Customized Operations For Different Fields

Here is another example of the custom operations implemented with both native torch API and treetensor API.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import copy

import torch

import treetensor.torch as ttorch

T, B = 3, 4


def with_nativetensor(batch_):
    mean_b_list = []
    even_index_a_list = []
    for i in range(len(batch_)):
        for k, v in batch_[i].items():
            if k == 'a':
                v = v.float()
                even_index_a_list.append(v[::2])
            elif k == 'b':
                v = v.float()
                transformed_v = torch.pow(v, 2) + 1.0
                mean_b_list.append(transformed_v.mean())
            elif k == 'c':
                for k1, v1 in v.items():
                    if k1 == 'd':
                        v1 = v1.float()
                    else:
                        print('ignore keys: {}'.format(k1))
            else:
                print('ignore keys: {}'.format(k))
    for i in range(len(batch_)):
        for k in batch_[i].keys():
            if k == 'd':
                batch_[i][k]['noise'] = torch.randn(size=(3, 4, 5))

    mean_b = sum(mean_b_list) / len(mean_b_list)
    even_index_a = torch.stack(even_index_a_list, dim=0)
    return batch_, mean_b, even_index_a


def with_treetensor(batch_):
    batch_ = [ttorch.tensor(b) for b in batch_]
    batch_ = ttorch.stack(batch_)
    batch_ = batch_.float()
    batch_.b = ttorch.pow(batch_.b, 2) + 1.0
    batch_.c.noise = ttorch.randn(size=(B, 3, 4, 5))
    mean_b = batch_.b.mean()
    even_index_a = batch_.a[:, ::2]
    batch_ = ttorch.split(batch_, split_size_or_sections=1, dim=0)
    return batch_, mean_b, even_index_a


def get_data():
    return {
        'a': torch.rand(size=(T, 8)),
        'b': torch.rand(size=(6,)),
        'c': {
            'd': torch.randint(0, 10, size=(1,))
        }
    }


if __name__ == "__main__":
    batch = [get_data() for _ in range(B)]
    batch0, mean0, even_index_a0 = with_nativetensor(copy.deepcopy(batch))
    batch1, mean1, even_index_a1 = with_treetensor(copy.deepcopy(batch))
    print(batch0)
    print('\n\n')
    print(batch1)

    assert torch.abs(mean0 - mean1) < 1e-6
    print('mean0 & mean1:', mean0, mean1)
    print('\n')

    assert torch.abs((even_index_a0 - even_index_a1).max()) < 1e-6
    print('even_index_a0:', even_index_a0)
    print('even_index_a1:', even_index_a1)

    assert len(batch0) == B
    assert len(batch1) == B
    assert isinstance(batch1[0], ttorch.Tensor)
    print(batch1[0].shape)

The output should be like below, and all the assertions can be passed.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
[{'a': tensor([[0.0686, 0.1086, 0.1260, 0.0717, 0.1534, 0.4248, 0.8107, 0.9792],
        [0.8358, 0.7622, 0.0328, 0.2995, 0.4200, 0.4186, 0.7946, 0.6256],
        [0.5847, 0.1426, 0.1226, 0.4180, 0.4402, 0.9689, 0.9807, 0.1741]]), 'b': tensor([0.2408, 0.0534, 0.2229, 0.3362, 0.0190, 0.5164]), 'c': {'d': tensor([0])}}, {'a': tensor([[0.5001, 0.0582, 0.2682, 0.9194, 0.8338, 0.3869, 0.4686, 0.1196],
        [0.8403, 0.9976, 0.7340, 0.1658, 0.1127, 0.6743, 0.1626, 0.5464],
        [0.0086, 0.5789, 0.5836, 0.5645, 0.2969, 0.5460, 0.3157, 0.0195]]), 'b': tensor([0.5788, 0.7546, 0.1636, 0.9155, 0.6175, 0.2375]), 'c': {'d': tensor([8])}}, {'a': tensor([[0.8284, 0.8438, 0.7916, 0.1016, 0.2882, 0.1145, 0.6095, 0.0657],
        [0.4021, 0.4337, 0.7934, 0.3543, 0.9150, 0.0930, 0.7815, 0.1391],
        [0.0771, 0.6045, 0.6823, 0.1701, 0.2156, 0.5855, 0.6355, 0.9394]]), 'b': tensor([0.9815, 0.2230, 0.6183, 0.0756, 0.1584, 0.6588]), 'c': {'d': tensor([1])}}, {'a': tensor([[0.0669, 0.6374, 0.1598, 0.7754, 0.2292, 0.0127, 0.4531, 0.9546],
        [0.1285, 0.1093, 0.6841, 0.0386, 0.5381, 0.5292, 0.2525, 0.4983],
        [0.4484, 0.7647, 0.6624, 0.1767, 0.2220, 0.4530, 0.6271, 0.0049]]), 'b': tensor([0.0129, 0.8459, 0.0295, 0.2026, 0.2082, 0.3370]), 'c': {'d': tensor([4])}}]



