Customized Operations For Different Fields¶
Here is another example of the custom operations implemented with both native torch API and treetensor API.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 | import copy import torch import treetensor.torch as ttorch T, B = 3, 4 def with_nativetensor(batch_): mean_b_list = [] even_index_a_list = [] for i in range(len(batch_)): for k, v in batch_[i].items(): if k == 'a': v = v.float() even_index_a_list.append(v[::2]) elif k == 'b': v = v.float() transformed_v = torch.pow(v, 2) + 1.0 mean_b_list.append(transformed_v.mean()) elif k == 'c': for k1, v1 in v.items(): if k1 == 'd': v1 = v1.float() else: print('ignore keys: {}'.format(k1)) else: print('ignore keys: {}'.format(k)) for i in range(len(batch_)): for k in batch_[i].keys(): if k == 'd': batch_[i][k]['noise'] = torch.randn(size=(3, 4, 5)) mean_b = sum(mean_b_list) / len(mean_b_list) even_index_a = torch.stack(even_index_a_list, dim=0) return batch_, mean_b, even_index_a def with_treetensor(batch_): batch_ = [ttorch.tensor(b) for b in batch_] batch_ = ttorch.stack(batch_) batch_ = batch_.float() batch_.b = ttorch.pow(batch_.b, 2) + 1.0 batch_.c.noise = ttorch.randn(size=(B, 3, 4, 5)) mean_b = batch_.b.mean() even_index_a = batch_.a[:, ::2] batch_ = ttorch.split(batch_, split_size_or_sections=1, dim=0) return batch_, mean_b, even_index_a def get_data(): return { 'a': torch.rand(size=(T, 8)), 'b': torch.rand(size=(6,)), 'c': { 'd': torch.randint(0, 10, size=(1,)) } } if __name__ == "__main__": batch = [get_data() for _ in range(B)] batch0, mean0, even_index_a0 = with_nativetensor(copy.deepcopy(batch)) batch1, mean1, even_index_a1 = with_treetensor(copy.deepcopy(batch)) print(batch0) print('\n\n') print(batch1) assert torch.abs(mean0 - mean1) < 1e-6 print('mean0 & mean1:', mean0, mean1) print('\n') assert torch.abs((even_index_a0 - even_index_a1).max()) < 1e-6 print('even_index_a0:', even_index_a0) print('even_index_a1:', even_index_a1) assert len(batch0) == B assert len(batch1) == B assert isinstance(batch1[0], ttorch.Tensor) print(batch1[0].shape) |
The output should be like below, and all the assertions can be passed.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 | [{'a': tensor([[0.6796, 0.8915, 0.1787, 0.7100, 0.1271, 0.6469, 0.8311, 0.6974], [0.4011, 0.6651, 0.7264, 0.4408, 0.8056, 0.3359, 0.9172, 0.8944], [0.5826, 0.7113, 0.1235, 0.6089, 0.1268, 0.2385, 0.2617, 0.9369]]), 'b': tensor([0.8683, 0.6075, 0.1722, 0.4260, 0.3023, 0.8900]), 'c': {'d': tensor([8])}}, {'a': tensor([[0.3677, 0.1415, 0.0893, 0.2518, 0.4041, 0.0492, 0.1923, 0.9586], [0.7426, 0.5916, 0.2602, 0.1307, 0.2072, 0.5941, 0.3626, 0.4814], [0.1689, 0.1465, 0.3012, 0.5370, 0.9788, 0.1135, 0.1176, 0.7977]]), 'b': tensor([0.8620, 0.7494, 0.6395, 0.2251, 0.6765, 0.9543]), 'c': {'d': tensor([3])}}, {'a': tensor([[0.2679, 0.1777, 0.1735, 0.3921, 0.2183, 0.4208, 0.0421, 0.6751], [0.9911, 0.3462, 0.8812, 0.8405, 0.3927, 0.9823, 0.9971, 0.8806], [0.7546, 0.0281, 0.2463, 0.4958, 0.7946, 0.8709, 0.1290, 0.1247]]), 'b': tensor([0.2163, 0.9633, 0.5466, 0.7943, 0.6659, 0.0497]), 'c': {'d': tensor([7])}}, {'a': tensor([[0.8043, 0.8778, 0.7910, 0.1306, 0.5225, 0.7919, 0.3902, 0.8837], [0.7651, 0.6421, 0.3727, 0.4784, 0.8094, 0.1772, 0.0986, 0.5673], [0.7576, 0.4917, 0.2139, 0.5751, 0.2011, 0.6180, 0.3734, 0.7506]]), 'b': tensor([0.2779, 0.9648, 0.7230, 0.2130, 0.3859, 0.4441]), 'c': {'d': tensor([4])}}] (<Tensor 0x7f4a8d72bdc0> ├── 'a' --> tensor([[[0.6796, 0.8915, 0.1787, 0.7100, 0.1271, 0.6469, 0.8311, 0.6974], │ [0.4011, 0.6651, 0.7264, 0.4408, 0.8056, 0.3359, 0.9172, 0.8944], │ [0.5826, 0.7113, 0.1235, 0.6089, 0.1268, 0.2385, 0.2617, 0.9369]]]) ├── 'b' --> tensor([[1.7539, 1.3691, 1.0296, 1.1814, 1.0914, 1.7922]]) └── 'c' --> <Tensor 0x7f4a8d72be20> ├── 'd' --> tensor([[8.]]) └── 'noise' --> tensor([[[[-0.8381, 2.1321, -0.9444, -0.7191, -1.3926], [-0.2115, -0.5253, -0.7440, -1.7407, -0.7213], [ 0.2595, -1.0568, -0.2872, 0.2834, 0.8790], [ 0.1232, -1.4918, 0.9498, -0.7962, -0.0331]], [[-0.0877, -1.0657, 0.5894, -0.8115, -1.1342], [ 0.9418, -0.5228, -0.2236, 0.8553, -0.2079], [ 0.8717, 0.0194, -0.3866, -1.3372, 0.9505], [-0.5231, -0.1153, 1.2639, 1.8143, -2.3913]], [[-0.2092, -0.6510, 0.7318, 0.2611, -0.5800], [ 0.2776, -0.1129, 1.3485, -0.8305, -0.3350], [ 0.5349, -0.2376, -0.5073, -0.0369, -0.2382], [ 1.6290, -0.2432, 0.3592, -0.7307, -1.1146]]]]) , <Tensor 0x7f4a8d72bee0> ├── 'a' --> tensor([[[0.3677, 0.1415, 0.0893, 0.2518, 0.4041, 0.0492, 0.1923, 0.9586], │ [0.7426, 0.5916, 0.2602, 0.1307, 0.2072, 0.5941, 0.3626, 0.4814], │ [0.1689, 0.1465, 0.3012, 0.5370, 0.9788, 0.1135, 0.1176, 0.7977]]]) ├── 'b' --> tensor([[1.7430, 1.5616, 1.4090, 1.0507, 1.4577, 1.9107]]) └── 'c' --> <Tensor 0x7f4a8d72bd30> ├── 'd' --> tensor([[3.]]) └── 'noise' --> tensor([[[[ 2.0865, -0.3595, -0.6910, 1.0244, -0.0228], [ 0.9128, 0.4903, 0.1303, -1.5050, 1.0615], [-0.0071, -0.2558, 0.1073, 0.5586, 0.6103], [-0.3212, 1.2623, 1.6619, 0.6409, 0.5480]], [[-0.5059, -0.9238, 0.7028, -0.7919, 0.7945], [ 1.2895, 1.2693, -0.0658, -0.2758, -0.1439], [ 1.5293, 0.5293, -1.5634, -0.1730, 2.1593], [ 1.8063, 1.2966, -1.2386, -0.2025, 1.0203]], [[-0.3223, 0.7071, 1.2362, 0.9568, 0.5885], [ 1.3537, 0.0566, 1.3208, 0.4373, 1.7377], [-0.9590, 0.1773, -0.3827, -0.3538, -2.9457], [ 1.1134, 0.4801, 0.2854, 0.2599, 0.6335]]]]) , <Tensor 0x7f4a8d72bf40> ├── 'a' --> tensor([[[0.2679, 0.1777, 0.1735, 0.3921, 0.2183, 0.4208, 0.0421, 0.6751], │ [0.9911, 0.3462, 0.8812, 0.8405, 0.3927, 0.9823, 0.9971, 0.8806], │ [0.7546, 0.0281, 0.2463, 0.4958, 0.7946, 0.8709, 0.1290, 0.1247]]]) ├── 'b' --> tensor([[1.0468, 1.9280, 1.2988, 1.6309, 1.4434, 1.0025]]) └── 'c' --> <Tensor 0x7f4a8d72bf10> ├── 'd' --> tensor([[7.]]) └── 'noise' --> tensor([[[[ 0.0507, 0.8475, 1.9797, 1.5418, 0.4959], [ 0.2704, 1.2632, -0.9198, -0.3680, 0.1904], [ 0.4642, 0.6332, 1.3865, -0.9788, -1.0222], [-0.9968, -0.8438, 1.1953, -0.3468, 1.8236]], [[-2.4549, -0.6488, 0.8985, 0.7699, -1.3897], [-0.2723, 1.1534, -1.5565, 0.2130, 1.5684], [-1.1921, 0.1810, -0.9803, -1.6161, -1.3742], [-0.4538, -0.0770, 1.6496, -1.5832, -0.5101]], [[ 1.4290, 1.1963, -0.4897, 0.0123, 1.0885], [-0.2512, -1.6909, 0.1550, 0.2726, -0.6209], [ 2.0050, 0.6278, 0.3968, 0.5575, 0.4668], [ 0.0547, 0.7338, 1.6606, 0.4362, -1.9252]]]]) , <Tensor 0x7f4a8d72bfa0> ├── 'a' --> tensor([[[0.8043, 0.8778, 0.7910, 0.1306, 0.5225, 0.7919, 0.3902, 0.8837], │ [0.7651, 0.6421, 0.3727, 0.4784, 0.8094, 0.1772, 0.0986, 0.5673], │ [0.7576, 0.4917, 0.2139, 0.5751, 0.2011, 0.6180, 0.3734, 0.7506]]]) ├── 'b' --> tensor([[1.0772, 1.9309, 1.5227, 1.0454, 1.1489, 1.1972]]) └── 'c' --> <Tensor 0x7f4a8d72bf70> ├── 'd' --> tensor([[4.]]) └── 'noise' --> tensor([[[[-8.6057e-01, 1.9909e+00, -1.1341e+00, -3.1219e-01, -1.0579e+00], [ 4.2417e-01, 9.0933e-01, 1.5546e+00, 2.4609e+00, -2.7219e+00], [ 3.5179e-01, 4.9726e-01, 6.9975e-01, -6.5159e-01, -1.1356e+00], [ 2.8656e-01, -1.7430e-02, 1.6347e+00, 7.6030e-01, -8.6411e-01]], [[ 6.5482e-01, -3.3193e-01, 1.8156e+00, -1.2931e-01, -1.5559e+00], [-3.2648e-01, 8.0178e-01, -3.0099e-01, 1.9234e+00, 5.6460e-01], [-3.3813e-01, -1.6584e+00, -6.7409e-01, 4.0787e-01, -3.6855e-01], [-1.2040e+00, 3.6019e-01, -6.1413e-01, 1.2220e+00, -9.5071e-01]], [[-6.4820e-01, 1.3130e+00, -3.1143e-01, 7.4179e-02, 9.6257e-01], [ 1.7304e+00, -4.4796e-01, 9.9969e-01, 1.5349e-01, 3.0027e-02], [ 7.0398e-01, -3.9019e-01, 1.1322e+00, -2.1520e+00, 4.9974e-01], [ 1.4728e+00, -1.9006e-01, -5.1278e-01, -6.7268e-01, 1.4833e-03]]]]) ) mean0 & mean1: tensor(1.4009) tensor(1.4009) even_index_a0: tensor([[[0.6796, 0.8915, 0.1787, 0.7100, 0.1271, 0.6469, 0.8311, 0.6974], [0.5826, 0.7113, 0.1235, 0.6089, 0.1268, 0.2385, 0.2617, 0.9369]], [[0.3677, 0.1415, 0.0893, 0.2518, 0.4041, 0.0492, 0.1923, 0.9586], [0.1689, 0.1465, 0.3012, 0.5370, 0.9788, 0.1135, 0.1176, 0.7977]], [[0.2679, 0.1777, 0.1735, 0.3921, 0.2183, 0.4208, 0.0421, 0.6751], [0.7546, 0.0281, 0.2463, 0.4958, 0.7946, 0.8709, 0.1290, 0.1247]], [[0.8043, 0.8778, 0.7910, 0.1306, 0.5225, 0.7919, 0.3902, 0.8837], [0.7576, 0.4917, 0.2139, 0.5751, 0.2011, 0.6180, 0.3734, 0.7506]]]) even_index_a1: tensor([[[0.6796, 0.8915, 0.1787, 0.7100, 0.1271, 0.6469, 0.8311, 0.6974], [0.5826, 0.7113, 0.1235, 0.6089, 0.1268, 0.2385, 0.2617, 0.9369]], [[0.3677, 0.1415, 0.0893, 0.2518, 0.4041, 0.0492, 0.1923, 0.9586], [0.1689, 0.1465, 0.3012, 0.5370, 0.9788, 0.1135, 0.1176, 0.7977]], [[0.2679, 0.1777, 0.1735, 0.3921, 0.2183, 0.4208, 0.0421, 0.6751], [0.7546, 0.0281, 0.2463, 0.4958, 0.7946, 0.8709, 0.1290, 0.1247]], [[0.8043, 0.8778, 0.7910, 0.1306, 0.5225, 0.7919, 0.3902, 0.8837], [0.7576, 0.4917, 0.2139, 0.5751, 0.2011, 0.6180, 0.3734, 0.7506]]]) <Size 0x7f4aed616460> ├── 'a' --> torch.Size([1, 3, 8]) ├── 'b' --> torch.Size([1, 6]) └── 'c' --> <Size 0x7f4a8d7a8460> ├── 'd' --> torch.Size([1, 1]) └── 'noise' --> torch.Size([1, 3, 4, 5]) |
The implement with treetensor API is much simpler and clearer.