Customized Operations For Different Fields

Here is another example of the custom operations implemented with both native torch API and treetensor API.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import copy

import torch

import treetensor.torch as ttorch

T, B = 3, 4


def with_nativetensor(batch_):
    mean_b_list = []
    even_index_a_list = []
    for i in range(len(batch_)):
        for k, v in batch_[i].items():
            if k == 'a':
                v = v.float()
                even_index_a_list.append(v[::2])
            elif k == 'b':
                v = v.float()
                transformed_v = torch.pow(v, 2) + 1.0
                mean_b_list.append(transformed_v.mean())
            elif k == 'c':
                for k1, v1 in v.items():
                    if k1 == 'd':
                        v1 = v1.float()
                    else:
                        print('ignore keys: {}'.format(k1))
            else:
                print('ignore keys: {}'.format(k))
    for i in range(len(batch_)):
        for k in batch_[i].keys():
            if k == 'd':
                batch_[i][k]['noise'] = torch.randn(size=(3, 4, 5))

    mean_b = sum(mean_b_list) / len(mean_b_list)
    even_index_a = torch.stack(even_index_a_list, dim=0)
    return batch_, mean_b, even_index_a


def with_treetensor(batch_):
    batch_ = [ttorch.tensor(b) for b in batch_]
    batch_ = ttorch.stack(batch_)
    batch_ = batch_.float()
    batch_.b = ttorch.pow(batch_.b, 2) + 1.0
    batch_.c.noise = ttorch.randn(size=(B, 3, 4, 5))
    mean_b = batch_.b.mean()
    even_index_a = batch_.a[:, ::2]
    batch_ = ttorch.split(batch_, split_size_or_sections=1, dim=0)
    return batch_, mean_b, even_index_a


def get_data():
    return {
        'a': torch.rand(size=(T, 8)),
        'b': torch.rand(size=(6,)),
        'c': {
            'd': torch.randint(0, 10, size=(1,))
        }
    }


if __name__ == "__main__":
    batch = [get_data() for _ in range(B)]
    batch0, mean0, even_index_a0 = with_nativetensor(copy.deepcopy(batch))
    batch1, mean1, even_index_a1 = with_treetensor(copy.deepcopy(batch))
    print(batch0)
    print('\n\n')
    print(batch1)

    assert torch.abs(mean0 - mean1) < 1e-6
    print('mean0 & mean1:', mean0, mean1)
    print('\n')

    assert torch.abs((even_index_a0 - even_index_a1).max()) < 1e-6
    print('even_index_a0:', even_index_a0)
    print('even_index_a1:', even_index_a1)

    assert len(batch0) == B
    assert len(batch1) == B
    assert isinstance(batch1[0], ttorch.Tensor)
    print(batch1[0].shape)

The output should be like below, and all the assertions can be passed.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
[{'a': tensor([[0.2909, 0.8774, 0.5245, 0.2834, 0.5976, 0.6227, 0.4479, 0.5658],
        [0.1502, 0.0066, 0.2932, 0.7347, 0.4004, 0.5038, 0.3965, 0.8740],
        [0.9986, 0.1072, 0.5157, 0.1964, 0.1072, 0.1204, 0.6982, 0.2265]]), 'b': tensor([0.6827, 0.6241, 0.4580, 0.2430, 0.0198, 0.4970]), 'c': {'d': tensor([7])}}, {'a': tensor([[0.2647, 0.5880, 0.1264, 0.1287, 0.3078, 0.3299, 0.1620, 0.6071],
        [0.5385, 0.2389, 0.3826, 0.3370, 0.1770, 0.9450, 0.2802, 0.8462],
        [0.7176, 0.5219, 0.2742, 0.4196, 0.4705, 0.5566, 0.0080, 0.3343]]), 'b': tensor([0.7906, 0.6741, 0.2049, 0.5595, 0.9478, 0.7007]), 'c': {'d': tensor([3])}}, {'a': tensor([[0.8778, 0.1243, 0.9530, 0.1061, 0.8310, 0.0896, 0.8134, 0.3172],
        [0.0549, 0.6482, 0.8368, 0.7538, 0.2837, 0.1982, 0.6072, 0.8325],
        [0.1204, 0.1436, 0.3943, 0.2892, 0.8983, 0.7299, 0.5482, 0.5508]]), 'b': tensor([0.7508, 0.2203, 0.4806, 0.0496, 0.5695, 0.0591]), 'c': {'d': tensor([0])}}, {'a': tensor([[0.0139, 0.5196, 0.4118, 0.8342, 0.2534, 0.9744, 0.8006, 0.8289],
        [0.5989, 0.7918, 0.3672, 0.7529, 0.7780, 0.8163, 0.0953, 0.8385],
        [0.4482, 0.9899, 0.5385, 0.3456, 0.6094, 0.8605, 0.0776, 0.2461]]), 'b': tensor([0.3306, 0.7661, 0.6920, 0.6690, 0.2996, 0.6166]), 'c': {'d': tensor([3])}}]



