Customized Operations For Different Fields¶
Here is another example of the custom operations implemented with both native torch API and treetensor API.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 | import copy import torch import treetensor.torch as ttorch T, B = 3, 4 def with_nativetensor(batch_): mean_b_list = [] even_index_a_list = [] for i in range(len(batch_)): for k, v in batch_[i].items(): if k == 'a': v = v.float() even_index_a_list.append(v[::2]) elif k == 'b': v = v.float() transformed_v = torch.pow(v, 2) + 1.0 mean_b_list.append(transformed_v.mean()) elif k == 'c': for k1, v1 in v.items(): if k1 == 'd': v1 = v1.float() else: print('ignore keys: {}'.format(k1)) else: print('ignore keys: {}'.format(k)) for i in range(len(batch_)): for k in batch_[i].keys(): if k == 'd': batch_[i][k]['noise'] = torch.randn(size=(3, 4, 5)) mean_b = sum(mean_b_list) / len(mean_b_list) even_index_a = torch.stack(even_index_a_list, dim=0) return batch_, mean_b, even_index_a def with_treetensor(batch_): batch_ = [ttorch.tensor(b) for b in batch_] batch_ = ttorch.stack(batch_) batch_ = batch_.float() batch_.b = ttorch.pow(batch_.b, 2) + 1.0 batch_.c.noise = ttorch.randn(size=(B, 3, 4, 5)) mean_b = batch_.b.mean() even_index_a = batch_.a[:, ::2] batch_ = ttorch.split(batch_, split_size_or_sections=1, dim=0) return batch_, mean_b, even_index_a def get_data(): return { 'a': torch.rand(size=(T, 8)), 'b': torch.rand(size=(6,)), 'c': { 'd': torch.randint(0, 10, size=(1,)) } } if __name__ == "__main__": batch = [get_data() for _ in range(B)] batch0, mean0, even_index_a0 = with_nativetensor(copy.deepcopy(batch)) batch1, mean1, even_index_a1 = with_treetensor(copy.deepcopy(batch)) print(batch0) print('\n\n') print(batch1) assert torch.abs(mean0 - mean1) < 1e-6 print('mean0 & mean1:', mean0, mean1) print('\n') assert torch.abs((even_index_a0 - even_index_a1).max()) < 1e-6 print('even_index_a0:', even_index_a0) print('even_index_a1:', even_index_a1) assert len(batch0) == B assert len(batch1) == B assert isinstance(batch1[0], ttorch.Tensor) print(batch1[0].shape) |
The output should be like below, and all the assertions can be passed.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 | [{'a': tensor([[0.2036, 0.2045, 0.6618, 0.7467, 0.9715, 0.8493, 0.1803, 0.1353], [0.3565, 0.5454, 0.2573, 0.7043, 0.3618, 0.4075, 0.2910, 0.6184], [0.6495, 0.0149, 0.2179, 0.4534, 0.0662, 0.8949, 0.7714, 0.0765]]), 'b': tensor([0.9334, 0.9717, 0.6869, 0.6752, 0.8431, 0.4989]), 'c': {'d': tensor([9])}}, {'a': tensor([[0.0783, 0.0599, 0.7010, 0.8613, 0.4005, 0.4793, 0.1911, 0.2949], [0.7930, 0.0374, 0.3266, 0.2080, 0.3605, 0.1079, 0.7367, 0.5571], [0.2806, 0.3294, 0.0518, 0.6544, 0.1504, 0.3349, 0.0295, 0.1594]]), 'b': tensor([0.4440, 0.2503, 0.3370, 0.5850, 0.0929, 0.7613]), 'c': {'d': tensor([4])}}, {'a': tensor([[0.8803, 0.2766, 0.7276, 0.9788, 0.4364, 0.0951, 0.8465, 0.6436], [0.0128, 0.4066, 0.5662, 0.2270, 0.3613, 0.2894, 0.9197, 0.0907], [0.8925, 0.5813, 0.4148, 0.1561, 0.9737, 0.9314, 0.9608, 0.0347]]), 'b': tensor([0.3093, 0.2831, 0.4555, 0.2733, 0.2475, 0.3501]), 'c': {'d': tensor([5])}}, {'a': tensor([[0.4850, 0.2218, 0.1663, 0.4499, 0.7313, 0.4146, 0.5337, 0.7840], [0.1885, 0.0900, 0.1045, 0.8186, 0.5989, 0.0149, 0.5255, 0.6977], [0.7255, 0.8035, 0.7201, 0.3028, 0.8759, 0.7901, 0.9948, 0.7760]]), 'b': tensor([0.5845, 0.2858, 0.6358, 0.8988, 0.7272, 0.6260]), 'c': {'d': tensor([5])}}] (<Tensor 0x7f85d1f09dc0> ├── 'a' --> tensor([[[0.2036, 0.2045, 0.6618, 0.7467, 0.9715, 0.8493, 0.1803, 0.1353], │ [0.3565, 0.5454, 0.2573, 0.7043, 0.3618, 0.4075, 0.2910, 0.6184], │ [0.6495, 0.0149, 0.2179, 0.4534, 0.0662, 0.8949, 0.7714, 0.0765]]]) ├── 'b' --> tensor([[1.8712, 1.9442, 1.4718, 1.4559, 1.7109, 1.2489]]) └── 'c' --> <Tensor 0x7f85d1f09e20> ├── 'd' --> tensor([[9.]]) └── 'noise' --> tensor([[[[-1.3222, -0.4631, -1.1493, 0.2412, 1.7066], [ 0.6862, 0.1240, 0.0722, -0.3273, 0.6220], [ 1.0520, -0.0453, 0.0042, 0.3125, 1.0273], [ 1.0052, -1.4199, 0.5473, -1.2724, -0.4614]], [[-0.8970, 0.6137, 0.7893, 0.7652, 0.6276], [ 0.3589, -0.9416, 1.9306, -0.3673, -0.6270], [ 1.3039, -0.0854, 0.0807, -0.3006, 0.5846], [ 0.9596, -0.3978, -2.7955, 1.0804, 0.8696]], [[ 0.5054, -0.7677, -0.5725, 0.2164, 1.1939], [ 0.2597, 0.6918, 1.1359, 1.7897, -0.0778], [-0.1170, -1.7965, -1.0321, 0.9932, -1.1103], [-0.7336, 0.0222, 0.7134, 0.1178, 0.2860]]]]) , <Tensor 0x7f85d1f09ee0> ├── 'a' --> tensor([[[0.0783, 0.0599, 0.7010, 0.8613, 0.4005, 0.4793, 0.1911, 0.2949], │ [0.7930, 0.0374, 0.3266, 0.2080, 0.3605, 0.1079, 0.7367, 0.5571], │ [0.2806, 0.3294, 0.0518, 0.6544, 0.1504, 0.3349, 0.0295, 0.1594]]]) ├── 'b' --> tensor([[1.1971, 1.0627, 1.1136, 1.3422, 1.0086, 1.5797]]) └── 'c' --> <Tensor 0x7f85d1f09d30> ├── 'd' --> tensor([[4.]]) └── 'noise' --> tensor([[[[ 2.5355, -1.2246, 1.1552, 1.1878, -0.8264], [ 0.1038, -1.9889, -0.3179, -0.9583, -0.0319], [ 0.8137, 0.9206, 1.5439, -0.4607, 1.3009], [-0.4685, -0.0254, 0.3119, -1.2403, -1.3670]], [[-0.7800, 1.1956, 1.4925, -0.6537, 0.9278], [ 0.8594, 0.5685, 0.9859, -0.7464, -1.3470], [ 1.1382, 0.3734, -0.1874, 0.2087, -0.3723], [-1.9231, -0.6158, -0.9289, 0.3163, 1.1917]], [[-1.1180, 0.0056, -1.4290, -0.8379, -0.4956], [ 1.1919, -1.2769, 0.2825, 0.7284, 0.3196], [ 0.7649, -1.2728, -1.7425, -0.1880, -0.7062], [ 0.4846, -0.7361, -3.2466, 0.7879, 0.4578]]]]) , <Tensor 0x7f85d1f09f40> ├── 'a' --> tensor([[[0.8803, 0.2766, 0.7276, 0.9788, 0.4364, 0.0951, 0.8465, 0.6436], │ [0.0128, 0.4066, 0.5662, 0.2270, 0.3613, 0.2894, 0.9197, 0.0907], │ [0.8925, 0.5813, 0.4148, 0.1561, 0.9737, 0.9314, 0.9608, 0.0347]]]) ├── 'b' --> tensor([[1.0957, 1.0802, 1.2075, 1.0747, 1.0613, 1.1226]]) └── 'c' --> <Tensor 0x7f85d1f09f10> ├── 'd' --> tensor([[5.]]) └── 'noise' --> tensor([[[[ 1.0207, 0.3434, -0.8973, 0.1341, -0.4765], [-0.7483, 0.3544, -1.0095, 1.2414, 0.0488], [ 0.5660, -0.2623, 0.7233, 1.1281, -0.0232], [-1.6810, -0.4228, -1.9214, 0.1472, 0.2862]], [[-0.0974, -1.3113, 0.5264, -2.0024, -0.3385], [-0.5324, 0.9494, 0.5736, -1.0252, -0.5890], [ 1.0508, -1.2013, -0.0814, -0.7333, -0.9893], [-0.1241, -0.3741, -1.6691, 0.6177, -1.5291]], [[-2.3571, -0.5609, -0.9252, -0.5954, 0.1488], [-1.2112, -0.7898, 0.7132, -0.7785, -0.2986], [ 1.8967, -0.5564, 0.8275, 0.2167, -1.2242], [ 0.7123, 1.0096, -0.7093, 0.2377, 0.9861]]]]) , <Tensor 0x7f85d1f09fa0> ├── 'a' --> tensor([[[0.4850, 0.2218, 0.1663, 0.4499, 0.7313, 0.4146, 0.5337, 0.7840], │ [0.1885, 0.0900, 0.1045, 0.8186, 0.5989, 0.0149, 0.5255, 0.6977], │ [0.7255, 0.8035, 0.7201, 0.3028, 0.8759, 0.7901, 0.9948, 0.7760]]]) ├── 'b' --> tensor([[1.3417, 1.0817, 1.4043, 1.8078, 1.5288, 1.3919]]) └── 'c' --> <Tensor 0x7f85d1f09f70> ├── 'd' --> tensor([[5.]]) └── 'noise' --> tensor([[[[ 0.3787, -1.1965, -1.2048, -0.3430, 0.6383], [-0.7527, -1.5821, -1.4492, 1.0375, -0.0591], [-0.1333, -1.5517, -0.5212, -1.8166, -2.1628], [-3.0129, -0.8093, -0.7022, -0.3460, 1.9479]], [[ 1.4103, 0.3771, 1.0118, -0.4478, 0.7042], [ 0.8111, 0.4685, 0.2835, 0.5010, -0.1655], [ 1.2953, 0.0749, -0.5532, -0.5583, -0.0162], [-0.5610, 0.8703, -0.4858, 0.0177, 1.0945]], [[ 0.0955, -1.1814, -1.7059, 1.0074, -1.0350], [ 0.5058, 0.2503, -0.7092, -0.5157, -1.2040], [ 1.0195, -0.7076, -1.3974, -0.0386, -1.8758], [ 0.4644, 0.4292, -0.7586, 0.6937, 0.5348]]]]) ) mean0 & mean1: tensor(1.3419) tensor(1.3419) even_index_a0: tensor([[[0.2036, 0.2045, 0.6618, 0.7467, 0.9715, 0.8493, 0.1803, 0.1353], [0.6495, 0.0149, 0.2179, 0.4534, 0.0662, 0.8949, 0.7714, 0.0765]], [[0.0783, 0.0599, 0.7010, 0.8613, 0.4005, 0.4793, 0.1911, 0.2949], [0.2806, 0.3294, 0.0518, 0.6544, 0.1504, 0.3349, 0.0295, 0.1594]], [[0.8803, 0.2766, 0.7276, 0.9788, 0.4364, 0.0951, 0.8465, 0.6436], [0.8925, 0.5813, 0.4148, 0.1561, 0.9737, 0.9314, 0.9608, 0.0347]], [[0.4850, 0.2218, 0.1663, 0.4499, 0.7313, 0.4146, 0.5337, 0.7840], [0.7255, 0.8035, 0.7201, 0.3028, 0.8759, 0.7901, 0.9948, 0.7760]]]) even_index_a1: tensor([[[0.2036, 0.2045, 0.6618, 0.7467, 0.9715, 0.8493, 0.1803, 0.1353], [0.6495, 0.0149, 0.2179, 0.4534, 0.0662, 0.8949, 0.7714, 0.0765]], [[0.0783, 0.0599, 0.7010, 0.8613, 0.4005, 0.4793, 0.1911, 0.2949], [0.2806, 0.3294, 0.0518, 0.6544, 0.1504, 0.3349, 0.0295, 0.1594]], [[0.8803, 0.2766, 0.7276, 0.9788, 0.4364, 0.0951, 0.8465, 0.6436], [0.8925, 0.5813, 0.4148, 0.1561, 0.9737, 0.9314, 0.9608, 0.0347]], [[0.4850, 0.2218, 0.1663, 0.4499, 0.7313, 0.4146, 0.5337, 0.7840], [0.7255, 0.8035, 0.7201, 0.3028, 0.8759, 0.7901, 0.9948, 0.7760]]]) <Size 0x7f8631dfc460> ├── 'a' --> torch.Size([1, 3, 8]) ├── 'b' --> torch.Size([1, 6]) └── 'c' --> <Size 0x7f85d1f86460> ├── 'd' --> torch.Size([1, 1]) └── 'noise' --> torch.Size([1, 3, 4, 5]) |
The implement with treetensor API is much simpler and clearer.