Customized Operations For Different Fields

Here is another example of the custom operations implemented with both native torch API and treetensor API.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import copy

import torch

import treetensor.torch as ttorch

T, B = 3, 4


def with_nativetensor(batch_):
    mean_b_list = []
    even_index_a_list = []
    for i in range(len(batch_)):
        for k, v in batch_[i].items():
            if k == 'a':
                v = v.float()
                even_index_a_list.append(v[::2])
            elif k == 'b':
                v = v.float()
                transformed_v = torch.pow(v, 2) + 1.0
                mean_b_list.append(transformed_v.mean())
            elif k == 'c':
                for k1, v1 in v.items():
                    if k1 == 'd':
                        v1 = v1.float()
                    else:
                        print('ignore keys: {}'.format(k1))
            else:
                print('ignore keys: {}'.format(k))
    for i in range(len(batch_)):
        for k in batch_[i].keys():
            if k == 'd':
                batch_[i][k]['noise'] = torch.randn(size=(3, 4, 5))

    mean_b = sum(mean_b_list) / len(mean_b_list)
    even_index_a = torch.stack(even_index_a_list, dim=0)
    return batch_, mean_b, even_index_a


def with_treetensor(batch_):
    batch_ = [ttorch.tensor(b) for b in batch_]
    batch_ = ttorch.stack(batch_)
    batch_ = batch_.float()
    batch_.b = ttorch.pow(batch_.b, 2) + 1.0
    batch_.c.noise = ttorch.randn(size=(B, 3, 4, 5))
    mean_b = batch_.b.mean()
    even_index_a = batch_.a[:, ::2]
    batch_ = ttorch.split(batch_, split_size_or_sections=1, dim=0)
    return batch_, mean_b, even_index_a


def get_data():
    return {
        'a': torch.rand(size=(T, 8)),
        'b': torch.rand(size=(6,)),
        'c': {
            'd': torch.randint(0, 10, size=(1,))
        }
    }


if __name__ == "__main__":
    batch = [get_data() for _ in range(B)]
    batch0, mean0, even_index_a0 = with_nativetensor(copy.deepcopy(batch))
    batch1, mean1, even_index_a1 = with_treetensor(copy.deepcopy(batch))
    print(batch0)
    print('\n\n')
    print(batch1)

    assert torch.abs(mean0 - mean1) < 1e-6
    print('mean0 & mean1:', mean0, mean1)
    print('\n')

    assert torch.abs((even_index_a0 - even_index_a1).max()) < 1e-6
    print('even_index_a0:', even_index_a0)
    print('even_index_a1:', even_index_a1)

    assert len(batch0) == B
    assert len(batch1) == B
    assert isinstance(batch1[0], ttorch.Tensor)
    print(batch1[0].shape)

The output should be like below, and all the assertions can be passed.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
[{'a': tensor([[0.1760, 0.9876, 0.7666, 0.6312, 0.1349, 0.4530, 0.2687, 0.9414],
        [0.8164, 0.5808, 0.3442, 0.3910, 0.5593, 0.3062, 0.1630, 0.2680],
        [0.8270, 0.5847, 0.2073, 0.1759, 0.1533, 0.9498, 0.6257, 0.6740]]), 'b': tensor([0.5486, 0.2000, 0.2006, 0.3559, 0.4115, 0.9562]), 'c': {'d': tensor([9])}}, {'a': tensor([[0.6299, 0.4004, 0.4542, 0.7033, 0.5401, 0.3561, 0.3821, 0.1780],
        [0.3425, 0.5969, 0.5505, 0.6960, 0.4339, 0.9264, 0.5039, 0.7249],
        [0.0891, 0.3546, 0.2275, 0.9833, 0.5123, 0.2282, 0.7045, 0.4442]]), 'b': tensor([0.4095, 0.3095, 0.8428, 0.8548, 0.7888, 0.5654]), 'c': {'d': tensor([2])}}, {'a': tensor([[0.7502, 0.8676, 0.9999, 0.7419, 0.8389, 0.0293, 0.1430, 0.7604],
        [0.6477, 0.8092, 0.2047, 0.1347, 0.1611, 0.8959, 0.7691, 0.3033],
        [0.8967, 0.5410, 0.6066, 0.7471, 0.4575, 0.2954, 0.4237, 0.7751]]), 'b': tensor([0.1823, 0.2579, 0.7932, 0.9217, 0.2090, 0.0825]), 'c': {'d': tensor([0])}}, {'a': tensor([[0.1884, 0.1080, 0.1079, 0.4079, 0.0615, 0.3103, 0.5896, 0.2523],
        [0.1338, 0.0235, 0.6606, 0.3652, 0.8009, 0.3431, 0.9553, 0.6687],
        [0.3617, 0.2567, 0.5046, 0.0802, 0.5889, 0.4226, 0.0626, 0.3299]]), 'b': tensor([0.5523, 0.4252, 0.6987, 0.2098, 0.3236, 0.6150]), 'c': {'d': tensor([0])}}]



