Customized Operations For Different Fields¶
Here is another example of the custom operations implemented with both native torch API and treetensor API.
| 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 | import copy import torch import treetensor.torch as ttorch T, B = 3, 4 def with_nativetensor(batch_): mean_b_list = [] even_index_a_list = [] for i in range(len(batch_)): for k, v in batch_[i].items(): if k == 'a': v = v.float() even_index_a_list.append(v[::2]) elif k == 'b': v = v.float() transformed_v = torch.pow(v, 2) + 1.0 mean_b_list.append(transformed_v.mean()) elif k == 'c': for k1, v1 in v.items(): if k1 == 'd': v1 = v1.float() else: print('ignore keys: {}'.format(k1)) else: print('ignore keys: {}'.format(k)) for i in range(len(batch_)): for k in batch_[i].keys(): if k == 'd': batch_[i][k]['noise'] = torch.randn(size=(3, 4, 5)) mean_b = sum(mean_b_list) / len(mean_b_list) even_index_a = torch.stack(even_index_a_list, dim=0) return batch_, mean_b, even_index_a def with_treetensor(batch_): batch_ = [ttorch.tensor(b) for b in batch_] batch_ = ttorch.stack(batch_) batch_ = batch_.float() batch_.b = ttorch.pow(batch_.b, 2) + 1.0 batch_.c.noise = ttorch.randn(size=(B, 3, 4, 5)) mean_b = batch_.b.mean() even_index_a = batch_.a[:, ::2] batch_ = ttorch.split(batch_, split_size_or_sections=1, dim=0) return batch_, mean_b, even_index_a def get_data(): return { 'a': torch.rand(size=(T, 8)), 'b': torch.rand(size=(6,)), 'c': { 'd': torch.randint(0, 10, size=(1,)) } } if __name__ == "__main__": batch = [get_data() for _ in range(B)] batch0, mean0, even_index_a0 = with_nativetensor(copy.deepcopy(batch)) batch1, mean1, even_index_a1 = with_treetensor(copy.deepcopy(batch)) print(batch0) print('\n\n') print(batch1) assert torch.abs(mean0 - mean1) < 1e-6 print('mean0 & mean1:', mean0, mean1) print('\n') assert torch.abs((even_index_a0 - even_index_a1).max()) < 1e-6 print('even_index_a0:', even_index_a0) print('even_index_a1:', even_index_a1) assert len(batch0) == B assert len(batch1) == B assert isinstance(batch1[0], ttorch.Tensor) print(batch1[0].shape) | 
The output should be like below, and all the assertions can be passed.
| 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 | [{'a': tensor([[0.2909, 0.8774, 0.5245, 0.2834, 0.5976, 0.6227, 0.4479, 0.5658],
        [0.1502, 0.0066, 0.2932, 0.7347, 0.4004, 0.5038, 0.3965, 0.8740],
        [0.9986, 0.1072, 0.5157, 0.1964, 0.1072, 0.1204, 0.6982, 0.2265]]), 'b': tensor([0.6827, 0.6241, 0.4580, 0.2430, 0.0198, 0.4970]), 'c': {'d': tensor([7])}}, {'a': tensor([[0.2647, 0.5880, 0.1264, 0.1287, 0.3078, 0.3299, 0.1620, 0.6071],
        [0.5385, 0.2389, 0.3826, 0.3370, 0.1770, 0.9450, 0.2802, 0.8462],
        [0.7176, 0.5219, 0.2742, 0.4196, 0.4705, 0.5566, 0.0080, 0.3343]]), 'b': tensor([0.7906, 0.6741, 0.2049, 0.5595, 0.9478, 0.7007]), 'c': {'d': tensor([3])}}, {'a': tensor([[0.8778, 0.1243, 0.9530, 0.1061, 0.8310, 0.0896, 0.8134, 0.3172],
        [0.0549, 0.6482, 0.8368, 0.7538, 0.2837, 0.1982, 0.6072, 0.8325],
        [0.1204, 0.1436, 0.3943, 0.2892, 0.8983, 0.7299, 0.5482, 0.5508]]), 'b': tensor([0.7508, 0.2203, 0.4806, 0.0496, 0.5695, 0.0591]), 'c': {'d': tensor([0])}}, {'a': tensor([[0.0139, 0.5196, 0.4118, 0.8342, 0.2534, 0.9744, 0.8006, 0.8289],
        [0.5989, 0.7918, 0.3672, 0.7529, 0.7780, 0.8163, 0.0953, 0.8385],
        [0.4482, 0.9899, 0.5385, 0.3456, 0.6094, 0.8605, 0.0776, 0.2461]]), 'b': tensor([0.3306, 0.7661, 0.6920, 0.6690, 0.2996, 0.6166]), 'c': {'d': tensor([3])}}]
