Customized Operations For Different Fields¶
Here is another example of the custom operations implemented with both native torch API and treetensor API.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 | import copy import torch import treetensor.torch as ttorch T, B = 3, 4 def with_nativetensor(batch_): mean_b_list = [] even_index_a_list = [] for i in range(len(batch_)): for k, v in batch_[i].items(): if k == 'a': v = v.float() even_index_a_list.append(v[::2]) elif k == 'b': v = v.float() transformed_v = torch.pow(v, 2) + 1.0 mean_b_list.append(transformed_v.mean()) elif k == 'c': for k1, v1 in v.items(): if k1 == 'd': v1 = v1.float() else: print('ignore keys: {}'.format(k1)) else: print('ignore keys: {}'.format(k)) for i in range(len(batch_)): for k in batch_[i].keys(): if k == 'd': batch_[i][k]['noise'] = torch.randn(size=(3, 4, 5)) mean_b = sum(mean_b_list) / len(mean_b_list) even_index_a = torch.stack(even_index_a_list, dim=0) return batch_, mean_b, even_index_a def with_treetensor(batch_): batch_ = [ttorch.tensor(b) for b in batch_] batch_ = ttorch.stack(batch_) batch_ = batch_.float() batch_.b = ttorch.pow(batch_.b, 2) + 1.0 batch_.c.noise = ttorch.randn(size=(B, 3, 4, 5)) mean_b = batch_.b.mean() even_index_a = batch_.a[:, ::2] batch_ = ttorch.split(batch_, split_size_or_sections=1, dim=0) return batch_, mean_b, even_index_a def get_data(): return { 'a': torch.rand(size=(T, 8)), 'b': torch.rand(size=(6,)), 'c': { 'd': torch.randint(0, 10, size=(1,)) } } if __name__ == "__main__": batch = [get_data() for _ in range(B)] batch0, mean0, even_index_a0 = with_nativetensor(copy.deepcopy(batch)) batch1, mean1, even_index_a1 = with_treetensor(copy.deepcopy(batch)) print(batch0) print('\n\n') print(batch1) assert torch.abs(mean0 - mean1) < 1e-6 print('mean0 & mean1:', mean0, mean1) print('\n') assert torch.abs((even_index_a0 - even_index_a1).max()) < 1e-6 print('even_index_a0:', even_index_a0) print('even_index_a1:', even_index_a1) assert len(batch0) == B assert len(batch1) == B assert isinstance(batch1[0], ttorch.Tensor) print(batch1[0].shape) |
The output should be like below, and all the assertions can be passed.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 | [{'a': tensor([[0.4572, 0.6080, 0.6765, 0.4616, 0.7006, 0.7138, 0.8332, 0.5279],
[0.5234, 0.1879, 0.9367, 0.2713, 0.7517, 0.4578, 0.8312, 0.1275],
[0.5743, 0.4873, 0.6173, 0.8256, 0.9331, 0.5081, 0.7060, 0.8910]]), 'b': tensor([0.9206, 0.9682, 0.2597, 0.5851, 0.8320, 1.0000]), 'c': {'d': tensor([8])}}, {'a': tensor([[0.6489, 0.0320, 0.2625, 0.8833, 0.0545, 0.6756, 0.1604, 0.1900],
[0.5600, 0.1550, 0.4830, 0.9901, 0.7525, 0.7140, 0.2213, 0.3848],
[0.9487, 0.0209, 0.2177, 0.6618, 0.7286, 0.7615, 0.0964, 0.4275]]), 'b': tensor([0.2169, 0.8077, 0.7291, 0.2017, 0.9401, 0.4131]), 'c': {'d': tensor([4])}}, {'a': tensor([[0.7510, 0.6268, 0.8116, 0.1843, 0.0712, 0.7653, 0.0770, 0.9667],
[0.5327, 0.2731, 0.9119, 0.0957, 0.3421, 0.8421, 0.7524, 0.9841],
[0.4037, 0.5359, 0.0451, 0.8508, 0.2864, 0.8289, 0.6049, 0.8590]]), 'b': tensor([0.9390, 0.0167, 0.3553, 0.2527, 0.6955, 0.0793]), 'c': {'d': tensor([3])}}, {'a': tensor([[0.2746, 0.6944, 0.0830, 0.9218, 0.3689, 0.6639, 0.9689, 0.8173],
[0.4369, 0.5265, 0.1490, 0.0346, 0.7360, 0.2842, 0.5820, 0.1121],
[0.8793, 0.7142, 0.8760, 0.7279, 0.4149, 0.9174, 0.4245, 0.6824]]), 'b': tensor([0.1762, 0.6831, 0.7020, 0.2706, 0.3201, 0.9509]), 'c': {'d': tensor([4])}}]
(<Tensor 0x7ffb457152e0>
├── 'a' --> tensor([[[0.4572, 0.6080, 0.6765, 0.4616, 0.7006, 0.7138, 0.8332, 0.5279],
│ [0.5234, 0.1879, 0.9367, 0.2713, 0.7517, 0.4578, 0.8312, 0.1275],
│ [0.5743, 0.4873, 0.6173, 0.8256, 0.9331, 0.5081, 0.7060, 0.8910]]])
├── 'b' --> tensor([[1.8476, 1.9375, 1.0674, 1.3423, 1.6922, 2.0000]])
└── 'c' --> <Tensor 0x7ffb45715340>
├── 'd' --> tensor([[8.]])
└── 'noise' --> tensor([[[[-4.6412e-01, -5.7780e-01, 8.0361e-02, 6.8379e-01, -1.9820e+00],
[ 1.2988e+00, -1.0754e+00, 2.8268e-01, 9.3715e-01, -3.8230e-01],
[ 5.1139e-01, 3.5472e-01, 3.1486e-01, -1.0898e+00, -1.4912e-01],
[-7.7645e-01, -9.3090e-01, -7.2468e-01, -6.2890e-01, 5.9210e-01]],
[[-1.6167e+00, -4.9676e-02, -1.3844e+00, 1.1189e+00, 5.4293e-01],
[-1.2420e+00, -8.8608e-01, 1.8914e+00, -4.5926e-02, 3.0714e-01],
[ 4.6076e-01, 2.2135e-01, 4.6556e-01, -8.7983e-04, 7.5775e-01],
[ 1.0634e+00, -1.3722e+00, 8.8384e-01, -9.6968e-01, 2.7805e-01]],
[[-1.0646e-01, 4.8121e-01, 3.7468e+00, -1.7515e+00, 6.9178e-01],
[-1.1562e-01, -2.0747e-01, 8.3824e-01, 8.6832e-02, -7.3484e-01],
[-4.9480e-01, 1.5194e+00, -6.6665e-02, 4.9116e-01, -4.9920e-01],
[ 6.5903e-02, -1.8747e+00, -1.3733e+00, -1.0089e+00, 6.9826e-01]]]])
, <Tensor 0x7ffb457153a0>
├── 'a' --> tensor([[[0.6489, 0.0320, 0.2625, 0.8833, 0.0545, 0.6756, 0.1604, 0.1900],
│ [0.5600, 0.1550, 0.4830, 0.9901, 0.7525, 0.7140, 0.2213, 0.3848],
│ [0.9487, 0.0209, 0.2177, 0.6618, 0.7286, 0.7615, 0.0964, 0.4275]]])
├── 'b' --> tensor([[1.0471, 1.6523, 1.5316, 1.0407, 1.8837, 1.1707]])
└── 'c' --> <Tensor 0x7ffb45715280>
├── 'd' --> tensor([[4.]])
└── 'noise' --> tensor([[[[-0.1838, 0.7267, 0.0637, 0.6531, -1.8037],
[ 0.8776, -0.4766, -1.1770, -0.9471, -1.0448],
[ 0.2301, -1.2404, 1.1525, -0.3125, -1.1866],
[-0.7215, 0.6348, 0.0648, 0.0371, 0.4146]],
[[ 0.6134, -0.5340, -1.0677, 0.2898, -0.6210],
[ 0.9603, 0.1196, 1.2600, -0.2438, -1.6112],
[-0.0975, 0.0962, 1.4514, 0.2641, -0.2844],
[-1.5464, -1.6153, -1.5432, -0.3866, 0.6795]],
[[-1.0840, 1.0936, 0.7873, 0.7940, 0.1981],
[ 0.3649, 1.2445, 0.2280, -1.5792, -0.1313],
[-0.6048, 2.2737, -0.9488, 0.1801, 1.1222],
[-1.1301, 0.1280, 0.2690, -0.1754, 0.8604]]]])
, <Tensor 0x7ffb45715400>
├── 'a' --> tensor([[[0.7510, 0.6268, 0.8116, 0.1843, 0.0712, 0.7653, 0.0770, 0.9667],
│ [0.5327, 0.2731, 0.9119, 0.0957, 0.3421, 0.8421, 0.7524, 0.9841],
│ [0.4037, 0.5359, 0.0451, 0.8508, 0.2864, 0.8289, 0.6049, 0.8590]]])
├── 'b' --> tensor([[1.8817, 1.0003, 1.1263, 1.0638, 1.4837, 1.0063]])
└── 'c' --> <Tensor 0x7ffb457153d0>
├── 'd' --> tensor([[3.]])
└── 'noise' --> tensor([[[[ 0.0319, -1.3346, -0.9664, 0.5136, 1.6280],
[-1.7270, -0.3694, -1.0170, 0.9421, 1.7422],
[ 0.6287, -0.1087, -0.6610, 0.5958, -1.8940],
[-1.0542, 0.0810, -0.6481, -1.7150, 0.5486]],
[[ 0.4111, 0.4410, 0.7241, -2.3636, -0.4487],
[ 1.4885, 0.3644, 0.6788, -0.9568, -1.0254],
[ 0.8320, -0.2660, 0.3278, -0.9613, 0.5035],
[-0.1423, -1.0684, -0.1095, 0.3794, -1.1159]],
[[ 0.1101, -1.6556, -0.2003, -0.0055, -1.6199],
[ 0.2687, -0.8249, -0.1211, -1.0027, -0.1828],
[-1.7109, 0.3344, 1.0112, 0.5179, 0.3586],
[-1.6385, 0.2950, -0.0268, -1.7857, 0.2965]]]])
, <Tensor 0x7ffb45715460>
├── 'a' --> tensor([[[0.2746, 0.6944, 0.0830, 0.9218, 0.3689, 0.6639, 0.9689, 0.8173],
│ [0.4369, 0.5265, 0.1490, 0.0346, 0.7360, 0.2842, 0.5820, 0.1121],
│ [0.8793, 0.7142, 0.8760, 0.7279, 0.4149, 0.9174, 0.4245, 0.6824]]])
├── 'b' --> tensor([[1.0310, 1.4666, 1.4927, 1.0732, 1.1024, 1.9043]])
└── 'c' --> <Tensor 0x7ffb45715430>
├── 'd' --> tensor([[4.]])
└── 'noise' --> tensor([[[[-0.1884, -1.1912, -0.0673, -0.3911, 0.6435],
[ 0.8307, -0.6875, -0.6923, -1.0944, 0.8886],
[-3.0533, 0.7851, -1.2864, -2.0481, -0.2371],
[-0.3512, 0.1200, 0.3074, -0.2534, 0.5491]],
[[-0.7389, 1.0999, 0.2123, 1.1953, 0.0311],
[-1.4438, -0.7222, 0.3789, -0.2664, 0.5498],
[ 0.5244, 0.4530, 1.4953, 0.9694, 1.0271],
[ 1.6876, -0.7055, 0.9537, 0.7722, 0.4099]],
[[ 0.1919, 1.8940, -0.5216, 0.9805, 0.4321],
[ 0.4256, 1.4371, 0.2084, -0.2604, -0.2164],
[-0.1987, 0.7131, 2.0972, 0.7461, -0.4851],
[-0.2310, -1.3240, -1.4171, -1.7092, 1.2812]]]])
)
mean0 & mean1: tensor(1.4102) tensor(1.4102)
even_index_a0: tensor([[[0.4572, 0.6080, 0.6765, 0.4616, 0.7006, 0.7138, 0.8332, 0.5279],
[0.5743, 0.4873, 0.6173, 0.8256, 0.9331, 0.5081, 0.7060, 0.8910]],
[[0.6489, 0.0320, 0.2625, 0.8833, 0.0545, 0.6756, 0.1604, 0.1900],
[0.9487, 0.0209, 0.2177, 0.6618, 0.7286, 0.7615, 0.0964, 0.4275]],
[[0.7510, 0.6268, 0.8116, 0.1843, 0.0712, 0.7653, 0.0770, 0.9667],
[0.4037, 0.5359, 0.0451, 0.8508, 0.2864, 0.8289, 0.6049, 0.8590]],
[[0.2746, 0.6944, 0.0830, 0.9218, 0.3689, 0.6639, 0.9689, 0.8173],
[0.8793, 0.7142, 0.8760, 0.7279, 0.4149, 0.9174, 0.4245, 0.6824]]])
even_index_a1: tensor([[[0.4572, 0.6080, 0.6765, 0.4616, 0.7006, 0.7138, 0.8332, 0.5279],
[0.5743, 0.4873, 0.6173, 0.8256, 0.9331, 0.5081, 0.7060, 0.8910]],
[[0.6489, 0.0320, 0.2625, 0.8833, 0.0545, 0.6756, 0.1604, 0.1900],
[0.9487, 0.0209, 0.2177, 0.6618, 0.7286, 0.7615, 0.0964, 0.4275]],
[[0.7510, 0.6268, 0.8116, 0.1843, 0.0712, 0.7653, 0.0770, 0.9667],
[0.4037, 0.5359, 0.0451, 0.8508, 0.2864, 0.8289, 0.6049, 0.8590]],
[[0.2746, 0.6944, 0.0830, 0.9218, 0.3689, 0.6639, 0.9689, 0.8173],
[0.8793, 0.7142, 0.8760, 0.7279, 0.4149, 0.9174, 0.4245, 0.6824]]])
<Size 0x7ffbac129f40>
├── 'a' --> torch.Size([1, 3, 8])
├── 'b' --> torch.Size([1, 6])
└── 'c' --> <Size 0x7ffb45710ee0>
├── 'd' --> torch.Size([1, 1])
└── 'noise' --> torch.Size([1, 3, 4, 5])
|
The implement with treetensor API is much simpler and clearer.