Customized Operations For Different Fields¶
Here is another example of the custom operations implemented with both native torch API and treetensor API.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 | import copy import torch import treetensor.torch as ttorch T, B = 3, 4 def with_nativetensor(batch_): mean_b_list = [] even_index_a_list = [] for i in range(len(batch_)): for k, v in batch_[i].items(): if k == 'a': v = v.float() even_index_a_list.append(v[::2]) elif k == 'b': v = v.float() transformed_v = torch.pow(v, 2) + 1.0 mean_b_list.append(transformed_v.mean()) elif k == 'c': for k1, v1 in v.items(): if k1 == 'd': v1 = v1.float() else: print('ignore keys: {}'.format(k1)) else: print('ignore keys: {}'.format(k)) for i in range(len(batch_)): for k in batch_[i].keys(): if k == 'd': batch_[i][k]['noise'] = torch.randn(size=(3, 4, 5)) mean_b = sum(mean_b_list) / len(mean_b_list) even_index_a = torch.stack(even_index_a_list, dim=0) return batch_, mean_b, even_index_a def with_treetensor(batch_): batch_ = [ttorch.tensor(b) for b in batch_] batch_ = ttorch.stack(batch_) batch_ = batch_.float() batch_.b = ttorch.pow(batch_.b, 2) + 1.0 batch_.c.noise = ttorch.randn(size=(B, 3, 4, 5)) mean_b = batch_.b.mean() even_index_a = batch_.a[:, ::2] batch_ = ttorch.split(batch_, split_size_or_sections=1, dim=0) return batch_, mean_b, even_index_a def get_data(): return { 'a': torch.rand(size=(T, 8)), 'b': torch.rand(size=(6,)), 'c': { 'd': torch.randint(0, 10, size=(1,)) } } if __name__ == "__main__": batch = [get_data() for _ in range(B)] batch0, mean0, even_index_a0 = with_nativetensor(copy.deepcopy(batch)) batch1, mean1, even_index_a1 = with_treetensor(copy.deepcopy(batch)) print(batch0) print('\n\n') print(batch1) assert torch.abs(mean0 - mean1) < 1e-6 print('mean0 & mean1:', mean0, mean1) print('\n') assert torch.abs((even_index_a0 - even_index_a1).max()) < 1e-6 print('even_index_a0:', even_index_a0) print('even_index_a1:', even_index_a1) assert len(batch0) == B assert len(batch1) == B assert isinstance(batch1[0], ttorch.Tensor) print(batch1[0].shape) |
The output should be like below, and all the assertions can be passed.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 | [{'a': tensor([[0.1596, 0.4012, 0.3223, 0.4852, 0.8140, 0.1468, 0.5699, 0.1797], [0.2723, 0.2857, 0.5886, 0.8552, 0.5341, 0.3000, 0.0769, 0.0941], [0.1838, 0.3047, 0.4412, 0.8800, 0.8474, 0.5083, 0.8790, 0.9913]]), 'b': tensor([0.0547, 0.4228, 0.6144, 0.0303, 0.6479, 0.0555]), 'c': {'d': tensor([0])}}, {'a': tensor([[0.8184, 0.3520, 0.2519, 0.8049, 0.0454, 0.2602, 0.7012, 0.7521], [0.4884, 0.2814, 0.5682, 0.4806, 0.9580, 0.2219, 0.3575, 0.6630], [0.3724, 0.5604, 0.3354, 0.7325, 0.1509, 0.0982, 0.6868, 0.3617]]), 'b': tensor([0.3862, 0.3536, 0.8789, 0.2303, 0.2371, 0.9049]), 'c': {'d': tensor([4])}}, {'a': tensor([[0.9463, 0.1392, 0.6163, 0.9218, 0.6002, 0.8881, 0.8347, 0.1492], [0.0390, 0.6043, 0.2513, 0.3038, 0.1501, 0.3200, 0.7904, 0.9806], [0.7639, 0.5355, 0.9951, 0.9049, 0.7562, 0.7972, 0.0351, 0.0165]]), 'b': tensor([0.8851, 0.0581, 0.6590, 0.7748, 0.7943, 0.3452]), 'c': {'d': tensor([2])}}, {'a': tensor([[0.1397, 0.5560, 0.1809, 0.6830, 0.0409, 0.5592, 0.4417, 0.1280], [0.0140, 0.6646, 0.3147, 0.2864, 0.8664, 0.7113, 0.1141, 0.5964], [0.1541, 0.1119, 0.9381, 0.8998, 0.6656, 0.8518, 0.7366, 0.2615]]), 'b': tensor([0.7784, 0.7635, 0.8329, 0.2074, 0.4004, 0.3969]), 'c': {'d': tensor([3])}}] (<Tensor 0x7f632b0f4dc0> ├── 'a' --> tensor([[[0.1596, 0.4012, 0.3223, 0.4852, 0.8140, 0.1468, 0.5699, 0.1797], │ [0.2723, 0.2857, 0.5886, 0.8552, 0.5341, 0.3000, 0.0769, 0.0941], │ [0.1838, 0.3047, 0.4412, 0.8800, 0.8474, 0.5083, 0.8790, 0.9913]]]) ├── 'b' --> tensor([[1.0030, 1.1788, 1.3775, 1.0009, 1.4197, 1.0031]]) └── 'c' --> <Tensor 0x7f632b0f4e20> ├── 'd' --> tensor([[0.]]) └── 'noise' --> tensor([[[[-1.6472, -2.1012, 0.5953, 1.2593, 0.4807], [-1.7966, -1.3679, 0.2307, 1.3067, 0.6097], [ 0.1925, -1.0879, -0.1310, 0.7680, -0.3777], [-0.0476, -1.5376, 1.2984, 1.0988, -0.1543]], [[-0.4359, -1.0796, 0.2791, 1.5900, -0.2499], [ 0.5715, -0.5642, -1.2343, 0.7848, -0.9442], [ 0.2103, 0.3017, 0.0613, -0.7353, -0.5311], [ 0.0271, 0.6253, -0.0628, -0.5104, -0.3841]], [[ 0.0466, -2.4251, -1.4401, 1.4015, -1.7317], [ 0.3690, -0.1695, -0.1779, 0.3515, 1.5173], [ 1.2173, 1.1610, 0.7979, -0.4319, 0.0112], [-1.6619, 1.0395, -0.3872, -0.9098, 2.0577]]]]) , <Tensor 0x7f632b0f4ee0> ├── 'a' --> tensor([[[0.8184, 0.3520, 0.2519, 0.8049, 0.0454, 0.2602, 0.7012, 0.7521], │ [0.4884, 0.2814, 0.5682, 0.4806, 0.9580, 0.2219, 0.3575, 0.6630], │ [0.3724, 0.5604, 0.3354, 0.7325, 0.1509, 0.0982, 0.6868, 0.3617]]]) ├── 'b' --> tensor([[1.1492, 1.1250, 1.7725, 1.0531, 1.0562, 1.8188]]) └── 'c' --> <Tensor 0x7f632b0f4d30> ├── 'd' --> tensor([[4.]]) └── 'noise' --> tensor([[[[ 0.7921, -0.0755, -0.8151, 0.8850, 1.4933], [ 1.2783, -0.4826, 1.3904, -1.5220, -0.8982], [-1.1363, -0.3053, 0.0885, 0.2277, -0.7028], [-0.9773, -0.0277, 0.6444, -1.2285, -0.2365]], [[-0.3434, -0.8346, 0.0528, 0.0435, -0.5553], [ 0.7556, -1.4243, -0.5564, 0.2192, -0.3744], [ 0.2173, -0.0042, 0.3533, -2.2362, -0.4302], [-0.0620, -0.8077, -1.3023, -1.5293, -0.3642]], [[ 0.9234, 0.6934, 0.3226, 0.6812, 0.4780], [ 0.7483, 0.9376, -0.7160, -0.8707, 2.6354], [ 1.7357, -0.3347, 1.0475, 1.2589, -1.0588], [-1.0726, -2.3005, 0.7702, 1.2074, -0.6260]]]]) , <Tensor 0x7f632b0f4f40> ├── 'a' --> tensor([[[0.9463, 0.1392, 0.6163, 0.9218, 0.6002, 0.8881, 0.8347, 0.1492], │ [0.0390, 0.6043, 0.2513, 0.3038, 0.1501, 0.3200, 0.7904, 0.9806], │ [0.7639, 0.5355, 0.9951, 0.9049, 0.7562, 0.7972, 0.0351, 0.0165]]]) ├── 'b' --> tensor([[1.7834, 1.0034, 1.4343, 1.6003, 1.6309, 1.1191]]) └── 'c' --> <Tensor 0x7f632b0f4f10> ├── 'd' --> tensor([[2.]]) └── 'noise' --> tensor([[[[ 0.9213, 0.0383, 0.2955, 1.3049, -1.9328], [-1.0894, -0.5289, -0.2533, -1.6255, -0.8783], [-0.4002, 0.6349, 2.1597, -0.0473, -0.2817], [-0.0421, 0.3092, 0.9032, 0.8430, 0.7661]], [[ 0.1034, 1.1913, 0.2249, 0.5569, -0.0987], [ 0.3374, 0.1335, 1.9032, -1.1055, -0.1811], [-0.7397, 0.4564, -0.3752, 0.7368, 0.5090], [-1.0403, -1.3333, -0.0523, 1.1758, 0.2347]], [[ 0.4211, -1.6655, -0.4795, 0.4314, 0.8669], [ 0.2273, 0.3823, 1.9700, 0.5738, 1.8101], [ 2.1773, -0.0461, 0.3586, -1.5405, -0.4960], [-0.6693, -0.3611, 0.0776, -0.8386, 0.3934]]]]) , <Tensor 0x7f632b0f4fa0> ├── 'a' --> tensor([[[0.1397, 0.5560, 0.1809, 0.6830, 0.0409, 0.5592, 0.4417, 0.1280], │ [0.0140, 0.6646, 0.3147, 0.2864, 0.8664, 0.7113, 0.1141, 0.5964], │ [0.1541, 0.1119, 0.9381, 0.8998, 0.6656, 0.8518, 0.7366, 0.2615]]]) ├── 'b' --> tensor([[1.6059, 1.5829, 1.6938, 1.0430, 1.1603, 1.1575]]) └── 'c' --> <Tensor 0x7f632b0f4f70> ├── 'd' --> tensor([[3.]]) └── 'noise' --> tensor([[[[-0.2365, -1.1637, -0.8815, -0.7596, 0.3517], [-1.4273, 0.2471, 0.7383, -1.1357, -1.0635], [-0.5137, 1.2854, 0.9059, 1.2296, 0.4843], [-0.3079, 1.8051, -0.5554, -1.5409, 0.5662]], [[ 0.0711, -0.6803, -0.2576, -0.5582, -0.1279], [ 0.0366, 0.9816, -0.6424, 1.1566, -0.1251], [ 0.3017, 1.1231, 0.1248, -0.3778, 1.1984], [-1.0851, -1.1335, -0.3563, -0.4848, -1.0288]], [[-0.4021, 0.5456, -0.7004, 1.3560, -0.4658], [ 0.2582, -0.5490, 1.0653, 0.3482, 0.9825], [ 2.2991, -0.6494, 0.6641, 1.0891, -1.8204], [-0.5928, 0.7025, 1.4344, 0.4906, 2.3079]]]]) ) mean0 & mean1: tensor(1.3239) tensor(1.3239) even_index_a0: tensor([[[0.1596, 0.4012, 0.3223, 0.4852, 0.8140, 0.1468, 0.5699, 0.1797], [0.1838, 0.3047, 0.4412, 0.8800, 0.8474, 0.5083, 0.8790, 0.9913]], [[0.8184, 0.3520, 0.2519, 0.8049, 0.0454, 0.2602, 0.7012, 0.7521], [0.3724, 0.5604, 0.3354, 0.7325, 0.1509, 0.0982, 0.6868, 0.3617]], [[0.9463, 0.1392, 0.6163, 0.9218, 0.6002, 0.8881, 0.8347, 0.1492], [0.7639, 0.5355, 0.9951, 0.9049, 0.7562, 0.7972, 0.0351, 0.0165]], [[0.1397, 0.5560, 0.1809, 0.6830, 0.0409, 0.5592, 0.4417, 0.1280], [0.1541, 0.1119, 0.9381, 0.8998, 0.6656, 0.8518, 0.7366, 0.2615]]]) even_index_a1: tensor([[[0.1596, 0.4012, 0.3223, 0.4852, 0.8140, 0.1468, 0.5699, 0.1797], [0.1838, 0.3047, 0.4412, 0.8800, 0.8474, 0.5083, 0.8790, 0.9913]], [[0.8184, 0.3520, 0.2519, 0.8049, 0.0454, 0.2602, 0.7012, 0.7521], [0.3724, 0.5604, 0.3354, 0.7325, 0.1509, 0.0982, 0.6868, 0.3617]], [[0.9463, 0.1392, 0.6163, 0.9218, 0.6002, 0.8881, 0.8347, 0.1492], [0.7639, 0.5355, 0.9951, 0.9049, 0.7562, 0.7972, 0.0351, 0.0165]], [[0.1397, 0.5560, 0.1809, 0.6830, 0.0409, 0.5592, 0.4417, 0.1280], [0.1541, 0.1119, 0.9381, 0.8998, 0.6656, 0.8518, 0.7366, 0.2615]]]) <Size 0x7f638affb460> ├── 'a' --> torch.Size([1, 3, 8]) ├── 'b' --> torch.Size([1, 6]) └── 'c' --> <Size 0x7f632b171460> ├── 'd' --> torch.Size([1, 1]) └── 'noise' --> torch.Size([1, 3, 4, 5]) |
The implement with treetensor API is much simpler and clearer.