(<Tensor 0x7fce010dd1f0>
├── 'a' --> tensor([[[0.2909, 0.8774, 0.5245, 0.2834, 0.5976, 0.6227, 0.4479, 0.5658],
│                    [0.1502, 0.0066, 0.2932, 0.7347, 0.4004, 0.5038, 0.3965, 0.8740],
│                    [0.9986, 0.1072, 0.5157, 0.1964, 0.1072, 0.1204, 0.6982, 0.2265]]])
├── 'b' --> tensor([[1.4660, 1.3895, 1.2097, 1.0590, 1.0004, 1.2470]])
└── 'c' --> <Tensor 0x7fce010dd250>
    ├── 'd' --> tensor([[7.]])
    └── 'noise' --> tensor([[[[-0.9183,  0.6631,  1.2223,  0.9913, -0.7930],
                              [-1.7144, -1.5954,  0.3963, -0.0340, -1.7042],
                              [-1.1111,  0.4383, -1.1122,  0.2265,  0.6809],
                              [-1.8948, -1.1574,  0.5751,  0.5739, -0.5946]],
                    
                             [[-0.3911,  0.5059, -1.5592,  0.8105, -0.4599],
                              [ 0.1479,  1.8968,  0.3216,  0.7993,  0.4830],
                              [-0.7883, -1.1408, -0.7055,  0.5743, -0.1041],
                              [ 0.6040, -1.2881,  1.9320,  0.8264,  1.0876]],
                    
                             [[-2.0007, -0.7446, -0.7728, -0.0177, -0.7398],
                              [-0.2534, -0.0327, -0.4782, -3.0775,  2.2030],
                              [-0.4641,  0.5182,  1.5898,  1.0946,  0.2854],
                              [-0.6921,  0.8066,  1.6027, -0.1465,  0.4847]]]])
, <Tensor 0x7fce010dd310>
├── 'a' --> tensor([[[0.2647, 0.5880, 0.1264, 0.1287, 0.3078, 0.3299, 0.1620, 0.6071],
│                    [0.5385, 0.2389, 0.3826, 0.3370, 0.1770, 0.9450, 0.2802, 0.8462],
│                    [0.7176, 0.5219, 0.2742, 0.4196, 0.4705, 0.5566, 0.0080, 0.3343]]])
├── 'b' --> tensor([[1.6251, 1.4544, 1.0420, 1.3130, 1.8983, 1.4910]])
└── 'c' --> <Tensor 0x7fce010dd190>
    ├── 'd' --> tensor([[3.]])
    └── 'noise' --> tensor([[[[ 2.1365,  0.2016,  1.4825, -0.3637, -0.0737],
                              [-1.7718, -0.8321, -1.6644,  0.1703,  1.6251],
                              [ 0.6342, -2.0468, -0.6453,  0.6979,  0.6614],
                              [-0.2988, -0.5733, -0.5838,  1.0056, -0.0442]],
                    
                             [[ 1.2284,  0.7206,  0.9217,  0.6188, -0.0728],
                              [ 0.4928,  0.5809,  0.0663,  1.6436,  1.4629],
                              [ 0.0755,  0.6302, -0.0471, -0.4710,  1.2841],
                              [ 0.8259, -0.6029, -1.3071,  1.2015,  1.1368]],
                    
                             [[ 0.8509, -0.8950, -0.4758, -0.2302,  0.6813],
                              [-0.1901,  1.1689, -0.7822,  0.4300,  0.6451],
                              [-2.7162,  0.5373,  0.2350, -0.5809,  2.3647],
                              [-0.5974, -0.7653,  1.2431, -0.0800, -0.8557]]]])
, <Tensor 0x7fce010dd370>
├── 'a' --> tensor([[[0.8778, 0.1243, 0.9530, 0.1061, 0.8310, 0.0896, 0.8134, 0.3172],
│                    [0.0549, 0.6482, 0.8368, 0.7538, 0.2837, 0.1982, 0.6072, 0.8325],
│                    [0.1204, 0.1436, 0.3943, 0.2892, 0.8983, 0.7299, 0.5482, 0.5508]]])
├── 'b' --> tensor([[1.5637, 1.0485, 1.2310, 1.0025, 1.3243, 1.0035]])
└── 'c' --> <Tensor 0x7fce010dd340>
    ├── 'd' --> tensor([[0.]])
    └── 'noise' --> tensor([[[[-1.7963,  0.2216, -0.8436,  0.6543, -0.3107],
                              [ 1.0097,  2.1742,  0.5098,  0.0165, -2.1901],
                              [-0.8981,  0.0152,  0.2663, -0.3163, -0.0566],
                              [-1.8476, -0.3444,  0.3505,  2.2829,  0.5392]],
                    
                             [[-0.9531,  1.4015, -0.9820,  0.1828, -1.0707],
                              [-0.6989, -0.5245, -0.1681, -1.2357, -0.0397],
                              [-0.0702,  0.1157,  0.0083, -0.3045, -0.1781],
                              [ 0.5893,  0.0633, -0.1190,  0.7433, -0.4827]],
                    
                             [[-0.8336,  0.1313,  0.1840,  0.1536, -0.1993],
                              [ 0.7144,  0.1805,  0.0081, -0.6981, -0.7325],
                              [ 0.3349,  0.5839, -0.5527,  1.0175, -0.4347],
                              [ 0.1843,  0.0505,  0.3090, -0.1539, -0.0230]]]])
, <Tensor 0x7fce010dd3d0>
├── 'a' --> tensor([[[0.0139, 0.5196, 0.4118, 0.8342, 0.2534, 0.9744, 0.8006, 0.8289],
│                    [0.5989, 0.7918, 0.3672, 0.7529, 0.7780, 0.8163, 0.0953, 0.8385],
│                    [0.4482, 0.9899, 0.5385, 0.3456, 0.6094, 0.8605, 0.0776, 0.2461]]])
├── 'b' --> tensor([[1.1093, 1.5869, 1.4788, 1.4476, 1.0898, 1.3802]])
└── 'c' --> <Tensor 0x7fce010dd3a0>
    ├── 'd' --> tensor([[3.]])
    └── 'noise' --> tensor([[[[ 0.1460,  0.5519,  0.1847, -0.2209, -2.3965],
                              [ 0.0868, -2.5259, -1.6959,  1.2219, -0.7213],
                              [ 0.2837, -0.7940, -1.1715,  1.0738,  0.1309],
                              [ 0.0484,  0.1060, -0.4734,  0.3928, -0.1920]],
                    
                             [[-0.1676,  1.2269,  1.3391,  1.1828, -1.3600],
                              [-1.9169, -1.2024,  0.8078,  0.2364, -0.6363],
                              [ 0.2980,  0.2787,  1.8287, -1.6175,  0.2046],
                              [ 0.0681, -0.7894, -0.0525, -2.3374, -1.0105]],
                    
                             [[ 1.4547, -1.1449,  0.3292,  1.1134,  1.1361],
                              [ 0.4465, -0.6099, -1.6036,  0.7052,  0.8402],
                              [-1.7581, -0.3325,  1.1684,  0.6402, -0.3822],
                              [ 0.8802,  0.9288,  1.2364, -0.6207, -0.9612]]]])
)
mean0 & mean1: tensor(1.3109) tensor(1.3109)


