Customized Operations For Different Fields¶
Here is another example of the custom operations implemented with both native torch API and treetensor API.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 | import copy import torch import treetensor.torch as ttorch T, B = 3, 4 def with_nativetensor(batch_): mean_b_list = [] even_index_a_list = [] for i in range(len(batch_)): for k, v in batch_[i].items(): if k == 'a': v = v.float() even_index_a_list.append(v[::2]) elif k == 'b': v = v.float() transformed_v = torch.pow(v, 2) + 1.0 mean_b_list.append(transformed_v.mean()) elif k == 'c': for k1, v1 in v.items(): if k1 == 'd': v1 = v1.float() else: print('ignore keys: {}'.format(k1)) else: print('ignore keys: {}'.format(k)) for i in range(len(batch_)): for k in batch_[i].keys(): if k == 'd': batch_[i][k]['noise'] = torch.randn(size=(3, 4, 5)) mean_b = sum(mean_b_list) / len(mean_b_list) even_index_a = torch.stack(even_index_a_list, dim=0) return batch_, mean_b, even_index_a def with_treetensor(batch_): batch_ = [ttorch.tensor(b) for b in batch_] batch_ = ttorch.stack(batch_) batch_ = batch_.float() batch_.b = ttorch.pow(batch_.b, 2) + 1.0 batch_.c.noise = ttorch.randn(size=(B, 3, 4, 5)) mean_b = batch_.b.mean() even_index_a = batch_.a[:, ::2] batch_ = ttorch.split(batch_, split_size_or_sections=1, dim=0) return batch_, mean_b, even_index_a def get_data(): return { 'a': torch.rand(size=(T, 8)), 'b': torch.rand(size=(6,)), 'c': { 'd': torch.randint(0, 10, size=(1,)) } } if __name__ == "__main__": batch = [get_data() for _ in range(B)] batch0, mean0, even_index_a0 = with_nativetensor(copy.deepcopy(batch)) batch1, mean1, even_index_a1 = with_treetensor(copy.deepcopy(batch)) print(batch0) print('\n\n') print(batch1) assert torch.abs(mean0 - mean1) < 1e-6 print('mean0 & mean1:', mean0, mean1) print('\n') assert torch.abs((even_index_a0 - even_index_a1).max()) < 1e-6 print('even_index_a0:', even_index_a0) print('even_index_a1:', even_index_a1) assert len(batch0) == B assert len(batch1) == B assert isinstance(batch1[0], ttorch.Tensor) print(batch1[0].shape) |
The output should be like below, and all the assertions can be passed.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 | [{'a': tensor([[0.1018, 0.1229, 0.9646, 0.9868, 0.1387, 0.0339, 0.1518, 0.7598], [0.6458, 0.7569, 0.2063, 0.3366, 0.7734, 0.7482, 0.6035, 0.8083], [0.9991, 0.6408, 0.8581, 0.5391, 0.9900, 0.7575, 0.9536, 0.2315]]), 'b': tensor([0.6318, 0.3348, 0.1966, 0.4083, 0.5994, 0.5446]), 'c': {'d': tensor([1])}}, {'a': tensor([[0.3198, 0.6870, 0.1427, 0.1506, 0.2779, 0.4413, 0.7380, 0.3396], [0.3949, 0.3718, 0.5813, 0.4857, 0.6542, 0.0823, 0.5097, 0.6795], [0.1236, 0.3174, 0.1848, 0.9525, 0.8351, 0.1266, 0.8205, 0.8777]]), 'b': tensor([0.3457, 0.1058, 0.0081, 0.0902, 0.2323, 0.8620]), 'c': {'d': tensor([6])}}, {'a': tensor([[0.6110, 0.7065, 0.9817, 0.5466, 0.9681, 0.5271, 0.9148, 0.8300], [0.9745, 0.3634, 0.6490, 0.7711, 0.1651, 0.7651, 0.0820, 0.6859], [0.0513, 0.8733, 0.0833, 0.0243, 0.4405, 0.7640, 0.9343, 0.4131]]), 'b': tensor([0.6757, 0.6218, 0.2433, 0.3836, 0.1499, 0.0585]), 'c': {'d': tensor([1])}}, {'a': tensor([[0.7970, 0.5926, 0.9834, 0.6448, 0.3721, 0.1004, 0.0332, 0.3718], [0.8415, 0.5384, 0.6109, 0.2469, 0.3713, 0.3119, 0.5707, 0.2317], [0.3114, 0.2227, 0.2023, 0.6401, 0.1025, 0.0098, 0.7681, 0.4003]]), 'b': tensor([0.9396, 0.3754, 0.5673, 0.0941, 0.7585, 0.3193]), 'c': {'d': tensor([6])}}] (<Tensor 0x7f4c11e15ac0> ├── 'a' --> tensor([[[0.1018, 0.1229, 0.9646, 0.9868, 0.1387, 0.0339, 0.1518, 0.7598], │ [0.6458, 0.7569, 0.2063, 0.3366, 0.7734, 0.7482, 0.6035, 0.8083], │ [0.9991, 0.6408, 0.8581, 0.5391, 0.9900, 0.7575, 0.9536, 0.2315]]]) ├── 'b' --> tensor([[1.3991, 1.1121, 1.0386, 1.1667, 1.3592, 1.2966]]) └── 'c' --> <Tensor 0x7f4c11e15b20> ├── 'd' --> tensor([[1.]]) └── 'noise' --> tensor([[[[-0.5880, 2.2755, 0.7484, 0.0287, 0.3380], [ 0.4480, 0.3728, -1.5667, 0.9229, -0.0397], [ 1.1858, 0.5950, 1.3761, -0.5896, -0.3640], [ 0.1571, -1.1151, 0.3706, -0.1462, -0.4969]], [[-1.2671, 1.5281, 2.4829, -0.5331, 0.5884], [ 0.3366, 1.9647, -1.0251, 1.3016, -0.5457], [-0.6922, -0.3392, 1.2724, -1.5485, 0.5084], [ 0.6813, -0.4010, 1.5073, 1.5329, -0.5679]], [[ 0.8735, -0.3361, -0.1501, -1.1746, 0.4980], [-1.8489, 1.6243, -1.3073, -1.1780, 0.9063], [ 1.4177, 1.6821, 2.6253, -0.5108, 0.1157], [ 2.6025, -2.1618, 2.0695, -0.3414, 1.9084]]]]) , <Tensor 0x7f4c11e15be0> ├── 'a' --> tensor([[[0.3198, 0.6870, 0.1427, 0.1506, 0.2779, 0.4413, 0.7380, 0.3396], │ [0.3949, 0.3718, 0.5813, 0.4857, 0.6542, 0.0823, 0.5097, 0.6795], │ [0.1236, 0.3174, 0.1848, 0.9525, 0.8351, 0.1266, 0.8205, 0.8777]]]) ├── 'b' --> tensor([[1.1195, 1.0112, 1.0001, 1.0081, 1.0539, 1.7430]]) └── 'c' --> <Tensor 0x7f4c11e15a30> ├── 'd' --> tensor([[6.]]) └── 'noise' --> tensor([[[[ 0.5339, -0.5147, -1.0812, -2.5419, 0.8775], [-1.2760, 0.8684, -0.2687, 1.6961, 1.1588], [-0.6761, 0.1458, 0.5305, 1.5527, -1.1650], [ 1.0163, 0.5603, 0.9785, -0.1793, 0.0314]], [[-0.1252, 1.1299, 0.3169, 1.5040, -0.5238], [ 0.2552, 0.3170, 0.9385, -0.4486, -0.7811], [-0.1173, -0.9089, -2.1491, 0.3983, -1.2886], [ 0.5881, -0.6633, -1.0676, 0.1214, 0.2026]], [[ 0.1844, 0.5130, 0.5463, 0.2661, 1.2480], [ 0.7749, -0.6297, 1.2171, 1.2214, -0.0714], [ 0.0407, -0.1869, 0.6574, 1.1993, -0.3189], [-0.8170, 0.8329, -1.0978, 0.6092, 1.5736]]]]) , <Tensor 0x7f4c11e15c40> ├── 'a' --> tensor([[[0.6110, 0.7065, 0.9817, 0.5466, 0.9681, 0.5271, 0.9148, 0.8300], │ [0.9745, 0.3634, 0.6490, 0.7711, 0.1651, 0.7651, 0.0820, 0.6859], │ [0.0513, 0.8733, 0.0833, 0.0243, 0.4405, 0.7640, 0.9343, 0.4131]]]) ├── 'b' --> tensor([[1.4565, 1.3867, 1.0592, 1.1471, 1.0225, 1.0034]]) └── 'c' --> <Tensor 0x7f4c11e15c10> ├── 'd' --> tensor([[1.]]) └── 'noise' --> tensor([[[[ 0.0383, 0.8180, 0.3472, 1.3140, -0.7884], [-1.1390, 0.3253, -0.4337, -0.9790, 0.3086], [-1.9032, 0.3683, -1.0571, -0.4134, 0.0734], [-0.8027, -3.0336, -0.6686, -0.1992, -0.5842]], [[-1.2057, 1.8317, -2.0265, -0.5646, -0.3632], [-0.3409, 0.8284, 0.4916, 1.0882, -0.4506], [-0.1837, 0.3143, -0.1099, -2.1298, 0.1751], [-0.9796, -1.7407, -0.8685, 1.1067, 1.2986]], [[-1.7898, 0.6996, 1.1994, 0.2322, 0.6311], [-0.7834, 0.9913, -1.0560, 1.2750, 0.3016], [ 1.5431, 0.0910, 0.4831, 0.3254, -0.7582], [ 1.9326, 0.6189, 1.0870, -0.6356, 0.7896]]]]) , <Tensor 0x7f4c11e15ca0> ├── 'a' --> tensor([[[0.7970, 0.5926, 0.9834, 0.6448, 0.3721, 0.1004, 0.0332, 0.3718], │ [0.8415, 0.5384, 0.6109, 0.2469, 0.3713, 0.3119, 0.5707, 0.2317], │ [0.3114, 0.2227, 0.2023, 0.6401, 0.1025, 0.0098, 0.7681, 0.4003]]]) ├── 'b' --> tensor([[1.8828, 1.1409, 1.3218, 1.0089, 1.5753, 1.1019]]) └── 'c' --> <Tensor 0x7f4c11e15c70> ├── 'd' --> tensor([[6.]]) └── 'noise' --> tensor([[[[ 0.0989, 0.0352, -1.6074, 0.6577, -0.6247], [ 1.1357, -0.4504, 1.9912, -0.5223, 0.7360], [-0.1203, 1.2087, 0.0854, -0.3729, 0.8584], [-1.7039, 1.6009, -0.6874, 0.0073, -0.0476]], [[ 0.4003, 0.6010, 0.7245, -0.2537, -2.3287], [ 0.9332, -0.9192, -0.6174, -0.7217, 1.3059], [ 2.2037, -1.1794, 1.8018, -0.3159, 0.0279], [-1.3474, -1.3609, 0.2150, 1.2202, -1.0155]], [[ 0.3434, -0.3591, -1.1696, -0.3053, 1.3126], [-0.8759, 1.0126, -1.4366, -2.1002, -0.2968], [-1.2088, 0.2267, 0.9093, -1.4701, -1.4230], [ 0.3305, -0.8960, 1.1042, 2.1353, -1.1125]]]]) ) mean0 & mean1: tensor(1.2256) tensor(1.2256) even_index_a0: tensor([[[0.1018, 0.1229, 0.9646, 0.9868, 0.1387, 0.0339, 0.1518, 0.7598], [0.9991, 0.6408, 0.8581, 0.5391, 0.9900, 0.7575, 0.9536, 0.2315]], [[0.3198, 0.6870, 0.1427, 0.1506, 0.2779, 0.4413, 0.7380, 0.3396], [0.1236, 0.3174, 0.1848, 0.9525, 0.8351, 0.1266, 0.8205, 0.8777]], [[0.6110, 0.7065, 0.9817, 0.5466, 0.9681, 0.5271, 0.9148, 0.8300], [0.0513, 0.8733, 0.0833, 0.0243, 0.4405, 0.7640, 0.9343, 0.4131]], [[0.7970, 0.5926, 0.9834, 0.6448, 0.3721, 0.1004, 0.0332, 0.3718], [0.3114, 0.2227, 0.2023, 0.6401, 0.1025, 0.0098, 0.7681, 0.4003]]]) even_index_a1: tensor([[[0.1018, 0.1229, 0.9646, 0.9868, 0.1387, 0.0339, 0.1518, 0.7598], [0.9991, 0.6408, 0.8581, 0.5391, 0.9900, 0.7575, 0.9536, 0.2315]], [[0.3198, 0.6870, 0.1427, 0.1506, 0.2779, 0.4413, 0.7380, 0.3396], [0.1236, 0.3174, 0.1848, 0.9525, 0.8351, 0.1266, 0.8205, 0.8777]], [[0.6110, 0.7065, 0.9817, 0.5466, 0.9681, 0.5271, 0.9148, 0.8300], [0.0513, 0.8733, 0.0833, 0.0243, 0.4405, 0.7640, 0.9343, 0.4131]], [[0.7970, 0.5926, 0.9834, 0.6448, 0.3721, 0.1004, 0.0332, 0.3718], [0.3114, 0.2227, 0.2023, 0.6401, 0.1025, 0.0098, 0.7681, 0.4003]]]) <Size 0x7f4c71ce3460> ├── 'a' --> torch.Size([1, 3, 8]) ├── 'b' --> torch.Size([1, 6]) └── 'c' --> <Size 0x7f4c11e8f460> ├── 'd' --> torch.Size([1, 1]) └── 'noise' --> torch.Size([1, 3, 4, 5]) |
The implement with treetensor API is much simpler and clearer.