(<Tensor 0x7fce010dd1f0>
├── 'a' --> tensor([[[0.2909, 0.8774, 0.5245, 0.2834, 0.5976, 0.6227, 0.4479, 0.5658],
│                    [0.1502, 0.0066, 0.2932, 0.7347, 0.4004, 0.5038, 0.3965, 0.8740],
│                    [0.9986, 0.1072, 0.5157, 0.1964, 0.1072, 0.1204, 0.6982, 0.2265]]])
├── 'b' --> tensor([[1.4660, 1.3895, 1.2097, 1.0590, 1.0004, 1.2470]])
└── 'c' --> <Tensor 0x7fce010dd250>
    ├── 'd' --> tensor([[7.]])
    └── 'noise' --> tensor([[[[-0.9183,  0.6631,  1.2223,  0.9913, -0.7930],
                              [-1.7144, -1.5954,  0.3963, -0.0340, -1.7042],
                              [-1.1111,  0.4383, -1.1122,  0.2265,  0.6809],
                              [-1.8948, -1.1574,  0.5751,  0.5739, -0.5946]],
                    
                             [[-0.3911,  0.5059, -1.5592,  0.8105, -0.4599],
                              [ 0.1479,  1.8968,  0.3216,  0.7993,  0.4830],
                              [-0.7883, -1.1408, -0.7055,  0.5743, -0.1041],
                              [ 0.6040, -1.2881,  1.9320,  0.8264,  1.0876]],
                    
                             [[-2.0007, -0.7446, -0.7728, -0.0177, -0.7398],
                              [-0.2534, -0.0327, -0.4782, -3.0775,  2.2030],
                              [-0.4641,  0.5182,  1.5898,  1.0946,  0.2854],
                              [-0.6921,  0.8066,  1.6027, -0.1465,  0.4847]]]])
, <Tensor 0x7fce010dd310>
├── 'a' --> tensor([[[0.2647, 0.5880, 0.1264, 0.1287, 0.3078, 0.3299, 0.1620, 0.6071],
│                    [0.5385, 0.2389, 0.3826, 0.3370, 0.1770, 0.9450, 0.2802, 0.8462],
│                    [0.7176, 0.5219, 0.2742, 0.4196, 0.4705, 0.5566, 0.0080, 0.3343]]])
├── 'b' --> tensor([[1.6251, 1.4544, 1.0420, 1.3130, 1.8983, 1.4910]])
└── 'c' --> <Tensor 0x7fce010dd190>
    ├── 'd' --> tensor([[3.]])
    └── 'noise' --> tensor([[[[ 2.1365,  0.2016,  1.4825, -0.3637, -0.0737],
                              [-1.7718, -0.8321, -1.6644,  0.1703,  1.6251],
                              [ 0.6342, -2.0468, -0.6453,  0.6979,  0.6614],
                              [-0.2988, -0.5733, -0.5838,  1.0056, -0.0442]],
                    
                             [[ 1.2284,  0.7206,  0.9217,  0.6188, -0.0728],
                              [ 0.4928,  0.5809,  0.0663,  1.6436,  1.4629],
                              [ 0.0755,  0.6302, -0.0471, -0.4710,  1.2841],
                              [ 0.8259, -0.6029, -1.3071,  1.2015,  1.1368]],
                    
                             [[ 0.8509, -0.8950, -0.4758, -0.2302,  0.6813],
                              [-0.1901,  1.1689, -0.7822,  0.4300,  0.6451],
                              [-2.7162,  0.5373,  0.2350, -0.5809,  2.3647],
                              [-0.5974, -0.7653,  1.2431, -0.0800, -0.8557]]]])
, <Tensor 0x7fce010dd370>
├── 'a' --> tensor([[[0.8778, 0.1243, 0.9530, 0.1061, 0.8310, 0.0896, 0.8134, 0.3172],
│                    [0.0549, 0.6482, 0.8368, 0.7538, 0.2837, 0.1982, 0.6072, 0.8325],
│                    [0.1204, 0.1436, 0.3943, 0.2892, 0.8983, 0.7299, 0.5482, 0.5508]]])
├── 'b' --> tensor([[1.5637, 1.0485, 1.2310, 1.0025, 1.3243, 1.0035]])
└── 'c' --> <Tensor 0x7fce010dd340>
    ├── 'd' --> tensor([[0.]])
    └── 'noise' --> tensor([[[[-1.7963,  0.2216, -0.8436,  0.6543, -0.3107],
                              [ 1.0097,  2.1742,  0.5098,  0.0165, -2.1901],
                              [-0.8981,  0.0152,  0.2663, -0.3163, -0.0566],
                              [-1.8476, -0.3444,  0.3505,  2.2829,  0.5392]],
                    
                             [[-0.9531,  1.4015, -0.9820,  0.1828, -1.0707],
                              [-0.6989, -0.5245, -0.1681, -1.2357, -0.0397],
                              [-0.0702,  0.1157,  0.0083, -0.3045, -0.1781],
                              [ 0.5893,  0.0633, -0.1190,  0.7433, -0.4827]],
                    
                             [[-0.8336,  0.1313,  0.1840,  0.1536, -0.1993],
                              [ 0.7144,  0.1805,  0.0081, -0.6981, -0.7325],
                              [ 0.3349,  0.5839, -0.5527,  1.0175, -0.4347],
                              [ 0.1843,  0.0505,  0.3090, -0.1539, -0.0230]]]])
, <Tensor 0x7fce010dd3d0>
├── 'a' --> tensor([[[0.0139, 0.5196, 0.4118, 0.8342, 0.2534, 0.9744, 0.8006, 0.8289],
│                    [0.5989, 0.7918, 0.3672, 0.7529, 0.7780, 0.8163, 0.0953, 0.8385],
│                    [0.4482, 0.9899, 0.5385, 0.3456, 0.6094, 0.8605, 0.0776, 0.2461]]])
├── 'b' --> tensor([[1.1093, 1.5869, 1.4788, 1.4476, 1.0898, 1.3802]])
└── 'c' --> <Tensor 0x7fce010dd3a0>
    ├── 'd' --> tensor([[3.]])
    └── 'noise' --> tensor([[[[ 0.1460,  0.5519,  0.1847, -0.2209, -2.3965],
                              [ 0.0868, -2.5259, -1.6959,  1.2219, -0.7213],
                              [ 0.2837, -0.7940, -1.1715,  1.0738,  0.1309],
                              [ 0.0484,  0.1060, -0.4734,  0.3928, -0.1920]],
                    
                             [[-0.1676,  1.2269,  1.3391,  1.1828, -1.3600],
                              [-1.9169, -1.2024,  0.8078,  0.2364, -0.6363],
                              [ 0.2980,  0.2787,  1.8287, -1.6175,  0.2046],
                              [ 0.0681, -0.7894, -0.0525, -2.3374, -1.0105]],
                    
                             [[ 1.4547, -1.1449,  0.3292,  1.1134,  1.1361],
                              [ 0.4465, -0.6099, -1.6036,  0.7052,  0.8402],
                              [-1.7581, -0.3325,  1.1684,  0.6402, -0.3822],
                              [ 0.8802,  0.9288,  1.2364, -0.6207, -0.9612]]]])
)
mean0 & mean1: tensor(1.3109) tensor(1.3109)
even_index_a0: tensor([[[0.2909, 0.8774, 0.5245, 0.2834, 0.5976, 0.6227, 0.4479, 0.5658],
         [0.9986, 0.1072, 0.5157, 0.1964, 0.1072, 0.1204, 0.6982, 0.2265]],
        [[0.2647, 0.5880, 0.1264, 0.1287, 0.3078, 0.3299, 0.1620, 0.6071],
         [0.7176, 0.5219, 0.2742, 0.4196, 0.4705, 0.5566, 0.0080, 0.3343]],
        [[0.8778, 0.1243, 0.9530, 0.1061, 0.8310, 0.0896, 0.8134, 0.3172],
         [0.1204, 0.1436, 0.3943, 0.2892, 0.8983, 0.7299, 0.5482, 0.5508]],
        [[0.0139, 0.5196, 0.4118, 0.8342, 0.2534, 0.9744, 0.8006, 0.8289],
         [0.4482, 0.9899, 0.5385, 0.3456, 0.6094, 0.8605, 0.0776, 0.2461]]])
