Customized Operations For Different Fields¶
Here is another example of the custom operations implemented with both native torch API and treetensor API.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 | import copy import torch import treetensor.torch as ttorch T, B = 3, 4 def with_nativetensor(batch_): mean_b_list = [] even_index_a_list = [] for i in range(len(batch_)): for k, v in batch_[i].items(): if k == 'a': v = v.float() even_index_a_list.append(v[::2]) elif k == 'b': v = v.float() transformed_v = torch.pow(v, 2) + 1.0 mean_b_list.append(transformed_v.mean()) elif k == 'c': for k1, v1 in v.items(): if k1 == 'd': v1 = v1.float() else: print('ignore keys: {}'.format(k1)) else: print('ignore keys: {}'.format(k)) for i in range(len(batch_)): for k in batch_[i].keys(): if k == 'd': batch_[i][k]['noise'] = torch.randn(size=(3, 4, 5)) mean_b = sum(mean_b_list) / len(mean_b_list) even_index_a = torch.stack(even_index_a_list, dim=0) return batch_, mean_b, even_index_a def with_treetensor(batch_): batch_ = [ttorch.tensor(b) for b in batch_] batch_ = ttorch.stack(batch_) batch_ = batch_.float() batch_.b = ttorch.pow(batch_.b, 2) + 1.0 batch_.c.noise = ttorch.randn(size=(B, 3, 4, 5)) mean_b = batch_.b.mean() even_index_a = batch_.a[:, ::2] batch_ = ttorch.split(batch_, split_size_or_sections=1, dim=0) return batch_, mean_b, even_index_a def get_data(): return { 'a': torch.rand(size=(T, 8)), 'b': torch.rand(size=(6,)), 'c': { 'd': torch.randint(0, 10, size=(1,)) } } if __name__ == "__main__": batch = [get_data() for _ in range(B)] batch0, mean0, even_index_a0 = with_nativetensor(copy.deepcopy(batch)) batch1, mean1, even_index_a1 = with_treetensor(copy.deepcopy(batch)) print(batch0) print('\n\n') print(batch1) assert torch.abs(mean0 - mean1) < 1e-6 print('mean0 & mean1:', mean0, mean1) print('\n') assert torch.abs((even_index_a0 - even_index_a1).max()) < 1e-6 print('even_index_a0:', even_index_a0) print('even_index_a1:', even_index_a1) assert len(batch0) == B assert len(batch1) == B assert isinstance(batch1[0], ttorch.Tensor) print(batch1[0].shape) |
The output should be like below, and all the assertions can be passed.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 | [{'a': tensor([[0.7952, 0.5317, 0.9969, 0.7247, 0.9008, 0.8212, 0.3001, 0.8340], [0.0502, 0.1068, 0.0729, 0.3611, 0.7296, 0.5931, 0.6764, 0.2331], [0.4727, 0.5710, 0.9302, 0.8417, 0.7715, 0.8744, 0.8541, 0.4062]]), 'b': tensor([0.9890, 0.8908, 0.2642, 0.1434, 0.1721, 0.9819]), 'c': {'d': tensor([4])}}, {'a': tensor([[0.1144, 0.8584, 0.1021, 0.4512, 0.2160, 0.4735, 0.1162, 0.4354], [0.2487, 0.6985, 0.9192, 0.3302, 0.8952, 0.1003, 0.1221, 0.0799], [0.8819, 0.8667, 0.0678, 0.5979, 0.7199, 0.5001, 0.9542, 0.4657]]), 'b': tensor([0.4046, 0.8683, 0.6061, 0.8716, 0.4036, 0.3148]), 'c': {'d': tensor([8])}}, {'a': tensor([[0.3129, 0.4439, 0.8225, 0.1990, 0.4839, 0.6035, 0.6834, 0.5054], [0.3481, 0.6482, 0.3350, 0.7787, 0.2522, 0.7648, 0.1572, 0.7762], [0.5089, 0.3098, 0.3571, 0.4281, 0.0979, 0.2945, 0.0717, 0.9534]]), 'b': tensor([0.6999, 0.9815, 0.6931, 0.8468, 0.4924, 0.6062]), 'c': {'d': tensor([7])}}, {'a': tensor([[0.5043, 0.1589, 0.5069, 0.4711, 0.3651, 0.7958, 0.4787, 0.3208], [0.4928, 0.9831, 0.0118, 0.6966, 0.4535, 0.9948, 0.7176, 0.1944], [0.0527, 0.1514, 0.3471, 0.2261, 0.1730, 0.7510, 0.1243, 0.8186]]), 'b': tensor([0.9240, 0.4537, 0.9679, 0.1696, 0.8909, 0.2051]), 'c': {'d': tensor([8])}}] (<Tensor 0x7f83ecb97ac0> ├── 'a' --> tensor([[[0.7952, 0.5317, 0.9969, 0.7247, 0.9008, 0.8212, 0.3001, 0.8340], │ [0.0502, 0.1068, 0.0729, 0.3611, 0.7296, 0.5931, 0.6764, 0.2331], │ [0.4727, 0.5710, 0.9302, 0.8417, 0.7715, 0.8744, 0.8541, 0.4062]]]) ├── 'b' --> tensor([[1.9782, 1.7935, 1.0698, 1.0206, 1.0296, 1.9642]]) └── 'c' --> <Tensor 0x7f83ecb97b20> ├── 'd' --> tensor([[4.]]) └── 'noise' --> tensor([[[[-0.9922, -0.8831, -0.9031, -1.4641, -1.3487], [ 0.3582, -1.2870, -0.0889, 0.4798, 0.4101], [-1.5796, -0.0642, -1.0601, 2.1795, 0.0919], [-1.1331, -0.7478, -1.6202, 0.2408, 0.4960]], [[-0.2332, -0.9402, 1.2484, 0.6176, -1.2069], [-0.5010, 0.7441, -0.3839, 0.8156, 0.1747], [ 1.3137, 0.7718, -0.7034, 0.2337, -0.4377], [-0.0994, 0.3264, -0.2980, -1.2229, 0.9893]], [[ 1.1173, -0.2840, -0.8270, -0.3214, 2.0599], [-0.1047, 1.6224, -0.3308, 0.9336, 0.6434], [-0.2882, -0.8809, -0.4204, 1.6479, -0.2479], [ 0.4165, 0.8377, -0.3197, 0.1841, 1.6500]]]]) , <Tensor 0x7f83ecb97be0> ├── 'a' --> tensor([[[0.1144, 0.8584, 0.1021, 0.4512, 0.2160, 0.4735, 0.1162, 0.4354], │ [0.2487, 0.6985, 0.9192, 0.3302, 0.8952, 0.1003, 0.1221, 0.0799], │ [0.8819, 0.8667, 0.0678, 0.5979, 0.7199, 0.5001, 0.9542, 0.4657]]]) ├── 'b' --> tensor([[1.1637, 1.7539, 1.3673, 1.7596, 1.1629, 1.0991]]) └── 'c' --> <Tensor 0x7f83ecb97a30> ├── 'd' --> tensor([[8.]]) └── 'noise' --> tensor([[[[ 0.5359, -1.1872, -1.1349, -1.2241, -0.2247], [-1.4811, 0.5311, -0.5148, 1.7360, -0.6674], [-1.5773, 0.3944, -0.5607, -0.4603, 0.6399], [-0.3212, -0.4318, 0.0031, 1.1059, -0.9834]], [[-0.6793, -0.5348, -0.0504, -1.5032, 0.1691], [ 1.6129, -0.6823, 1.0471, -1.1235, 0.4317], [-0.4445, 0.0780, 0.3752, 1.8468, 1.1199], [ 0.6156, 0.4850, 0.2734, -0.8555, 0.1752]], [[-0.1423, 0.0679, -0.5144, -1.3756, -0.8489], [-0.7540, -1.1874, -0.3342, -0.2126, 0.8937], [ 0.0817, 0.0894, 1.2505, 0.2829, -0.0346], [-1.2883, -1.3640, -0.8059, 0.2859, 0.6248]]]]) , <Tensor 0x7f83ecb97c40> ├── 'a' --> tensor([[[0.3129, 0.4439, 0.8225, 0.1990, 0.4839, 0.6035, 0.6834, 0.5054], │ [0.3481, 0.6482, 0.3350, 0.7787, 0.2522, 0.7648, 0.1572, 0.7762], │ [0.5089, 0.3098, 0.3571, 0.4281, 0.0979, 0.2945, 0.0717, 0.9534]]]) ├── 'b' --> tensor([[1.4898, 1.9633, 1.4804, 1.7171, 1.2424, 1.3674]]) └── 'c' --> <Tensor 0x7f83ecb97c10> ├── 'd' --> tensor([[7.]]) └── 'noise' --> tensor([[[[ 0.5697, -0.5898, -0.8598, -0.5347, 1.8237], [-1.5553, -0.3725, -0.7192, 0.9180, -1.4548], [ 1.4326, 0.0348, 0.1363, 1.8379, 0.4643], [ 0.0743, -0.4870, 0.6392, -0.0841, 0.3341]], [[-0.4257, 0.6298, 1.9045, -1.0348, 0.8331], [ 0.8333, 0.7284, -0.0489, -0.6122, 0.1061], [-0.0823, -1.4911, -1.6681, 2.9037, -0.0814], [-1.3788, 0.2873, 0.1422, 1.1302, 0.6525]], [[ 1.3702, 0.3678, 1.2022, 0.3144, 0.3844], [ 0.7469, 0.0259, 0.0321, 0.6302, -1.0288], [ 0.4854, 0.9360, -0.1610, 0.4796, 0.5281], [-0.3088, 0.4554, -0.6296, 0.0379, 0.4916]]]]) , <Tensor 0x7f83ecb97ca0> ├── 'a' --> tensor([[[0.5043, 0.1589, 0.5069, 0.4711, 0.3651, 0.7958, 0.4787, 0.3208], │ [0.4928, 0.9831, 0.0118, 0.6966, 0.4535, 0.9948, 0.7176, 0.1944], │ [0.0527, 0.1514, 0.3471, 0.2261, 0.1730, 0.7510, 0.1243, 0.8186]]]) ├── 'b' --> tensor([[1.8538, 1.2059, 1.9369, 1.0288, 1.7937, 1.0420]]) └── 'c' --> <Tensor 0x7f83ecb97c70> ├── 'd' --> tensor([[8.]]) └── 'noise' --> tensor([[[[ 1.6036e-01, 6.7931e-01, 1.4153e+00, 7.3368e-04, 2.9788e-01], [ 8.1740e-01, -8.6354e-01, -1.5628e+00, 2.3292e-01, -6.9122e-01], [ 1.5631e-01, -1.1406e+00, 2.9363e-01, 1.1081e+00, 7.3348e-01], [ 5.8475e-01, 2.4910e-01, -3.2656e-01, 1.2840e+00, -7.0661e-01]], [[-1.3394e+00, -8.0459e-01, 8.6640e-01, 1.9354e+00, 1.1346e+00], [-2.0040e-01, 6.3933e-01, -1.4943e+00, 7.8732e-01, -8.9331e-01], [-1.5937e+00, 4.8034e-01, -1.3983e+00, -7.2016e-01, 5.6987e-01], [ 8.3162e-01, -1.1837e+00, 1.4762e+00, 9.6509e-01, 2.4423e-02]], [[ 5.0932e-01, -1.0427e+00, 1.0486e+00, 3.6652e-01, 8.3352e-01], [ 6.6314e-02, -1.0065e+00, -1.1592e-01, -2.3717e-01, -6.5655e-02], [ 5.7803e-02, -6.4236e-01, 8.9233e-01, 9.9751e-01, 1.1912e-01], [ 6.5045e-02, -1.7835e-01, -9.1286e-01, -7.8407e-01, 5.5585e-01]]]]) ) mean0 & mean1: tensor(1.4702) tensor(1.4702) even_index_a0: tensor([[[0.7952, 0.5317, 0.9969, 0.7247, 0.9008, 0.8212, 0.3001, 0.8340], [0.4727, 0.5710, 0.9302, 0.8417, 0.7715, 0.8744, 0.8541, 0.4062]], [[0.1144, 0.8584, 0.1021, 0.4512, 0.2160, 0.4735, 0.1162, 0.4354], [0.8819, 0.8667, 0.0678, 0.5979, 0.7199, 0.5001, 0.9542, 0.4657]], [[0.3129, 0.4439, 0.8225, 0.1990, 0.4839, 0.6035, 0.6834, 0.5054], [0.5089, 0.3098, 0.3571, 0.4281, 0.0979, 0.2945, 0.0717, 0.9534]], [[0.5043, 0.1589, 0.5069, 0.4711, 0.3651, 0.7958, 0.4787, 0.3208], [0.0527, 0.1514, 0.3471, 0.2261, 0.1730, 0.7510, 0.1243, 0.8186]]]) even_index_a1: tensor([[[0.7952, 0.5317, 0.9969, 0.7247, 0.9008, 0.8212, 0.3001, 0.8340], [0.4727, 0.5710, 0.9302, 0.8417, 0.7715, 0.8744, 0.8541, 0.4062]], [[0.1144, 0.8584, 0.1021, 0.4512, 0.2160, 0.4735, 0.1162, 0.4354], [0.8819, 0.8667, 0.0678, 0.5979, 0.7199, 0.5001, 0.9542, 0.4657]], [[0.3129, 0.4439, 0.8225, 0.1990, 0.4839, 0.6035, 0.6834, 0.5054], [0.5089, 0.3098, 0.3571, 0.4281, 0.0979, 0.2945, 0.0717, 0.9534]], [[0.5043, 0.1589, 0.5069, 0.4711, 0.3651, 0.7958, 0.4787, 0.3208], [0.0527, 0.1514, 0.3471, 0.2261, 0.1730, 0.7510, 0.1243, 0.8186]]]) <Size 0x7f844caa3460> ├── 'a' --> torch.Size([1, 3, 8]) ├── 'b' --> torch.Size([1, 6]) └── 'c' --> <Size 0x7f83ecc12460> ├── 'd' --> torch.Size([1, 1]) └── 'noise' --> torch.Size([1, 3, 4, 5]) |
The implement with treetensor API is much simpler and clearer.