Customized Operations For Different Fields¶
Here is another example of the custom operations implemented with both native torch API and treetensor API.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 | import copy import torch import treetensor.torch as ttorch T, B = 3, 4 def with_nativetensor(batch_): mean_b_list = [] even_index_a_list = [] for i in range(len(batch_)): for k, v in batch_[i].items(): if k == 'a': v = v.float() even_index_a_list.append(v[::2]) elif k == 'b': v = v.float() transformed_v = torch.pow(v, 2) + 1.0 mean_b_list.append(transformed_v.mean()) elif k == 'c': for k1, v1 in v.items(): if k1 == 'd': v1 = v1.float() else: print('ignore keys: {}'.format(k1)) else: print('ignore keys: {}'.format(k)) for i in range(len(batch_)): for k in batch_[i].keys(): if k == 'd': batch_[i][k]['noise'] = torch.randn(size=(3, 4, 5)) mean_b = sum(mean_b_list) / len(mean_b_list) even_index_a = torch.stack(even_index_a_list, dim=0) return batch_, mean_b, even_index_a def with_treetensor(batch_): batch_ = [ttorch.tensor(b) for b in batch_] batch_ = ttorch.stack(batch_) batch_ = batch_.float() batch_.b = ttorch.pow(batch_.b, 2) + 1.0 batch_.c.noise = ttorch.randn(size=(B, 3, 4, 5)) mean_b = batch_.b.mean() even_index_a = batch_.a[:, ::2] batch_ = ttorch.split(batch_, split_size_or_sections=1, dim=0) return batch_, mean_b, even_index_a def get_data(): return { 'a': torch.rand(size=(T, 8)), 'b': torch.rand(size=(6,)), 'c': { 'd': torch.randint(0, 10, size=(1,)) } } if __name__ == "__main__": batch = [get_data() for _ in range(B)] batch0, mean0, even_index_a0 = with_nativetensor(copy.deepcopy(batch)) batch1, mean1, even_index_a1 = with_treetensor(copy.deepcopy(batch)) print(batch0) print('\n\n') print(batch1) assert torch.abs(mean0 - mean1) < 1e-6 print('mean0 & mean1:', mean0, mean1) print('\n') assert torch.abs((even_index_a0 - even_index_a1).max()) < 1e-6 print('even_index_a0:', even_index_a0) print('even_index_a1:', even_index_a1) assert len(batch0) == B assert len(batch1) == B assert isinstance(batch1[0], ttorch.Tensor) print(batch1[0].shape) |
The output should be like below, and all the assertions can be passed.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 | [{'a': tensor([[0.4529, 0.4845, 0.1233, 0.7727, 0.7163, 0.0018, 0.1644, 0.1536], [0.5998, 0.5347, 0.9329, 0.4599, 0.4904, 0.3431, 0.5684, 0.5958], [0.6771, 0.6037, 0.8489, 0.3818, 0.6618, 0.1331, 0.0240, 0.8878]]), 'b': tensor([0.2393, 0.0322, 0.8971, 0.4516, 0.1839, 0.5778]), 'c': {'d': tensor([3])}}, {'a': tensor([[0.2874, 0.2397, 0.1155, 0.6106, 0.2545, 0.3252, 0.6645, 0.2612], [0.7247, 0.1225, 0.2925, 0.6812, 0.5074, 0.1101, 0.6757, 0.4603], [0.1935, 0.7102, 0.5032, 0.3095, 0.1066, 0.0958, 0.3336, 0.3753]]), 'b': tensor([0.5000, 0.8592, 0.3209, 0.7188, 0.3207, 0.7602]), 'c': {'d': tensor([0])}}, {'a': tensor([[0.3435, 0.4422, 0.0218, 0.8240, 0.1337, 0.8707, 0.2374, 0.8759], [0.8631, 0.3800, 0.5087, 0.8458, 0.3355, 0.3682, 0.3533, 0.8632], [0.2984, 0.6672, 0.9125, 0.8459, 0.6046, 0.6513, 0.1312, 0.7935]]), 'b': tensor([0.7767, 0.8433, 0.1131, 0.4656, 0.8797, 0.6381]), 'c': {'d': tensor([9])}}, {'a': tensor([[0.3697, 0.8184, 0.2459, 0.7395, 0.6239, 0.6514, 0.2199, 0.8016], [0.1031, 0.6713, 0.0948, 0.8860, 0.7416, 0.7898, 0.6888, 0.8017], [0.2897, 0.4645, 0.3337, 0.8838, 0.4184, 0.5080, 0.7330, 0.3640]]), 'b': tensor([0.4962, 0.1462, 0.7673, 0.3855, 0.2231, 0.2650]), 'c': {'d': tensor([7])}}] (<Tensor 0x7fedcf41e520> ├── 'a' --> tensor([[[0.4529, 0.4845, 0.1233, 0.7727, 0.7163, 0.0018, 0.1644, 0.1536], │ [0.5998, 0.5347, 0.9329, 0.4599, 0.4904, 0.3431, 0.5684, 0.5958], │ [0.6771, 0.6037, 0.8489, 0.3818, 0.6618, 0.1331, 0.0240, 0.8878]]]) ├── 'b' --> tensor([[1.0573, 1.0010, 1.8047, 1.2039, 1.0338, 1.3338]]) └── 'c' --> <Tensor 0x7fedcf41e580> ├── 'd' --> tensor([[3.]]) └── 'noise' --> tensor([[[[-0.6636, -0.0559, -1.7296, 0.1993, -0.4144], [-0.9026, -0.1666, -1.0286, 0.9501, 0.2277], [ 0.1565, 0.7128, 1.3450, 0.1384, -0.2809], [ 2.3120, -0.6240, -0.0100, 0.0555, 0.1233]], [[ 0.7889, -0.3542, 0.5173, -0.1994, -0.2877], [ 0.5132, -0.2618, -0.2381, -0.5973, 2.0943], [ 2.5564, 0.3201, -0.6498, -0.7667, 0.2432], [-0.3787, -0.5635, 0.7659, 1.1013, -0.3738]], [[-2.0796, -1.2503, -0.8508, 0.4856, 0.3095], [-1.3404, 1.1604, -1.2025, 0.7917, -0.6804], [-1.3208, 0.5036, -1.2873, 0.1188, -0.5987], [ 1.3959, -1.5887, -1.0825, 0.2393, -1.0344]]]]) , <Tensor 0x7fedcf41e640> ├── 'a' --> tensor([[[0.2874, 0.2397, 0.1155, 0.6106, 0.2545, 0.3252, 0.6645, 0.2612], │ [0.7247, 0.1225, 0.2925, 0.6812, 0.5074, 0.1101, 0.6757, 0.4603], │ [0.1935, 0.7102, 0.5032, 0.3095, 0.1066, 0.0958, 0.3336, 0.3753]]]) ├── 'b' --> tensor([[1.2500, 1.7382, 1.1030, 1.5166, 1.1029, 1.5779]]) └── 'c' --> <Tensor 0x7fedcf41e490> ├── 'd' --> tensor([[0.]]) └── 'noise' --> tensor([[[[-0.6272, -0.8048, -0.1083, -2.0494, -0.5725], [ 0.4207, 0.5677, -2.1537, 0.0689, -0.9418], [ 0.3334, 0.0102, -2.2345, -0.4336, 0.6982], [-0.7987, -1.1848, -0.3990, 0.9738, 1.0983]], [[-0.1831, -1.0681, 1.5724, 0.8108, -0.6631], [-0.1058, -0.2816, -0.4718, 1.8463, -2.0032], [-1.1256, -0.7793, 0.2796, 0.9440, -0.1748], [-0.0289, -1.1661, -1.0366, 0.2706, -0.0884]], [[ 0.3446, -1.2585, 0.0536, -1.0360, -2.0640], [-0.2337, -0.2699, -0.0259, -0.5141, 0.7299], [-0.3413, -0.0165, 0.3234, 1.4532, 0.8804], [ 0.6048, -0.4461, 0.2000, -0.8728, 1.8801]]]]) , <Tensor 0x7fedcf41e6a0> ├── 'a' --> tensor([[[0.3435, 0.4422, 0.0218, 0.8240, 0.1337, 0.8707, 0.2374, 0.8759], │ [0.8631, 0.3800, 0.5087, 0.8458, 0.3355, 0.3682, 0.3533, 0.8632], │ [0.2984, 0.6672, 0.9125, 0.8459, 0.6046, 0.6513, 0.1312, 0.7935]]]) ├── 'b' --> tensor([[1.6033, 1.7111, 1.0128, 1.2168, 1.7738, 1.4071]]) └── 'c' --> <Tensor 0x7fedcf41e670> ├── 'd' --> tensor([[9.]]) └── 'noise' --> tensor([[[[ 0.9333, 1.1424, -0.7474, -0.8829, 0.9201], [ 0.2274, -1.8195, -1.6136, -0.2165, -1.0574], [ 1.4707, -1.4699, -0.7836, -0.4376, 1.3080], [-0.9974, 0.2809, 0.4782, -1.5480, -0.7322]], [[-0.1923, 0.1196, -0.6996, 0.5642, -0.3245], [-0.7999, 1.7541, 0.7113, -0.3174, 1.5693], [-0.5462, -1.2245, 0.2500, 0.4545, 0.5209], [ 1.0194, 1.3523, -0.7063, 0.0980, 0.2298]], [[ 0.5091, 1.2884, -2.4875, -0.6895, 1.3439], [-0.1998, 0.0077, 0.9488, 0.1471, 1.2440], [ 1.7300, 1.9349, -0.3384, 0.1602, 1.2618], [-0.0880, -1.4244, 0.1738, 0.2489, 0.2008]]]]) , <Tensor 0x7fedcf41e700> ├── 'a' --> tensor([[[0.3697, 0.8184, 0.2459, 0.7395, 0.6239, 0.6514, 0.2199, 0.8016], │ [0.1031, 0.6713, 0.0948, 0.8860, 0.7416, 0.7898, 0.6888, 0.8017], │ [0.2897, 0.4645, 0.3337, 0.8838, 0.4184, 0.5080, 0.7330, 0.3640]]]) ├── 'b' --> tensor([[1.2462, 1.0214, 1.5887, 1.1486, 1.0498, 1.0702]]) └── 'c' --> <Tensor 0x7fedcf41e6d0> ├── 'd' --> tensor([[7.]]) └── 'noise' --> tensor([[[[ 1.3262, 3.0024, -1.3932, -1.0703, 2.0210], [-0.2846, 1.1873, -0.5566, 0.4753, 0.6220], [ 1.3426, 0.1599, 1.3555, -1.1109, 1.4186], [-1.2408, 1.6457, -1.3557, 0.4189, -1.4630]], [[-1.0050, 0.7207, -0.9922, -0.0952, 0.5409], [-0.7450, 1.3094, -1.5485, 0.3075, 1.0038], [ 0.4184, -0.3876, -0.3017, 0.2980, -1.4328], [-0.2267, -0.9949, -0.4961, -1.3194, -1.5610]], [[-0.6511, -0.1034, -0.9375, 0.9415, -1.5765], [ 0.2072, -0.7496, 0.4781, -1.7320, 1.1552], [-1.0410, 0.3267, -0.0955, 0.3992, 0.1344], [-0.0522, 1.7607, -2.5081, -1.0226, -0.6805]]]]) ) mean0 & mean1: tensor(1.3155) tensor(1.3155) even_index_a0: tensor([[[0.4529, 0.4845, 0.1233, 0.7727, 0.7163, 0.0018, 0.1644, 0.1536], [0.6771, 0.6037, 0.8489, 0.3818, 0.6618, 0.1331, 0.0240, 0.8878]], [[0.2874, 0.2397, 0.1155, 0.6106, 0.2545, 0.3252, 0.6645, 0.2612], [0.1935, 0.7102, 0.5032, 0.3095, 0.1066, 0.0958, 0.3336, 0.3753]], [[0.3435, 0.4422, 0.0218, 0.8240, 0.1337, 0.8707, 0.2374, 0.8759], [0.2984, 0.6672, 0.9125, 0.8459, 0.6046, 0.6513, 0.1312, 0.7935]], [[0.3697, 0.8184, 0.2459, 0.7395, 0.6239, 0.6514, 0.2199, 0.8016], [0.2897, 0.4645, 0.3337, 0.8838, 0.4184, 0.5080, 0.7330, 0.3640]]]) even_index_a1: tensor([[[0.4529, 0.4845, 0.1233, 0.7727, 0.7163, 0.0018, 0.1644, 0.1536], [0.6771, 0.6037, 0.8489, 0.3818, 0.6618, 0.1331, 0.0240, 0.8878]], [[0.2874, 0.2397, 0.1155, 0.6106, 0.2545, 0.3252, 0.6645, 0.2612], [0.1935, 0.7102, 0.5032, 0.3095, 0.1066, 0.0958, 0.3336, 0.3753]], [[0.3435, 0.4422, 0.0218, 0.8240, 0.1337, 0.8707, 0.2374, 0.8759], [0.2984, 0.6672, 0.9125, 0.8459, 0.6046, 0.6513, 0.1312, 0.7935]], [[0.3697, 0.8184, 0.2459, 0.7395, 0.6239, 0.6514, 0.2199, 0.8016], [0.2897, 0.4645, 0.3337, 0.8838, 0.4184, 0.5080, 0.7330, 0.3640]]]) <Size 0x7fee2f323460> ├── 'a' --> torch.Size([1, 3, 8]) ├── 'b' --> torch.Size([1, 6]) └── 'c' --> <Size 0x7fedcf49d100> ├── 'd' --> torch.Size([1, 1]) └── 'noise' --> torch.Size([1, 3, 4, 5]) |
The implement with treetensor API is much simpler and clearer.