(<Tensor 0x7fc8ed333580>
├── 'a' --> tensor([[[0.1760, 0.9876, 0.7666, 0.6312, 0.1349, 0.4530, 0.2687, 0.9414],
│                    [0.8164, 0.5808, 0.3442, 0.3910, 0.5593, 0.3062, 0.1630, 0.2680],
│                    [0.8270, 0.5847, 0.2073, 0.1759, 0.1533, 0.9498, 0.6257, 0.6740]]])
├── 'b' --> tensor([[1.3009, 1.0400, 1.0402, 1.1267, 1.1693, 1.9144]])
└── 'c' --> <Tensor 0x7fc8ed3335e0>
    ├── 'd' --> tensor([[9.]])
    └── 'noise' --> tensor([[[[ 0.2810,  1.4643, -0.9719,  0.3324,  0.1614],
                              [-2.4645,  0.0850, -0.5767, -1.1389,  0.4888],
                              [-2.3588, -0.4756, -1.1212,  0.8527, -0.3302],
                              [ 0.7773, -1.2100,  1.1747,  0.9436,  1.3061]],
                    
                             [[-0.5174,  0.4250, -1.0929,  0.5150,  0.0363],
                              [ 0.2181, -0.8368,  0.6428, -0.7121, -0.5522],
                              [-0.1242,  0.8859,  1.4753,  0.6179,  0.0561],
                              [ 0.0670, -1.6503,  0.4668, -0.1533, -1.2691]],
                    
                             [[-0.5205, -0.5433,  0.5092, -0.3479,  0.9470],
                              [-0.2283, -0.4997,  0.4974, -1.4480, -2.0320],
                              [-0.9231,  0.3258, -0.4301,  1.3455, -0.5889],
                              [-1.1355,  0.8925, -1.5451, -0.2237, -0.6219]]]])
, <Tensor 0x7fc8ed3336a0>
├── 'a' --> tensor([[[0.6299, 0.4004, 0.4542, 0.7033, 0.5401, 0.3561, 0.3821, 0.1780],
│                    [0.3425, 0.5969, 0.5505, 0.6960, 0.4339, 0.9264, 0.5039, 0.7249],
│                    [0.0891, 0.3546, 0.2275, 0.9833, 0.5123, 0.2282, 0.7045, 0.4442]]])
├── 'b' --> tensor([[1.1677, 1.0958, 1.7103, 1.7307, 1.6222, 1.3197]])
└── 'c' --> <Tensor 0x7fc8ed3334c0>
    ├── 'd' --> tensor([[2.]])
    └── 'noise' --> tensor([[[[-1.1136, -1.0463,  0.2548,  0.1471, -0.3970],
                              [ 1.8490,  1.5962, -0.6584, -1.9240,  0.1384],
                              [ 0.0654,  0.2813, -1.2626, -0.1989, -1.7316],
                              [-1.2335, -1.1391, -0.2785, -0.6234, -1.7957]],
                    
                             [[-0.2992,  2.4934,  0.1978,  1.3581, -0.0760],
                              [ 1.6081,  0.7025,  1.6493,  0.4080, -0.4362],
                              [-0.1329,  0.2666,  1.8063, -1.3381, -0.5317],
                              [-1.6255,  0.5464,  1.0031, -0.6533,  0.1872]],
                    
                             [[-0.2006,  0.1562,  0.3206, -1.1773,  0.5603],
                              [-1.0208, -2.2813,  1.7826, -0.8397,  0.1486],
                              [-1.3046, -1.7730,  2.0844, -1.9943, -0.2563],
                              [ 0.1521, -0.7617,  0.9459,  0.2308,  0.7644]]]])
, <Tensor 0x7fc8ed333700>
├── 'a' --> tensor([[[0.7502, 0.8676, 0.9999, 0.7419, 0.8389, 0.0293, 0.1430, 0.7604],
│                    [0.6477, 0.8092, 0.2047, 0.1347, 0.1611, 0.8959, 0.7691, 0.3033],
│                    [0.8967, 0.5410, 0.6066, 0.7471, 0.4575, 0.2954, 0.4237, 0.7751]]])
├── 'b' --> tensor([[1.0332, 1.0665, 1.6292, 1.8496, 1.0437, 1.0068]])
└── 'c' --> <Tensor 0x7fc8ed3336d0>
    ├── 'd' --> tensor([[0.]])
    └── 'noise' --> tensor([[[[ 0.1881,  1.3646, -1.3650,  1.7777,  1.1769],
                              [-0.2026, -0.4343, -2.5894,  0.8237, -0.1804],
                              [ 0.7298,  1.2785, -1.0835, -1.1562, -2.3360],
                              [ 2.3419,  0.3254, -0.9012, -1.5318,  0.3249]],
                    
                             [[-0.9911, -1.3211,  0.7636,  1.3447, -0.9473],
                              [ 0.4414,  0.3585, -1.6078, -0.0313,  0.8511],
                              [ 2.0062,  0.8407,  1.4273,  1.4176, -0.8433],
                              [ 1.6407, -0.5180,  0.6023, -0.3928, -2.2764]],
                    
                             [[ 2.6079, -1.3053, -1.6408, -0.7152,  1.5177],
                              [ 1.1402, -1.3566, -1.5336,  1.6338,  0.0867],
                              [-0.3611,  0.3741,  0.6614,  1.9662,  0.3121],
                              [ 0.1861, -0.1407,  1.1927, -0.6658,  1.1299]]]])
, <Tensor 0x7fc8ed333760>
├── 'a' --> tensor([[[0.1884, 0.1080, 0.1079, 0.4079, 0.0615, 0.3103, 0.5896, 0.2523],
│                    [0.1338, 0.0235, 0.6606, 0.3652, 0.8009, 0.3431, 0.9553, 0.6687],
│                    [0.3617, 0.2567, 0.5046, 0.0802, 0.5889, 0.4226, 0.0626, 0.3299]]])
├── 'b' --> tensor([[1.3050, 1.1808, 1.4882, 1.0440, 1.1047, 1.3782]])
└── 'c' --> <Tensor 0x7fc8ed333730>
    ├── 'd' --> tensor([[0.]])
    └── 'noise' --> tensor([[[[-2.4965,  0.3270,  1.2949, -0.1508,  0.5471],
                              [ 1.1143,  0.7643, -0.5936, -0.2020,  0.8068],
                              [ 0.3543,  1.0698, -0.0111, -0.0924, -1.7070],
                              [-0.0966,  0.0233,  0.3531, -1.0438, -0.8790]],
                    
                             [[ 1.8667,  0.2731,  2.7688, -0.8521, -0.7153],
                              [-0.1571,  1.7009,  1.0061,  0.0157, -1.3572],
                              [-0.1685,  1.5744, -0.3597, -0.3344,  0.2629],
                              [ 0.3098,  1.9305, -0.2067, -1.2160,  0.2546]],
                    
                             [[ 0.3821, -1.5974, -1.7580, -0.5292,  1.7847],
                              [-2.1281,  0.2113,  1.0670, -0.8068, -0.9332],
                              [-1.3835, -0.5270, -1.7189, -0.9176, -0.5866],
                              [-0.9154, -1.4674,  2.5163,  0.1607, -1.0668]]]])
)
mean0 & mean1: tensor(1.3070) tensor(1.3070)


