Customized Operations For Different Fields¶
Here is another example of the custom operations implemented with both native torch API and treetensor API.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 | import copy import torch import treetensor.torch as ttorch T, B = 3, 4 def with_nativetensor(batch_): mean_b_list = [] even_index_a_list = [] for i in range(len(batch_)): for k, v in batch_[i].items(): if k == 'a': v = v.float() even_index_a_list.append(v[::2]) elif k == 'b': v = v.float() transformed_v = torch.pow(v, 2) + 1.0 mean_b_list.append(transformed_v.mean()) elif k == 'c': for k1, v1 in v.items(): if k1 == 'd': v1 = v1.float() else: print('ignore keys: {}'.format(k1)) else: print('ignore keys: {}'.format(k)) for i in range(len(batch_)): for k in batch_[i].keys(): if k == 'd': batch_[i][k]['noise'] = torch.randn(size=(3, 4, 5)) mean_b = sum(mean_b_list) / len(mean_b_list) even_index_a = torch.stack(even_index_a_list, dim=0) return batch_, mean_b, even_index_a def with_treetensor(batch_): batch_ = [ttorch.tensor(b) for b in batch_] batch_ = ttorch.stack(batch_) batch_ = batch_.float() batch_.b = ttorch.pow(batch_.b, 2) + 1.0 batch_.c.noise = ttorch.randn(size=(B, 3, 4, 5)) mean_b = batch_.b.mean() even_index_a = batch_.a[:, ::2] batch_ = ttorch.split(batch_, split_size_or_sections=1, dim=0) return batch_, mean_b, even_index_a def get_data(): return { 'a': torch.rand(size=(T, 8)), 'b': torch.rand(size=(6,)), 'c': { 'd': torch.randint(0, 10, size=(1,)) } } if __name__ == "__main__": batch = [get_data() for _ in range(B)] batch0, mean0, even_index_a0 = with_nativetensor(copy.deepcopy(batch)) batch1, mean1, even_index_a1 = with_treetensor(copy.deepcopy(batch)) print(batch0) print('\n\n') print(batch1) assert torch.abs(mean0 - mean1) < 1e-6 print('mean0 & mean1:', mean0, mean1) print('\n') assert torch.abs((even_index_a0 - even_index_a1).max()) < 1e-6 print('even_index_a0:', even_index_a0) print('even_index_a1:', even_index_a1) assert len(batch0) == B assert len(batch1) == B assert isinstance(batch1[0], ttorch.Tensor) print(batch1[0].shape) |
The output should be like below, and all the assertions can be passed.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 | [{'a': tensor([[0.0686, 0.1086, 0.1260, 0.0717, 0.1534, 0.4248, 0.8107, 0.9792], [0.8358, 0.7622, 0.0328, 0.2995, 0.4200, 0.4186, 0.7946, 0.6256], [0.5847, 0.1426, 0.1226, 0.4180, 0.4402, 0.9689, 0.9807, 0.1741]]), 'b': tensor([0.2408, 0.0534, 0.2229, 0.3362, 0.0190, 0.5164]), 'c': {'d': tensor([0])}}, {'a': tensor([[0.5001, 0.0582, 0.2682, 0.9194, 0.8338, 0.3869, 0.4686, 0.1196], [0.8403, 0.9976, 0.7340, 0.1658, 0.1127, 0.6743, 0.1626, 0.5464], [0.0086, 0.5789, 0.5836, 0.5645, 0.2969, 0.5460, 0.3157, 0.0195]]), 'b': tensor([0.5788, 0.7546, 0.1636, 0.9155, 0.6175, 0.2375]), 'c': {'d': tensor([8])}}, {'a': tensor([[0.8284, 0.8438, 0.7916, 0.1016, 0.2882, 0.1145, 0.6095, 0.0657], [0.4021, 0.4337, 0.7934, 0.3543, 0.9150, 0.0930, 0.7815, 0.1391], [0.0771, 0.6045, 0.6823, 0.1701, 0.2156, 0.5855, 0.6355, 0.9394]]), 'b': tensor([0.9815, 0.2230, 0.6183, 0.0756, 0.1584, 0.6588]), 'c': {'d': tensor([1])}}, {'a': tensor([[0.0669, 0.6374, 0.1598, 0.7754, 0.2292, 0.0127, 0.4531, 0.9546], [0.1285, 0.1093, 0.6841, 0.0386, 0.5381, 0.5292, 0.2525, 0.4983], [0.4484, 0.7647, 0.6624, 0.1767, 0.2220, 0.4530, 0.6271, 0.0049]]), 'b': tensor([0.0129, 0.8459, 0.0295, 0.2026, 0.2082, 0.3370]), 'c': {'d': tensor([4])}}] (<Tensor 0x7fea7bb5cac0> ├── 'a' --> tensor([[[0.0686, 0.1086, 0.1260, 0.0717, 0.1534, 0.4248, 0.8107, 0.9792], │ [0.8358, 0.7622, 0.0328, 0.2995, 0.4200, 0.4186, 0.7946, 0.6256], │ [0.5847, 0.1426, 0.1226, 0.4180, 0.4402, 0.9689, 0.9807, 0.1741]]]) ├── 'b' --> tensor([[1.0580, 1.0029, 1.0497, 1.1130, 1.0004, 1.2667]]) └── 'c' --> <Tensor 0x7fea7bb5cb20> ├── 'd' --> tensor([[0.]]) └── 'noise' --> tensor([[[[ 1.4809, -0.8466, 1.0720, 0.7674, 1.2988], [-0.1379, -1.1622, -0.4715, 0.2469, 0.2314], [-0.4344, -2.3258, -0.1021, 0.5908, -0.0135], [ 0.2125, -0.7161, -0.6925, -0.8530, -0.3846]], [[ 0.6735, -0.9661, 0.5631, 0.9926, -3.1729], [-0.5627, 0.5281, 1.4474, -1.0707, 1.2655], [ 2.8633, -1.2859, 0.3502, 0.0993, 1.5777], [ 1.0557, -0.0891, 0.8168, -1.4842, 0.1666]], [[ 0.3919, -0.0320, 1.8884, -0.6711, -1.8297], [ 0.1443, 1.5710, -0.0992, -0.6423, 0.6008], [ 0.7714, 0.4063, 0.8384, -1.2446, 0.7972], [-0.6708, -0.1289, -0.4705, -1.0139, 1.0656]]]]) , <Tensor 0x7fea7bb5cbe0> ├── 'a' --> tensor([[[0.5001, 0.0582, 0.2682, 0.9194, 0.8338, 0.3869, 0.4686, 0.1196], │ [0.8403, 0.9976, 0.7340, 0.1658, 0.1127, 0.6743, 0.1626, 0.5464], │ [0.0086, 0.5789, 0.5836, 0.5645, 0.2969, 0.5460, 0.3157, 0.0195]]]) ├── 'b' --> tensor([[1.3350, 1.5694, 1.0268, 1.8381, 1.3813, 1.0564]]) └── 'c' --> <Tensor 0x7fea7bb5ca30> ├── 'd' --> tensor([[8.]]) └── 'noise' --> tensor([[[[-0.4667, 0.2000, 1.2140, 0.4890, 0.5628], [ 0.3238, 1.4185, -0.4402, -0.5429, 0.5033], [ 0.5419, -0.5719, 1.0255, 0.6894, -0.3771], [-0.7727, -2.1226, 0.0043, 0.2789, -0.8093]], [[-0.8213, -1.9765, 0.6290, 0.7305, -0.6955], [-2.0336, -0.3545, -0.6026, 0.6712, -1.2341], [ 0.9422, -1.2487, 1.7422, -0.4708, 0.5907], [ 0.1950, -1.3476, -0.0215, 0.6312, -0.7755]], [[-1.0473, -0.0140, 0.1878, 0.0064, 1.8081], [-2.4747, -0.3494, 0.0361, -0.6426, 0.4053], [-0.2242, 0.2586, 0.4661, -1.7540, -1.9929], [-0.8916, 0.7854, 1.5631, 0.2244, -0.4089]]]]) , <Tensor 0x7fea7bb5cc40> ├── 'a' --> tensor([[[0.8284, 0.8438, 0.7916, 0.1016, 0.2882, 0.1145, 0.6095, 0.0657], │ [0.4021, 0.4337, 0.7934, 0.3543, 0.9150, 0.0930, 0.7815, 0.1391], │ [0.0771, 0.6045, 0.6823, 0.1701, 0.2156, 0.5855, 0.6355, 0.9394]]]) ├── 'b' --> tensor([[1.9633, 1.0497, 1.3823, 1.0057, 1.0251, 1.4340]]) └── 'c' --> <Tensor 0x7fea7bb5cc10> ├── 'd' --> tensor([[1.]]) └── 'noise' --> tensor([[[[-0.8266, -0.2700, -0.5078, 0.6761, 0.2158], [ 0.6552, -0.5522, -0.0266, 1.0813, -1.5548], [ 2.3831, -0.8522, 0.6716, -0.2994, 0.6217], [-0.0827, 1.1236, 0.5660, -0.1817, 1.2789]], [[-1.3301, 1.3541, 0.5394, 0.1224, 0.3887], [-1.5058, -1.1427, 1.2118, 1.2243, 0.2612], [ 0.4264, -1.6123, -0.9390, -1.1020, -0.8110], [-0.0263, 0.1119, 0.5974, 0.2575, -0.8753]], [[ 1.7161, 1.6931, -2.6318, -0.2364, 0.1694], [-0.7443, 0.9577, 1.1957, -0.5674, -0.4003], [-1.4331, 0.6762, 1.2686, 0.6880, -0.4489], [-0.9069, -0.4411, 0.9527, 1.6256, -1.3127]]]]) , <Tensor 0x7fea7bb5cca0> ├── 'a' --> tensor([[[0.0669, 0.6374, 0.1598, 0.7754, 0.2292, 0.0127, 0.4531, 0.9546], │ [0.1285, 0.1093, 0.6841, 0.0386, 0.5381, 0.5292, 0.2525, 0.4983], │ [0.4484, 0.7647, 0.6624, 0.1767, 0.2220, 0.4530, 0.6271, 0.0049]]]) ├── 'b' --> tensor([[1.0002, 1.7155, 1.0009, 1.0410, 1.0434, 1.1136]]) └── 'c' --> <Tensor 0x7fea7bb5cc70> ├── 'd' --> tensor([[4.]]) └── 'noise' --> tensor([[[[-0.9848, -0.8801, -2.0229, 0.2040, 1.0911], [ 0.1213, 3.1088, -1.0881, 2.1307, -0.0602], [-0.5257, -0.5790, 0.5905, 0.6199, 0.0638], [-1.9652, -1.0409, 0.6534, 2.0500, 1.4129]], [[-0.9417, -1.2216, -0.6902, 0.6150, 0.1717], [ 0.5321, -0.1839, -1.9115, 1.6711, -0.6858], [-2.2324, -0.3537, 1.4024, 0.8769, -1.0161], [-0.6217, -0.3269, 0.1266, -0.5366, -0.3699]], [[ 0.0931, 0.4600, -0.0660, -0.3847, 0.5237], [ 0.7623, 0.6981, 1.2001, 1.2765, 0.6914], [ 0.9582, -0.2392, -1.2116, 0.0220, -1.5339], [-1.7506, 0.0362, 0.5130, 0.5271, -0.0502]]]]) ) mean0 & mean1: tensor(1.2280) tensor(1.2280) even_index_a0: tensor([[[0.0686, 0.1086, 0.1260, 0.0717, 0.1534, 0.4248, 0.8107, 0.9792], [0.5847, 0.1426, 0.1226, 0.4180, 0.4402, 0.9689, 0.9807, 0.1741]], [[0.5001, 0.0582, 0.2682, 0.9194, 0.8338, 0.3869, 0.4686, 0.1196], [0.0086, 0.5789, 0.5836, 0.5645, 0.2969, 0.5460, 0.3157, 0.0195]], [[0.8284, 0.8438, 0.7916, 0.1016, 0.2882, 0.1145, 0.6095, 0.0657], [0.0771, 0.6045, 0.6823, 0.1701, 0.2156, 0.5855, 0.6355, 0.9394]], [[0.0669, 0.6374, 0.1598, 0.7754, 0.2292, 0.0127, 0.4531, 0.9546], [0.4484, 0.7647, 0.6624, 0.1767, 0.2220, 0.4530, 0.6271, 0.0049]]]) even_index_a1: tensor([[[0.0686, 0.1086, 0.1260, 0.0717, 0.1534, 0.4248, 0.8107, 0.9792], [0.5847, 0.1426, 0.1226, 0.4180, 0.4402, 0.9689, 0.9807, 0.1741]], [[0.5001, 0.0582, 0.2682, 0.9194, 0.8338, 0.3869, 0.4686, 0.1196], [0.0086, 0.5789, 0.5836, 0.5645, 0.2969, 0.5460, 0.3157, 0.0195]], [[0.8284, 0.8438, 0.7916, 0.1016, 0.2882, 0.1145, 0.6095, 0.0657], [0.0771, 0.6045, 0.6823, 0.1701, 0.2156, 0.5855, 0.6355, 0.9394]], [[0.0669, 0.6374, 0.1598, 0.7754, 0.2292, 0.0127, 0.4531, 0.9546], [0.4484, 0.7647, 0.6624, 0.1767, 0.2220, 0.4530, 0.6271, 0.0049]]]) <Size 0x7feadba56460> ├── 'a' --> torch.Size([1, 3, 8]) ├── 'b' --> torch.Size([1, 6]) └── 'c' --> <Size 0x7fea7bbd7460> ├── 'd' --> torch.Size([1, 1]) └── 'noise' --> torch.Size([1, 3, 4, 5]) |
The implement with treetensor API is much simpler and clearer.