Customized Operations For Different Fields¶
Here is another example of the custom operations implemented with both native torch API and treetensor API.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 | import copy import torch import treetensor.torch as ttorch T, B = 3, 4 def with_nativetensor(batch_): mean_b_list = [] even_index_a_list = [] for i in range(len(batch_)): for k, v in batch_[i].items(): if k == 'a': v = v.float() even_index_a_list.append(v[::2]) elif k == 'b': v = v.float() transformed_v = torch.pow(v, 2) + 1.0 mean_b_list.append(transformed_v.mean()) elif k == 'c': for k1, v1 in v.items(): if k1 == 'd': v1 = v1.float() else: print('ignore keys: {}'.format(k1)) else: print('ignore keys: {}'.format(k)) for i in range(len(batch_)): for k in batch_[i].keys(): if k == 'd': batch_[i][k]['noise'] = torch.randn(size=(3, 4, 5)) mean_b = sum(mean_b_list) / len(mean_b_list) even_index_a = torch.stack(even_index_a_list, dim=0) return batch_, mean_b, even_index_a def with_treetensor(batch_): batch_ = [ttorch.tensor(b) for b in batch_] batch_ = ttorch.stack(batch_) batch_ = batch_.float() batch_.b = ttorch.pow(batch_.b, 2) + 1.0 batch_.c.noise = ttorch.randn(size=(B, 3, 4, 5)) mean_b = batch_.b.mean() even_index_a = batch_.a[:, ::2] batch_ = ttorch.split(batch_, split_size_or_sections=1, dim=0) return batch_, mean_b, even_index_a def get_data(): return { 'a': torch.rand(size=(T, 8)), 'b': torch.rand(size=(6,)), 'c': { 'd': torch.randint(0, 10, size=(1,)) } } if __name__ == "__main__": batch = [get_data() for _ in range(B)] batch0, mean0, even_index_a0 = with_nativetensor(copy.deepcopy(batch)) batch1, mean1, even_index_a1 = with_treetensor(copy.deepcopy(batch)) print(batch0) print('\n\n') print(batch1) assert torch.abs(mean0 - mean1) < 1e-6 print('mean0 & mean1:', mean0, mean1) print('\n') assert torch.abs((even_index_a0 - even_index_a1).max()) < 1e-6 print('even_index_a0:', even_index_a0) print('even_index_a1:', even_index_a1) assert len(batch0) == B assert len(batch1) == B assert isinstance(batch1[0], ttorch.Tensor) print(batch1[0].shape) |
The output should be like below, and all the assertions can be passed.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 | [{'a': tensor([[0.2909, 0.8774, 0.5245, 0.2834, 0.5976, 0.6227, 0.4479, 0.5658], [0.1502, 0.0066, 0.2932, 0.7347, 0.4004, 0.5038, 0.3965, 0.8740], [0.9986, 0.1072, 0.5157, 0.1964, 0.1072, 0.1204, 0.6982, 0.2265]]), 'b': tensor([0.6827, 0.6241, 0.4580, 0.2430, 0.0198, 0.4970]), 'c': {'d': tensor([7])}}, {'a': tensor([[0.2647, 0.5880, 0.1264, 0.1287, 0.3078, 0.3299, 0.1620, 0.6071], [0.5385, 0.2389, 0.3826, 0.3370, 0.1770, 0.9450, 0.2802, 0.8462], [0.7176, 0.5219, 0.2742, 0.4196, 0.4705, 0.5566, 0.0080, 0.3343]]), 'b': tensor([0.7906, 0.6741, 0.2049, 0.5595, 0.9478, 0.7007]), 'c': {'d': tensor([3])}}, {'a': tensor([[0.8778, 0.1243, 0.9530, 0.1061, 0.8310, 0.0896, 0.8134, 0.3172], [0.0549, 0.6482, 0.8368, 0.7538, 0.2837, 0.1982, 0.6072, 0.8325], [0.1204, 0.1436, 0.3943, 0.2892, 0.8983, 0.7299, 0.5482, 0.5508]]), 'b': tensor([0.7508, 0.2203, 0.4806, 0.0496, 0.5695, 0.0591]), 'c': {'d': tensor([0])}}, {'a': tensor([[0.0139, 0.5196, 0.4118, 0.8342, 0.2534, 0.9744, 0.8006, 0.8289], [0.5989, 0.7918, 0.3672, 0.7529, 0.7780, 0.8163, 0.0953, 0.8385], [0.4482, 0.9899, 0.5385, 0.3456, 0.6094, 0.8605, 0.0776, 0.2461]]), 'b': tensor([0.3306, 0.7661, 0.6920, 0.6690, 0.2996, 0.6166]), 'c': {'d': tensor([3])}}] (<Tensor 0x7fce010dd1f0> ├── 'a' --> tensor([[[0.2909, 0.8774, 0.5245, 0.2834, 0.5976, 0.6227, 0.4479, 0.5658], │ [0.1502, 0.0066, 0.2932, 0.7347, 0.4004, 0.5038, 0.3965, 0.8740], │ [0.9986, 0.1072, 0.5157, 0.1964, 0.1072, 0.1204, 0.6982, 0.2265]]]) ├── 'b' --> tensor([[1.4660, 1.3895, 1.2097, 1.0590, 1.0004, 1.2470]]) └── 'c' --> <Tensor 0x7fce010dd250> ├── 'd' --> tensor([[7.]]) └── 'noise' --> tensor([[[[-0.9183, 0.6631, 1.2223, 0.9913, -0.7930], [-1.7144, -1.5954, 0.3963, -0.0340, -1.7042], [-1.1111, 0.4383, -1.1122, 0.2265, 0.6809], [-1.8948, -1.1574, 0.5751, 0.5739, -0.5946]], [[-0.3911, 0.5059, -1.5592, 0.8105, -0.4599], [ 0.1479, 1.8968, 0.3216, 0.7993, 0.4830], [-0.7883, -1.1408, -0.7055, 0.5743, -0.1041], [ 0.6040, -1.2881, 1.9320, 0.8264, 1.0876]], [[-2.0007, -0.7446, -0.7728, -0.0177, -0.7398], [-0.2534, -0.0327, -0.4782, -3.0775, 2.2030], [-0.4641, 0.5182, 1.5898, 1.0946, 0.2854], [-0.6921, 0.8066, 1.6027, -0.1465, 0.4847]]]]) , <Tensor 0x7fce010dd310> ├── 'a' --> tensor([[[0.2647, 0.5880, 0.1264, 0.1287, 0.3078, 0.3299, 0.1620, 0.6071], │ [0.5385, 0.2389, 0.3826, 0.3370, 0.1770, 0.9450, 0.2802, 0.8462], │ [0.7176, 0.5219, 0.2742, 0.4196, 0.4705, 0.5566, 0.0080, 0.3343]]]) ├── 'b' --> tensor([[1.6251, 1.4544, 1.0420, 1.3130, 1.8983, 1.4910]]) └── 'c' --> <Tensor 0x7fce010dd190> ├── 'd' --> tensor([[3.]]) └── 'noise' --> tensor([[[[ 2.1365, 0.2016, 1.4825, -0.3637, -0.0737], [-1.7718, -0.8321, -1.6644, 0.1703, 1.6251], [ 0.6342, -2.0468, -0.6453, 0.6979, 0.6614], [-0.2988, -0.5733, -0.5838, 1.0056, -0.0442]], [[ 1.2284, 0.7206, 0.9217, 0.6188, -0.0728], [ 0.4928, 0.5809, 0.0663, 1.6436, 1.4629], [ 0.0755, 0.6302, -0.0471, -0.4710, 1.2841], [ 0.8259, -0.6029, -1.3071, 1.2015, 1.1368]], [[ 0.8509, -0.8950, -0.4758, -0.2302, 0.6813], [-0.1901, 1.1689, -0.7822, 0.4300, 0.6451], [-2.7162, 0.5373, 0.2350, -0.5809, 2.3647], [-0.5974, -0.7653, 1.2431, -0.0800, -0.8557]]]]) , <Tensor 0x7fce010dd370> ├── 'a' --> tensor([[[0.8778, 0.1243, 0.9530, 0.1061, 0.8310, 0.0896, 0.8134, 0.3172], │ [0.0549, 0.6482, 0.8368, 0.7538, 0.2837, 0.1982, 0.6072, 0.8325], │ [0.1204, 0.1436, 0.3943, 0.2892, 0.8983, 0.7299, 0.5482, 0.5508]]]) ├── 'b' --> tensor([[1.5637, 1.0485, 1.2310, 1.0025, 1.3243, 1.0035]]) └── 'c' --> <Tensor 0x7fce010dd340> ├── 'd' --> tensor([[0.]]) └── 'noise' --> tensor([[[[-1.7963, 0.2216, -0.8436, 0.6543, -0.3107], [ 1.0097, 2.1742, 0.5098, 0.0165, -2.1901], [-0.8981, 0.0152, 0.2663, -0.3163, -0.0566], [-1.8476, -0.3444, 0.3505, 2.2829, 0.5392]], [[-0.9531, 1.4015, -0.9820, 0.1828, -1.0707], [-0.6989, -0.5245, -0.1681, -1.2357, -0.0397], [-0.0702, 0.1157, 0.0083, -0.3045, -0.1781], [ 0.5893, 0.0633, -0.1190, 0.7433, -0.4827]], [[-0.8336, 0.1313, 0.1840, 0.1536, -0.1993], [ 0.7144, 0.1805, 0.0081, -0.6981, -0.7325], [ 0.3349, 0.5839, -0.5527, 1.0175, -0.4347], [ 0.1843, 0.0505, 0.3090, -0.1539, -0.0230]]]]) , <Tensor 0x7fce010dd3d0> ├── 'a' --> tensor([[[0.0139, 0.5196, 0.4118, 0.8342, 0.2534, 0.9744, 0.8006, 0.8289], │ [0.5989, 0.7918, 0.3672, 0.7529, 0.7780, 0.8163, 0.0953, 0.8385], │ [0.4482, 0.9899, 0.5385, 0.3456, 0.6094, 0.8605, 0.0776, 0.2461]]]) ├── 'b' --> tensor([[1.1093, 1.5869, 1.4788, 1.4476, 1.0898, 1.3802]]) └── 'c' --> <Tensor 0x7fce010dd3a0> ├── 'd' --> tensor([[3.]]) └── 'noise' --> tensor([[[[ 0.1460, 0.5519, 0.1847, -0.2209, -2.3965], [ 0.0868, -2.5259, -1.6959, 1.2219, -0.7213], [ 0.2837, -0.7940, -1.1715, 1.0738, 0.1309], [ 0.0484, 0.1060, -0.4734, 0.3928, -0.1920]], [[-0.1676, 1.2269, 1.3391, 1.1828, -1.3600], [-1.9169, -1.2024, 0.8078, 0.2364, -0.6363], [ 0.2980, 0.2787, 1.8287, -1.6175, 0.2046], [ 0.0681, -0.7894, -0.0525, -2.3374, -1.0105]], [[ 1.4547, -1.1449, 0.3292, 1.1134, 1.1361], [ 0.4465, -0.6099, -1.6036, 0.7052, 0.8402], [-1.7581, -0.3325, 1.1684, 0.6402, -0.3822], [ 0.8802, 0.9288, 1.2364, -0.6207, -0.9612]]]]) ) mean0 & mean1: tensor(1.3109) tensor(1.3109) even_index_a0: tensor([[[0.2909, 0.8774, 0.5245, 0.2834, 0.5976, 0.6227, 0.4479, 0.5658], [0.9986, 0.1072, 0.5157, 0.1964, 0.1072, 0.1204, 0.6982, 0.2265]], [[0.2647, 0.5880, 0.1264, 0.1287, 0.3078, 0.3299, 0.1620, 0.6071], [0.7176, 0.5219, 0.2742, 0.4196, 0.4705, 0.5566, 0.0080, 0.3343]], [[0.8778, 0.1243, 0.9530, 0.1061, 0.8310, 0.0896, 0.8134, 0.3172], [0.1204, 0.1436, 0.3943, 0.2892, 0.8983, 0.7299, 0.5482, 0.5508]], [[0.0139, 0.5196, 0.4118, 0.8342, 0.2534, 0.9744, 0.8006, 0.8289], [0.4482, 0.9899, 0.5385, 0.3456, 0.6094, 0.8605, 0.0776, 0.2461]]]) even_index_a1: tensor([[[0.2909, 0.8774, 0.5245, 0.2834, 0.5976, 0.6227, 0.4479, 0.5658], [0.9986, 0.1072, 0.5157, 0.1964, 0.1072, 0.1204, 0.6982, 0.2265]], [[0.2647, 0.5880, 0.1264, 0.1287, 0.3078, 0.3299, 0.1620, 0.6071], [0.7176, 0.5219, 0.2742, 0.4196, 0.4705, 0.5566, 0.0080, 0.3343]], [[0.8778, 0.1243, 0.9530, 0.1061, 0.8310, 0.0896, 0.8134, 0.3172], [0.1204, 0.1436, 0.3943, 0.2892, 0.8983, 0.7299, 0.5482, 0.5508]], [[0.0139, 0.5196, 0.4118, 0.8342, 0.2534, 0.9744, 0.8006, 0.8289], [0.4482, 0.9899, 0.5385, 0.3456, 0.6094, 0.8605, 0.0776, 0.2461]]]) <Size 0x7fce0114a970> ├── 'a' --> torch.Size([1, 3, 8]) ├── 'b' --> torch.Size([1, 6]) └── 'c' --> <Size 0x7fce014012e0> ├── 'd' --> torch.Size([1, 1]) └── 'noise' --> torch.Size([1, 3, 4, 5]) |
The implement with treetensor API is much simpler and clearer.