even_index_a0: tensor([[[0.2909, 0.8774, 0.5245, 0.2834, 0.5976, 0.6227, 0.4479, 0.5658],
         [0.9986, 0.1072, 0.5157, 0.1964, 0.1072, 0.1204, 0.6982, 0.2265]],

        [[0.2647, 0.5880, 0.1264, 0.1287, 0.3078, 0.3299, 0.1620, 0.6071],
         [0.7176, 0.5219, 0.2742, 0.4196, 0.4705, 0.5566, 0.0080, 0.3343]],

        [[0.8778, 0.1243, 0.9530, 0.1061, 0.8310, 0.0896, 0.8134, 0.3172],
         [0.1204, 0.1436, 0.3943, 0.2892, 0.8983, 0.7299, 0.5482, 0.5508]],

        [[0.0139, 0.5196, 0.4118, 0.8342, 0.2534, 0.9744, 0.8006, 0.8289],
         [0.4482, 0.9899, 0.5385, 0.3456, 0.6094, 0.8605, 0.0776, 0.2461]]])
even_index_a1: tensor([[[0.2909, 0.8774, 0.5245, 0.2834, 0.5976, 0.6227, 0.4479, 0.5658],
         [0.9986, 0.1072, 0.5157, 0.1964, 0.1072, 0.1204, 0.6982, 0.2265]],

        [[0.2647, 0.5880, 0.1264, 0.1287, 0.3078, 0.3299, 0.1620, 0.6071],
         [0.7176, 0.5219, 0.2742, 0.4196, 0.4705, 0.5566, 0.0080, 0.3343]],

        [[0.8778, 0.1243, 0.9530, 0.1061, 0.8310, 0.0896, 0.8134, 0.3172],
         [0.1204, 0.1436, 0.3943, 0.2892, 0.8983, 0.7299, 0.5482, 0.5508]],

        [[0.0139, 0.5196, 0.4118, 0.8342, 0.2534, 0.9744, 0.8006, 0.8289],
         [0.4482, 0.9899, 0.5385, 0.3456, 0.6094, 0.8605, 0.0776, 0.2461]]])
<Size 0x7fce0114a970>
├── 'a' --> torch.Size([1, 3, 8])
├── 'b' --> torch.Size([1, 6])
└── 'c' --> <Size 0x7fce014012e0>
    ├── 'd' --> torch.Size([1, 1])
    └── 'noise' --> torch.Size([1, 3, 4, 5])

The implement with treetensor API is much simpler and clearer.