(<Tensor 0x7fea7bb5cac0>
├── 'a' --> tensor([[[0.0686, 0.1086, 0.1260, 0.0717, 0.1534, 0.4248, 0.8107, 0.9792],
│                    [0.8358, 0.7622, 0.0328, 0.2995, 0.4200, 0.4186, 0.7946, 0.6256],
│                    [0.5847, 0.1426, 0.1226, 0.4180, 0.4402, 0.9689, 0.9807, 0.1741]]])
├── 'b' --> tensor([[1.0580, 1.0029, 1.0497, 1.1130, 1.0004, 1.2667]])
└── 'c' --> <Tensor 0x7fea7bb5cb20>
    ├── 'd' --> tensor([[0.]])
    └── 'noise' --> tensor([[[[ 1.4809, -0.8466,  1.0720,  0.7674,  1.2988],
                              [-0.1379, -1.1622, -0.4715,  0.2469,  0.2314],
                              [-0.4344, -2.3258, -0.1021,  0.5908, -0.0135],
                              [ 0.2125, -0.7161, -0.6925, -0.8530, -0.3846]],
                    
                             [[ 0.6735, -0.9661,  0.5631,  0.9926, -3.1729],
                              [-0.5627,  0.5281,  1.4474, -1.0707,  1.2655],
                              [ 2.8633, -1.2859,  0.3502,  0.0993,  1.5777],
                              [ 1.0557, -0.0891,  0.8168, -1.4842,  0.1666]],
                    
                             [[ 0.3919, -0.0320,  1.8884, -0.6711, -1.8297],
                              [ 0.1443,  1.5710, -0.0992, -0.6423,  0.6008],
                              [ 0.7714,  0.4063,  0.8384, -1.2446,  0.7972],
                              [-0.6708, -0.1289, -0.4705, -1.0139,  1.0656]]]])
, <Tensor 0x7fea7bb5cbe0>
├── 'a' --> tensor([[[0.5001, 0.0582, 0.2682, 0.9194, 0.8338, 0.3869, 0.4686, 0.1196],
│                    [0.8403, 0.9976, 0.7340, 0.1658, 0.1127, 0.6743, 0.1626, 0.5464],
│                    [0.0086, 0.5789, 0.5836, 0.5645, 0.2969, 0.5460, 0.3157, 0.0195]]])
├── 'b' --> tensor([[1.3350, 1.5694, 1.0268, 1.8381, 1.3813, 1.0564]])
└── 'c' --> <Tensor 0x7fea7bb5ca30>
    ├── 'd' --> tensor([[8.]])
    └── 'noise' --> tensor([[[[-0.4667,  0.2000,  1.2140,  0.4890,  0.5628],
                              [ 0.3238,  1.4185, -0.4402, -0.5429,  0.5033],
                              [ 0.5419, -0.5719,  1.0255,  0.6894, -0.3771],
                              [-0.7727, -2.1226,  0.0043,  0.2789, -0.8093]],
                    
                             [[-0.8213, -1.9765,  0.6290,  0.7305, -0.6955],
                              [-2.0336, -0.3545, -0.6026,  0.6712, -1.2341],
                              [ 0.9422, -1.2487,  1.7422, -0.4708,  0.5907],
                              [ 0.1950, -1.3476, -0.0215,  0.6312, -0.7755]],
                    
                             [[-1.0473, -0.0140,  0.1878,  0.0064,  1.8081],
                              [-2.4747, -0.3494,  0.0361, -0.6426,  0.4053],
                              [-0.2242,  0.2586,  0.4661, -1.7540, -1.9929],
                              [-0.8916,  0.7854,  1.5631,  0.2244, -0.4089]]]])
, <Tensor 0x7fea7bb5cc40>
├── 'a' --> tensor([[[0.8284, 0.8438, 0.7916, 0.1016, 0.2882, 0.1145, 0.6095, 0.0657],
│                    [0.4021, 0.4337, 0.7934, 0.3543, 0.9150, 0.0930, 0.7815, 0.1391],
│                    [0.0771, 0.6045, 0.6823, 0.1701, 0.2156, 0.5855, 0.6355, 0.9394]]])
├── 'b' --> tensor([[1.9633, 1.0497, 1.3823, 1.0057, 1.0251, 1.4340]])
└── 'c' --> <Tensor 0x7fea7bb5cc10>
    ├── 'd' --> tensor([[1.]])
    └── 'noise' --> tensor([[[[-0.8266, -0.2700, -0.5078,  0.6761,  0.2158],
                              [ 0.6552, -0.5522, -0.0266,  1.0813, -1.5548],
                              [ 2.3831, -0.8522,  0.6716, -0.2994,  0.6217],
                              [-0.0827,  1.1236,  0.5660, -0.1817,  1.2789]],
                    
                             [[-1.3301,  1.3541,  0.5394,  0.1224,  0.3887],
                              [-1.5058, -1.1427,  1.2118,  1.2243,  0.2612],
                              [ 0.4264, -1.6123, -0.9390, -1.1020, -0.8110],
                              [-0.0263,  0.1119,  0.5974,  0.2575, -0.8753]],
                    
                             [[ 1.7161,  1.6931, -2.6318, -0.2364,  0.1694],
                              [-0.7443,  0.9577,  1.1957, -0.5674, -0.4003],
                              [-1.4331,  0.6762,  1.2686,  0.6880, -0.4489],
                              [-0.9069, -0.4411,  0.9527,  1.6256, -1.3127]]]])
, <Tensor 0x7fea7bb5cca0>
├── 'a' --> tensor([[[0.0669, 0.6374, 0.1598, 0.7754, 0.2292, 0.0127, 0.4531, 0.9546],
│                    [0.1285, 0.1093, 0.6841, 0.0386, 0.5381, 0.5292, 0.2525, 0.4983],
│                    [0.4484, 0.7647, 0.6624, 0.1767, 0.2220, 0.4530, 0.6271, 0.0049]]])
├── 'b' --> tensor([[1.0002, 1.7155, 1.0009, 1.0410, 1.0434, 1.1136]])
└── 'c' --> <Tensor 0x7fea7bb5cc70>
    ├── 'd' --> tensor([[4.]])
    └── 'noise' --> tensor([[[[-0.9848, -0.8801, -2.0229,  0.2040,  1.0911],
                              [ 0.1213,  3.1088, -1.0881,  2.1307, -0.0602],
                              [-0.5257, -0.5790,  0.5905,  0.6199,  0.0638],
                              [-1.9652, -1.0409,  0.6534,  2.0500,  1.4129]],
                    
                             [[-0.9417, -1.2216, -0.6902,  0.6150,  0.1717],
                              [ 0.5321, -0.1839, -1.9115,  1.6711, -0.6858],
                              [-2.2324, -0.3537,  1.4024,  0.8769, -1.0161],
                              [-0.6217, -0.3269,  0.1266, -0.5366, -0.3699]],
                    
                             [[ 0.0931,  0.4600, -0.0660, -0.3847,  0.5237],
                              [ 0.7623,  0.6981,  1.2001,  1.2765,  0.6914],
                              [ 0.9582, -0.2392, -1.2116,  0.0220, -1.5339],
                              [-1.7506,  0.0362,  0.5130,  0.5271, -0.0502]]]])
)
mean0 & mean1: tensor(1.2280) tensor(1.2280)