even_index_a1: tensor([[[0.2909, 0.8774, 0.5245, 0.2834, 0.5976, 0.6227, 0.4479, 0.5658],
         [0.9986, 0.1072, 0.5157, 0.1964, 0.1072, 0.1204, 0.6982, 0.2265]],
        [[0.2647, 0.5880, 0.1264, 0.1287, 0.3078, 0.3299, 0.1620, 0.6071],
         [0.7176, 0.5219, 0.2742, 0.4196, 0.4705, 0.5566, 0.0080, 0.3343]],
        [[0.8778, 0.1243, 0.9530, 0.1061, 0.8310, 0.0896, 0.8134, 0.3172],
         [0.1204, 0.1436, 0.3943, 0.2892, 0.8983, 0.7299, 0.5482, 0.5508]],
        [[0.0139, 0.5196, 0.4118, 0.8342, 0.2534, 0.9744, 0.8006, 0.8289],
         [0.4482, 0.9899, 0.5385, 0.3456, 0.6094, 0.8605, 0.0776, 0.2461]]])
<Size 0x7fce0114a970>
├── 'a' --> torch.Size([1, 3, 8])
├── 'b' --> torch.Size([1, 6])
└── 'c' --> <Size 0x7fce014012e0>
    ├── 'd' --> torch.Size([1, 1])
    └── 'noise' --> torch.Size([1, 3, 4, 5])
 | 
The implement with treetensor API is much simpler and clearer.