Customized Operations For Different Fields¶
Here is another example of the custom operations implemented with both native torch API and treetensor API.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 | import copy import torch import treetensor.torch as ttorch T, B = 3, 4 def with_nativetensor(batch_): mean_b_list = [] even_index_a_list = [] for i in range(len(batch_)): for k, v in batch_[i].items(): if k == 'a': v = v.float() even_index_a_list.append(v[::2]) elif k == 'b': v = v.float() transformed_v = torch.pow(v, 2) + 1.0 mean_b_list.append(transformed_v.mean()) elif k == 'c': for k1, v1 in v.items(): if k1 == 'd': v1 = v1.float() else: print('ignore keys: {}'.format(k1)) else: print('ignore keys: {}'.format(k)) for i in range(len(batch_)): for k in batch_[i].keys(): if k == 'd': batch_[i][k]['noise'] = torch.randn(size=(3, 4, 5)) mean_b = sum(mean_b_list) / len(mean_b_list) even_index_a = torch.stack(even_index_a_list, dim=0) return batch_, mean_b, even_index_a def with_treetensor(batch_): batch_ = [ttorch.tensor(b) for b in batch_] batch_ = ttorch.stack(batch_) batch_ = batch_.float() batch_.b = ttorch.pow(batch_.b, 2) + 1.0 batch_.c.noise = ttorch.randn(size=(B, 3, 4, 5)) mean_b = batch_.b.mean() even_index_a = batch_.a[:, ::2] batch_ = ttorch.split(batch_, split_size_or_sections=1, dim=0) return batch_, mean_b, even_index_a def get_data(): return { 'a': torch.rand(size=(T, 8)), 'b': torch.rand(size=(6,)), 'c': { 'd': torch.randint(0, 10, size=(1,)) } } if __name__ == "__main__": batch = [get_data() for _ in range(B)] batch0, mean0, even_index_a0 = with_nativetensor(copy.deepcopy(batch)) batch1, mean1, even_index_a1 = with_treetensor(copy.deepcopy(batch)) print(batch0) print('\n\n') print(batch1) assert torch.abs(mean0 - mean1) < 1e-6 print('mean0 & mean1:', mean0, mean1) print('\n') assert torch.abs((even_index_a0 - even_index_a1).max()) < 1e-6 print('even_index_a0:', even_index_a0) print('even_index_a1:', even_index_a1) assert len(batch0) == B assert len(batch1) == B assert isinstance(batch1[0], ttorch.Tensor) print(batch1[0].shape) |
The output should be like below, and all the assertions can be passed.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 | [{'a': tensor([[0.1760, 0.9876, 0.7666, 0.6312, 0.1349, 0.4530, 0.2687, 0.9414], [0.8164, 0.5808, 0.3442, 0.3910, 0.5593, 0.3062, 0.1630, 0.2680], [0.8270, 0.5847, 0.2073, 0.1759, 0.1533, 0.9498, 0.6257, 0.6740]]), 'b': tensor([0.5486, 0.2000, 0.2006, 0.3559, 0.4115, 0.9562]), 'c': {'d': tensor([9])}}, {'a': tensor([[0.6299, 0.4004, 0.4542, 0.7033, 0.5401, 0.3561, 0.3821, 0.1780], [0.3425, 0.5969, 0.5505, 0.6960, 0.4339, 0.9264, 0.5039, 0.7249], [0.0891, 0.3546, 0.2275, 0.9833, 0.5123, 0.2282, 0.7045, 0.4442]]), 'b': tensor([0.4095, 0.3095, 0.8428, 0.8548, 0.7888, 0.5654]), 'c': {'d': tensor([2])}}, {'a': tensor([[0.7502, 0.8676, 0.9999, 0.7419, 0.8389, 0.0293, 0.1430, 0.7604], [0.6477, 0.8092, 0.2047, 0.1347, 0.1611, 0.8959, 0.7691, 0.3033], [0.8967, 0.5410, 0.6066, 0.7471, 0.4575, 0.2954, 0.4237, 0.7751]]), 'b': tensor([0.1823, 0.2579, 0.7932, 0.9217, 0.2090, 0.0825]), 'c': {'d': tensor([0])}}, {'a': tensor([[0.1884, 0.1080, 0.1079, 0.4079, 0.0615, 0.3103, 0.5896, 0.2523], [0.1338, 0.0235, 0.6606, 0.3652, 0.8009, 0.3431, 0.9553, 0.6687], [0.3617, 0.2567, 0.5046, 0.0802, 0.5889, 0.4226, 0.0626, 0.3299]]), 'b': tensor([0.5523, 0.4252, 0.6987, 0.2098, 0.3236, 0.6150]), 'c': {'d': tensor([0])}}] (<Tensor 0x7fc8ed333580> ├── 'a' --> tensor([[[0.1760, 0.9876, 0.7666, 0.6312, 0.1349, 0.4530, 0.2687, 0.9414], │ [0.8164, 0.5808, 0.3442, 0.3910, 0.5593, 0.3062, 0.1630, 0.2680], │ [0.8270, 0.5847, 0.2073, 0.1759, 0.1533, 0.9498, 0.6257, 0.6740]]]) ├── 'b' --> tensor([[1.3009, 1.0400, 1.0402, 1.1267, 1.1693, 1.9144]]) └── 'c' --> <Tensor 0x7fc8ed3335e0> ├── 'd' --> tensor([[9.]]) └── 'noise' --> tensor([[[[ 0.2810, 1.4643, -0.9719, 0.3324, 0.1614], [-2.4645, 0.0850, -0.5767, -1.1389, 0.4888], [-2.3588, -0.4756, -1.1212, 0.8527, -0.3302], [ 0.7773, -1.2100, 1.1747, 0.9436, 1.3061]], [[-0.5174, 0.4250, -1.0929, 0.5150, 0.0363], [ 0.2181, -0.8368, 0.6428, -0.7121, -0.5522], [-0.1242, 0.8859, 1.4753, 0.6179, 0.0561], [ 0.0670, -1.6503, 0.4668, -0.1533, -1.2691]], [[-0.5205, -0.5433, 0.5092, -0.3479, 0.9470], [-0.2283, -0.4997, 0.4974, -1.4480, -2.0320], [-0.9231, 0.3258, -0.4301, 1.3455, -0.5889], [-1.1355, 0.8925, -1.5451, -0.2237, -0.6219]]]]) , <Tensor 0x7fc8ed3336a0> ├── 'a' --> tensor([[[0.6299, 0.4004, 0.4542, 0.7033, 0.5401, 0.3561, 0.3821, 0.1780], │ [0.3425, 0.5969, 0.5505, 0.6960, 0.4339, 0.9264, 0.5039, 0.7249], │ [0.0891, 0.3546, 0.2275, 0.9833, 0.5123, 0.2282, 0.7045, 0.4442]]]) ├── 'b' --> tensor([[1.1677, 1.0958, 1.7103, 1.7307, 1.6222, 1.3197]]) └── 'c' --> <Tensor 0x7fc8ed3334c0> ├── 'd' --> tensor([[2.]]) └── 'noise' --> tensor([[[[-1.1136, -1.0463, 0.2548, 0.1471, -0.3970], [ 1.8490, 1.5962, -0.6584, -1.9240, 0.1384], [ 0.0654, 0.2813, -1.2626, -0.1989, -1.7316], [-1.2335, -1.1391, -0.2785, -0.6234, -1.7957]], [[-0.2992, 2.4934, 0.1978, 1.3581, -0.0760], [ 1.6081, 0.7025, 1.6493, 0.4080, -0.4362], [-0.1329, 0.2666, 1.8063, -1.3381, -0.5317], [-1.6255, 0.5464, 1.0031, -0.6533, 0.1872]], [[-0.2006, 0.1562, 0.3206, -1.1773, 0.5603], [-1.0208, -2.2813, 1.7826, -0.8397, 0.1486], [-1.3046, -1.7730, 2.0844, -1.9943, -0.2563], [ 0.1521, -0.7617, 0.9459, 0.2308, 0.7644]]]]) , <Tensor 0x7fc8ed333700> ├── 'a' --> tensor([[[0.7502, 0.8676, 0.9999, 0.7419, 0.8389, 0.0293, 0.1430, 0.7604], │ [0.6477, 0.8092, 0.2047, 0.1347, 0.1611, 0.8959, 0.7691, 0.3033], │ [0.8967, 0.5410, 0.6066, 0.7471, 0.4575, 0.2954, 0.4237, 0.7751]]]) ├── 'b' --> tensor([[1.0332, 1.0665, 1.6292, 1.8496, 1.0437, 1.0068]]) └── 'c' --> <Tensor 0x7fc8ed3336d0> ├── 'd' --> tensor([[0.]]) └── 'noise' --> tensor([[[[ 0.1881, 1.3646, -1.3650, 1.7777, 1.1769], [-0.2026, -0.4343, -2.5894, 0.8237, -0.1804], [ 0.7298, 1.2785, -1.0835, -1.1562, -2.3360], [ 2.3419, 0.3254, -0.9012, -1.5318, 0.3249]], [[-0.9911, -1.3211, 0.7636, 1.3447, -0.9473], [ 0.4414, 0.3585, -1.6078, -0.0313, 0.8511], [ 2.0062, 0.8407, 1.4273, 1.4176, -0.8433], [ 1.6407, -0.5180, 0.6023, -0.3928, -2.2764]], [[ 2.6079, -1.3053, -1.6408, -0.7152, 1.5177], [ 1.1402, -1.3566, -1.5336, 1.6338, 0.0867], [-0.3611, 0.3741, 0.6614, 1.9662, 0.3121], [ 0.1861, -0.1407, 1.1927, -0.6658, 1.1299]]]]) , <Tensor 0x7fc8ed333760> ├── 'a' --> tensor([[[0.1884, 0.1080, 0.1079, 0.4079, 0.0615, 0.3103, 0.5896, 0.2523], │ [0.1338, 0.0235, 0.6606, 0.3652, 0.8009, 0.3431, 0.9553, 0.6687], │ [0.3617, 0.2567, 0.5046, 0.0802, 0.5889, 0.4226, 0.0626, 0.3299]]]) ├── 'b' --> tensor([[1.3050, 1.1808, 1.4882, 1.0440, 1.1047, 1.3782]]) └── 'c' --> <Tensor 0x7fc8ed333730> ├── 'd' --> tensor([[0.]]) └── 'noise' --> tensor([[[[-2.4965, 0.3270, 1.2949, -0.1508, 0.5471], [ 1.1143, 0.7643, -0.5936, -0.2020, 0.8068], [ 0.3543, 1.0698, -0.0111, -0.0924, -1.7070], [-0.0966, 0.0233, 0.3531, -1.0438, -0.8790]], [[ 1.8667, 0.2731, 2.7688, -0.8521, -0.7153], [-0.1571, 1.7009, 1.0061, 0.0157, -1.3572], [-0.1685, 1.5744, -0.3597, -0.3344, 0.2629], [ 0.3098, 1.9305, -0.2067, -1.2160, 0.2546]], [[ 0.3821, -1.5974, -1.7580, -0.5292, 1.7847], [-2.1281, 0.2113, 1.0670, -0.8068, -0.9332], [-1.3835, -0.5270, -1.7189, -0.9176, -0.5866], [-0.9154, -1.4674, 2.5163, 0.1607, -1.0668]]]]) ) mean0 & mean1: tensor(1.3070) tensor(1.3070) even_index_a0: tensor([[[0.1760, 0.9876, 0.7666, 0.6312, 0.1349, 0.4530, 0.2687, 0.9414], [0.8270, 0.5847, 0.2073, 0.1759, 0.1533, 0.9498, 0.6257, 0.6740]], [[0.6299, 0.4004, 0.4542, 0.7033, 0.5401, 0.3561, 0.3821, 0.1780], [0.0891, 0.3546, 0.2275, 0.9833, 0.5123, 0.2282, 0.7045, 0.4442]], [[0.7502, 0.8676, 0.9999, 0.7419, 0.8389, 0.0293, 0.1430, 0.7604], [0.8967, 0.5410, 0.6066, 0.7471, 0.4575, 0.2954, 0.4237, 0.7751]], [[0.1884, 0.1080, 0.1079, 0.4079, 0.0615, 0.3103, 0.5896, 0.2523], [0.3617, 0.2567, 0.5046, 0.0802, 0.5889, 0.4226, 0.0626, 0.3299]]]) even_index_a1: tensor([[[0.1760, 0.9876, 0.7666, 0.6312, 0.1349, 0.4530, 0.2687, 0.9414], [0.8270, 0.5847, 0.2073, 0.1759, 0.1533, 0.9498, 0.6257, 0.6740]], [[0.6299, 0.4004, 0.4542, 0.7033, 0.5401, 0.3561, 0.3821, 0.1780], [0.0891, 0.3546, 0.2275, 0.9833, 0.5123, 0.2282, 0.7045, 0.4442]], [[0.7502, 0.8676, 0.9999, 0.7419, 0.8389, 0.0293, 0.1430, 0.7604], [0.8967, 0.5410, 0.6066, 0.7471, 0.4575, 0.2954, 0.4237, 0.7751]], [[0.1884, 0.1080, 0.1079, 0.4079, 0.0615, 0.3103, 0.5896, 0.2523], [0.3617, 0.2567, 0.5046, 0.0802, 0.5889, 0.4226, 0.0626, 0.3299]]]) <Size 0x7fc95322e430> ├── 'a' --> torch.Size([1, 3, 8]) ├── 'b' --> torch.Size([1, 6]) └── 'c' --> <Size 0x7fc8ed3ad940> ├── 'd' --> torch.Size([1, 1]) └── 'noise' --> torch.Size([1, 3, 4, 5]) |
The implement with treetensor API is much simpler and clearer.