even_index_a0: tensor([[[0.1760, 0.9876, 0.7666, 0.6312, 0.1349, 0.4530, 0.2687, 0.9414],
         [0.8270, 0.5847, 0.2073, 0.1759, 0.1533, 0.9498, 0.6257, 0.6740]],

        [[0.6299, 0.4004, 0.4542, 0.7033, 0.5401, 0.3561, 0.3821, 0.1780],
         [0.0891, 0.3546, 0.2275, 0.9833, 0.5123, 0.2282, 0.7045, 0.4442]],

        [[0.7502, 0.8676, 0.9999, 0.7419, 0.8389, 0.0293, 0.1430, 0.7604],
         [0.8967, 0.5410, 0.6066, 0.7471, 0.4575, 0.2954, 0.4237, 0.7751]],

        [[0.1884, 0.1080, 0.1079, 0.4079, 0.0615, 0.3103, 0.5896, 0.2523],
         [0.3617, 0.2567, 0.5046, 0.0802, 0.5889, 0.4226, 0.0626, 0.3299]]])
even_index_a1: tensor([[[0.1760, 0.9876, 0.7666, 0.6312, 0.1349, 0.4530, 0.2687, 0.9414],
         [0.8270, 0.5847, 0.2073, 0.1759, 0.1533, 0.9498, 0.6257, 0.6740]],

        [[0.6299, 0.4004, 0.4542, 0.7033, 0.5401, 0.3561, 0.3821, 0.1780],
         [0.0891, 0.3546, 0.2275, 0.9833, 0.5123, 0.2282, 0.7045, 0.4442]],

        [[0.7502, 0.8676, 0.9999, 0.7419, 0.8389, 0.0293, 0.1430, 0.7604],
         [0.8967, 0.5410, 0.6066, 0.7471, 0.4575, 0.2954, 0.4237, 0.7751]],

        [[0.1884, 0.1080, 0.1079, 0.4079, 0.0615, 0.3103, 0.5896, 0.2523],
         [0.3617, 0.2567, 0.5046, 0.0802, 0.5889, 0.4226, 0.0626, 0.3299]]])
<Size 0x7fc95322e430>
├── 'a' --> torch.Size([1, 3, 8])
├── 'b' --> torch.Size([1, 6])
└── 'c' --> <Size 0x7fc8ed3ad940>
    ├── 'd' --> torch.Size([1, 1])
    └── 'noise' --> torch.Size([1, 3, 4, 5])

The implement with treetensor API is much simpler and clearer.