Customized Operations For Different Fields¶
Here is another example of the custom operations implemented with both native torch API and treetensor API.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 | import copy import torch import treetensor.torch as ttorch T, B = 3, 4 def with_nativetensor(batch_): mean_b_list = [] even_index_a_list = [] for i in range(len(batch_)): for k, v in batch_[i].items(): if k == 'a': v = v.float() even_index_a_list.append(v[::2]) elif k == 'b': v = v.float() transformed_v = torch.pow(v, 2) + 1.0 mean_b_list.append(transformed_v.mean()) elif k == 'c': for k1, v1 in v.items(): if k1 == 'd': v1 = v1.float() else: print('ignore keys: {}'.format(k1)) else: print('ignore keys: {}'.format(k)) for i in range(len(batch_)): for k in batch_[i].keys(): if k == 'd': batch_[i][k]['noise'] = torch.randn(size=(3, 4, 5)) mean_b = sum(mean_b_list) / len(mean_b_list) even_index_a = torch.stack(even_index_a_list, dim=0) return batch_, mean_b, even_index_a def with_treetensor(batch_): batch_ = [ttorch.tensor(b) for b in batch_] batch_ = ttorch.stack(batch_) batch_ = batch_.float() batch_.b = ttorch.pow(batch_.b, 2) + 1.0 batch_.c.noise = ttorch.randn(size=(B, 3, 4, 5)) mean_b = batch_.b.mean() even_index_a = batch_.a[:, ::2] batch_ = ttorch.split(batch_, split_size_or_sections=1, dim=0) return batch_, mean_b, even_index_a def get_data(): return { 'a': torch.rand(size=(T, 8)), 'b': torch.rand(size=(6,)), 'c': { 'd': torch.randint(0, 10, size=(1,)) } } if __name__ == "__main__": batch = [get_data() for _ in range(B)] batch0, mean0, even_index_a0 = with_nativetensor(copy.deepcopy(batch)) batch1, mean1, even_index_a1 = with_treetensor(copy.deepcopy(batch)) print(batch0) print('\n\n') print(batch1) assert torch.abs(mean0 - mean1) < 1e-6 print('mean0 & mean1:', mean0, mean1) print('\n') assert torch.abs((even_index_a0 - even_index_a1).max()) < 1e-6 print('even_index_a0:', even_index_a0) print('even_index_a1:', even_index_a1) assert len(batch0) == B assert len(batch1) == B assert isinstance(batch1[0], ttorch.Tensor) print(batch1[0].shape) |
The output should be like below, and all the assertions can be passed.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 | [{'a': tensor([[0.4572, 0.6080, 0.6765, 0.4616, 0.7006, 0.7138, 0.8332, 0.5279], [0.5234, 0.1879, 0.9367, 0.2713, 0.7517, 0.4578, 0.8312, 0.1275], [0.5743, 0.4873, 0.6173, 0.8256, 0.9331, 0.5081, 0.7060, 0.8910]]), 'b': tensor([0.9206, 0.9682, 0.2597, 0.5851, 0.8320, 1.0000]), 'c': {'d': tensor([8])}}, {'a': tensor([[0.6489, 0.0320, 0.2625, 0.8833, 0.0545, 0.6756, 0.1604, 0.1900], [0.5600, 0.1550, 0.4830, 0.9901, 0.7525, 0.7140, 0.2213, 0.3848], [0.9487, 0.0209, 0.2177, 0.6618, 0.7286, 0.7615, 0.0964, 0.4275]]), 'b': tensor([0.2169, 0.8077, 0.7291, 0.2017, 0.9401, 0.4131]), 'c': {'d': tensor([4])}}, {'a': tensor([[0.7510, 0.6268, 0.8116, 0.1843, 0.0712, 0.7653, 0.0770, 0.9667], [0.5327, 0.2731, 0.9119, 0.0957, 0.3421, 0.8421, 0.7524, 0.9841], [0.4037, 0.5359, 0.0451, 0.8508, 0.2864, 0.8289, 0.6049, 0.8590]]), 'b': tensor([0.9390, 0.0167, 0.3553, 0.2527, 0.6955, 0.0793]), 'c': {'d': tensor([3])}}, {'a': tensor([[0.2746, 0.6944, 0.0830, 0.9218, 0.3689, 0.6639, 0.9689, 0.8173], [0.4369, 0.5265, 0.1490, 0.0346, 0.7360, 0.2842, 0.5820, 0.1121], [0.8793, 0.7142, 0.8760, 0.7279, 0.4149, 0.9174, 0.4245, 0.6824]]), 'b': tensor([0.1762, 0.6831, 0.7020, 0.2706, 0.3201, 0.9509]), 'c': {'d': tensor([4])}}] (<Tensor 0x7ffb457152e0> ├── 'a' --> tensor([[[0.4572, 0.6080, 0.6765, 0.4616, 0.7006, 0.7138, 0.8332, 0.5279], │ [0.5234, 0.1879, 0.9367, 0.2713, 0.7517, 0.4578, 0.8312, 0.1275], │ [0.5743, 0.4873, 0.6173, 0.8256, 0.9331, 0.5081, 0.7060, 0.8910]]]) ├── 'b' --> tensor([[1.8476, 1.9375, 1.0674, 1.3423, 1.6922, 2.0000]]) └── 'c' --> <Tensor 0x7ffb45715340> ├── 'd' --> tensor([[8.]]) └── 'noise' --> tensor([[[[-4.6412e-01, -5.7780e-01, 8.0361e-02, 6.8379e-01, -1.9820e+00], [ 1.2988e+00, -1.0754e+00, 2.8268e-01, 9.3715e-01, -3.8230e-01], [ 5.1139e-01, 3.5472e-01, 3.1486e-01, -1.0898e+00, -1.4912e-01], [-7.7645e-01, -9.3090e-01, -7.2468e-01, -6.2890e-01, 5.9210e-01]], [[-1.6167e+00, -4.9676e-02, -1.3844e+00, 1.1189e+00, 5.4293e-01], [-1.2420e+00, -8.8608e-01, 1.8914e+00, -4.5926e-02, 3.0714e-01], [ 4.6076e-01, 2.2135e-01, 4.6556e-01, -8.7983e-04, 7.5775e-01], [ 1.0634e+00, -1.3722e+00, 8.8384e-01, -9.6968e-01, 2.7805e-01]], [[-1.0646e-01, 4.8121e-01, 3.7468e+00, -1.7515e+00, 6.9178e-01], [-1.1562e-01, -2.0747e-01, 8.3824e-01, 8.6832e-02, -7.3484e-01], [-4.9480e-01, 1.5194e+00, -6.6665e-02, 4.9116e-01, -4.9920e-01], [ 6.5903e-02, -1.8747e+00, -1.3733e+00, -1.0089e+00, 6.9826e-01]]]]) , <Tensor 0x7ffb457153a0> ├── 'a' --> tensor([[[0.6489, 0.0320, 0.2625, 0.8833, 0.0545, 0.6756, 0.1604, 0.1900], │ [0.5600, 0.1550, 0.4830, 0.9901, 0.7525, 0.7140, 0.2213, 0.3848], │ [0.9487, 0.0209, 0.2177, 0.6618, 0.7286, 0.7615, 0.0964, 0.4275]]]) ├── 'b' --> tensor([[1.0471, 1.6523, 1.5316, 1.0407, 1.8837, 1.1707]]) └── 'c' --> <Tensor 0x7ffb45715280> ├── 'd' --> tensor([[4.]]) └── 'noise' --> tensor([[[[-0.1838, 0.7267, 0.0637, 0.6531, -1.8037], [ 0.8776, -0.4766, -1.1770, -0.9471, -1.0448], [ 0.2301, -1.2404, 1.1525, -0.3125, -1.1866], [-0.7215, 0.6348, 0.0648, 0.0371, 0.4146]], [[ 0.6134, -0.5340, -1.0677, 0.2898, -0.6210], [ 0.9603, 0.1196, 1.2600, -0.2438, -1.6112], [-0.0975, 0.0962, 1.4514, 0.2641, -0.2844], [-1.5464, -1.6153, -1.5432, -0.3866, 0.6795]], [[-1.0840, 1.0936, 0.7873, 0.7940, 0.1981], [ 0.3649, 1.2445, 0.2280, -1.5792, -0.1313], [-0.6048, 2.2737, -0.9488, 0.1801, 1.1222], [-1.1301, 0.1280, 0.2690, -0.1754, 0.8604]]]]) , <Tensor 0x7ffb45715400> ├── 'a' --> tensor([[[0.7510, 0.6268, 0.8116, 0.1843, 0.0712, 0.7653, 0.0770, 0.9667], │ [0.5327, 0.2731, 0.9119, 0.0957, 0.3421, 0.8421, 0.7524, 0.9841], │ [0.4037, 0.5359, 0.0451, 0.8508, 0.2864, 0.8289, 0.6049, 0.8590]]]) ├── 'b' --> tensor([[1.8817, 1.0003, 1.1263, 1.0638, 1.4837, 1.0063]]) └── 'c' --> <Tensor 0x7ffb457153d0> ├── 'd' --> tensor([[3.]]) └── 'noise' --> tensor([[[[ 0.0319, -1.3346, -0.9664, 0.5136, 1.6280], [-1.7270, -0.3694, -1.0170, 0.9421, 1.7422], [ 0.6287, -0.1087, -0.6610, 0.5958, -1.8940], [-1.0542, 0.0810, -0.6481, -1.7150, 0.5486]], [[ 0.4111, 0.4410, 0.7241, -2.3636, -0.4487], [ 1.4885, 0.3644, 0.6788, -0.9568, -1.0254], [ 0.8320, -0.2660, 0.3278, -0.9613, 0.5035], [-0.1423, -1.0684, -0.1095, 0.3794, -1.1159]], [[ 0.1101, -1.6556, -0.2003, -0.0055, -1.6199], [ 0.2687, -0.8249, -0.1211, -1.0027, -0.1828], [-1.7109, 0.3344, 1.0112, 0.5179, 0.3586], [-1.6385, 0.2950, -0.0268, -1.7857, 0.2965]]]]) , <Tensor 0x7ffb45715460> ├── 'a' --> tensor([[[0.2746, 0.6944, 0.0830, 0.9218, 0.3689, 0.6639, 0.9689, 0.8173], │ [0.4369, 0.5265, 0.1490, 0.0346, 0.7360, 0.2842, 0.5820, 0.1121], │ [0.8793, 0.7142, 0.8760, 0.7279, 0.4149, 0.9174, 0.4245, 0.6824]]]) ├── 'b' --> tensor([[1.0310, 1.4666, 1.4927, 1.0732, 1.1024, 1.9043]]) └── 'c' --> <Tensor 0x7ffb45715430> ├── 'd' --> tensor([[4.]]) └── 'noise' --> tensor([[[[-0.1884, -1.1912, -0.0673, -0.3911, 0.6435], [ 0.8307, -0.6875, -0.6923, -1.0944, 0.8886], [-3.0533, 0.7851, -1.2864, -2.0481, -0.2371], [-0.3512, 0.1200, 0.3074, -0.2534, 0.5491]], [[-0.7389, 1.0999, 0.2123, 1.1953, 0.0311], [-1.4438, -0.7222, 0.3789, -0.2664, 0.5498], [ 0.5244, 0.4530, 1.4953, 0.9694, 1.0271], [ 1.6876, -0.7055, 0.9537, 0.7722, 0.4099]], [[ 0.1919, 1.8940, -0.5216, 0.9805, 0.4321], [ 0.4256, 1.4371, 0.2084, -0.2604, -0.2164], [-0.1987, 0.7131, 2.0972, 0.7461, -0.4851], [-0.2310, -1.3240, -1.4171, -1.7092, 1.2812]]]]) ) mean0 & mean1: tensor(1.4102) tensor(1.4102) even_index_a0: tensor([[[0.4572, 0.6080, 0.6765, 0.4616, 0.7006, 0.7138, 0.8332, 0.5279], [0.5743, 0.4873, 0.6173, 0.8256, 0.9331, 0.5081, 0.7060, 0.8910]], [[0.6489, 0.0320, 0.2625, 0.8833, 0.0545, 0.6756, 0.1604, 0.1900], [0.9487, 0.0209, 0.2177, 0.6618, 0.7286, 0.7615, 0.0964, 0.4275]], [[0.7510, 0.6268, 0.8116, 0.1843, 0.0712, 0.7653, 0.0770, 0.9667], [0.4037, 0.5359, 0.0451, 0.8508, 0.2864, 0.8289, 0.6049, 0.8590]], [[0.2746, 0.6944, 0.0830, 0.9218, 0.3689, 0.6639, 0.9689, 0.8173], [0.8793, 0.7142, 0.8760, 0.7279, 0.4149, 0.9174, 0.4245, 0.6824]]]) even_index_a1: tensor([[[0.4572, 0.6080, 0.6765, 0.4616, 0.7006, 0.7138, 0.8332, 0.5279], [0.5743, 0.4873, 0.6173, 0.8256, 0.9331, 0.5081, 0.7060, 0.8910]], [[0.6489, 0.0320, 0.2625, 0.8833, 0.0545, 0.6756, 0.1604, 0.1900], [0.9487, 0.0209, 0.2177, 0.6618, 0.7286, 0.7615, 0.0964, 0.4275]], [[0.7510, 0.6268, 0.8116, 0.1843, 0.0712, 0.7653, 0.0770, 0.9667], [0.4037, 0.5359, 0.0451, 0.8508, 0.2864, 0.8289, 0.6049, 0.8590]], [[0.2746, 0.6944, 0.0830, 0.9218, 0.3689, 0.6639, 0.9689, 0.8173], [0.8793, 0.7142, 0.8760, 0.7279, 0.4149, 0.9174, 0.4245, 0.6824]]]) <Size 0x7ffbac129f40> ├── 'a' --> torch.Size([1, 3, 8]) ├── 'b' --> torch.Size([1, 6]) └── 'c' --> <Size 0x7ffb45710ee0> ├── 'd' --> torch.Size([1, 1]) └── 'noise' --> torch.Size([1, 3, 4, 5]) |
The implement with treetensor API is much simpler and clearer.