even_index_a0: tensor([[[0.0686, 0.1086, 0.1260, 0.0717, 0.1534, 0.4248, 0.8107, 0.9792],
         [0.5847, 0.1426, 0.1226, 0.4180, 0.4402, 0.9689, 0.9807, 0.1741]],

        [[0.5001, 0.0582, 0.2682, 0.9194, 0.8338, 0.3869, 0.4686, 0.1196],
         [0.0086, 0.5789, 0.5836, 0.5645, 0.2969, 0.5460, 0.3157, 0.0195]],

        [[0.8284, 0.8438, 0.7916, 0.1016, 0.2882, 0.1145, 0.6095, 0.0657],
         [0.0771, 0.6045, 0.6823, 0.1701, 0.2156, 0.5855, 0.6355, 0.9394]],

        [[0.0669, 0.6374, 0.1598, 0.7754, 0.2292, 0.0127, 0.4531, 0.9546],
         [0.4484, 0.7647, 0.6624, 0.1767, 0.2220, 0.4530, 0.6271, 0.0049]]])
even_index_a1: tensor([[[0.0686, 0.1086, 0.1260, 0.0717, 0.1534, 0.4248, 0.8107, 0.9792],
         [0.5847, 0.1426, 0.1226, 0.4180, 0.4402, 0.9689, 0.9807, 0.1741]],

        [[0.5001, 0.0582, 0.2682, 0.9194, 0.8338, 0.3869, 0.4686, 0.1196],
         [0.0086, 0.5789, 0.5836, 0.5645, 0.2969, 0.5460, 0.3157, 0.0195]],

        [[0.8284, 0.8438, 0.7916, 0.1016, 0.2882, 0.1145, 0.6095, 0.0657],
         [0.0771, 0.6045, 0.6823, 0.1701, 0.2156, 0.5855, 0.6355, 0.9394]],

        [[0.0669, 0.6374, 0.1598, 0.7754, 0.2292, 0.0127, 0.4531, 0.9546],
         [0.4484, 0.7647, 0.6624, 0.1767, 0.2220, 0.4530, 0.6271, 0.0049]]])
<Size 0x7feadba56460>
├── 'a' --> torch.Size([1, 3, 8])
├── 'b' --> torch.Size([1, 6])
└── 'c' --> <Size 0x7fea7bbd7460>
    ├── 'd' --> torch.Size([1, 1])
    └── 'noise' --> torch.Size([1, 3, 4, 5])

The implement with treetensor API is much simpler and clearer.