Customized Operations For Different Fields¶
Here is another example of the custom operations implemented with both native torch API and treetensor API.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 | import copy import torch import treetensor.torch as ttorch T, B = 3, 4 def with_nativetensor(batch_): mean_b_list = [] even_index_a_list = [] for i in range(len(batch_)): for k, v in batch_[i].items(): if k == 'a': v = v.float() even_index_a_list.append(v[::2]) elif k == 'b': v = v.float() transformed_v = torch.pow(v, 2) + 1.0 mean_b_list.append(transformed_v.mean()) elif k == 'c': for k1, v1 in v.items(): if k1 == 'd': v1 = v1.float() else: print('ignore keys: {}'.format(k1)) else: print('ignore keys: {}'.format(k)) for i in range(len(batch_)): for k in batch_[i].keys(): if k == 'd': batch_[i][k]['noise'] = torch.randn(size=(3, 4, 5)) mean_b = sum(mean_b_list) / len(mean_b_list) even_index_a = torch.stack(even_index_a_list, dim=0) return batch_, mean_b, even_index_a def with_treetensor(batch_): batch_ = [ttorch.tensor(b) for b in batch_] batch_ = ttorch.stack(batch_) batch_ = batch_.float() batch_.b = ttorch.pow(batch_.b, 2) + 1.0 batch_.c.noise = ttorch.randn(size=(B, 3, 4, 5)) mean_b = batch_.b.mean() even_index_a = batch_.a[:, ::2] batch_ = ttorch.split(batch_, split_size_or_sections=1, dim=0) return batch_, mean_b, even_index_a def get_data(): return { 'a': torch.rand(size=(T, 8)), 'b': torch.rand(size=(6,)), 'c': { 'd': torch.randint(0, 10, size=(1,)) } } if __name__ == "__main__": batch = [get_data() for _ in range(B)] batch0, mean0, even_index_a0 = with_nativetensor(copy.deepcopy(batch)) batch1, mean1, even_index_a1 = with_treetensor(copy.deepcopy(batch)) print(batch0) print('\n\n') print(batch1) assert torch.abs(mean0 - mean1) < 1e-6 print('mean0 & mean1:', mean0, mean1) print('\n') assert torch.abs((even_index_a0 - even_index_a1).max()) < 1e-6 print('even_index_a0:', even_index_a0) print('even_index_a1:', even_index_a1) assert len(batch0) == B assert len(batch1) == B assert isinstance(batch1[0], ttorch.Tensor) print(batch1[0].shape) |
The output should be like below, and all the assertions can be passed.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 | [{'a': tensor([[0.8586, 0.8745, 0.2379, 0.5585, 0.4995, 0.4912, 0.6396, 0.3216], [0.8695, 0.2370, 0.3972, 0.8458, 0.5265, 0.6883, 0.8920, 0.2215], [0.3978, 0.3845, 0.2481, 0.6023, 0.3278, 0.4616, 0.1159, 0.6446]]), 'b': tensor([0.0108, 0.5429, 0.9307, 0.0859, 0.9104, 0.6995]), 'c': {'d': tensor([6])}}, {'a': tensor([[0.3686, 0.3797, 0.9587, 0.0678, 0.8838, 0.7381, 0.0863, 0.2462], [0.5136, 0.8755, 0.4092, 0.7654, 0.2754, 0.3902, 0.9426, 0.3342], [0.4133, 0.3055, 0.6118, 0.5253, 0.6120, 0.2868, 0.4473, 0.9701]]), 'b': tensor([0.1724, 0.0885, 0.7045, 0.5513, 0.8351, 0.5251]), 'c': {'d': tensor([3])}}, {'a': tensor([[0.1634, 0.9050, 0.1383, 0.4304, 0.6519, 0.7408, 0.8242, 0.4771], [0.5757, 0.9322, 0.4528, 0.7131, 0.8326, 0.8706, 0.0143, 0.4793], [0.6991, 0.2571, 0.7627, 0.8777, 0.9678, 0.5340, 0.9114, 0.8534]]), 'b': tensor([0.1379, 0.1931, 0.7572, 0.1832, 0.4489, 0.7069]), 'c': {'d': tensor([1])}}, {'a': tensor([[0.5449, 0.3750, 0.1559, 0.1598, 0.8304, 0.5090, 0.8535, 0.4835], [0.7706, 0.3237, 0.6379, 0.6963, 0.8241, 0.8523, 0.1088, 0.6494], [0.4307, 0.0856, 0.3791, 0.7000, 0.2590, 0.6379, 0.9880, 0.7518]]), 'b': tensor([0.8498, 0.5255, 0.4614, 0.1046, 0.8997, 0.7453]), 'c': {'d': tensor([5])}}] (<Tensor 0x7f4ed5b9f2e0> ├── 'a' --> tensor([[[0.8586, 0.8745, 0.2379, 0.5585, 0.4995, 0.4912, 0.6396, 0.3216], │ [0.8695, 0.2370, 0.3972, 0.8458, 0.5265, 0.6883, 0.8920, 0.2215], │ [0.3978, 0.3845, 0.2481, 0.6023, 0.3278, 0.4616, 0.1159, 0.6446]]]) ├── 'b' --> tensor([[1.0001, 1.2947, 1.8663, 1.0074, 1.8288, 1.4892]]) └── 'c' --> <Tensor 0x7f4ed5b9f340> ├── 'd' --> tensor([[6.]]) └── 'noise' --> tensor([[[[ 1.3119, 1.0073, -0.0826, -0.0412, -1.2066], [ 1.1487, -0.9194, -1.0484, 1.0572, -0.3298], [ 1.0939, 1.0182, 0.5228, 1.0135, -0.0601], [ 0.6180, -0.4755, -1.2756, -0.3167, -0.5141]], [[ 1.4911, 0.2002, -0.2504, 0.1903, 0.0089], [-1.3901, 0.5571, -0.1145, 1.1175, -0.5574], [-0.2847, -1.4312, -1.4599, -0.0578, 0.7877], [ 0.1002, 0.2290, 0.8220, -0.1056, 1.4887]], [[-0.0857, 1.6059, -2.1704, -0.4478, 0.7624], [ 1.5118, 0.8130, 0.0370, -0.6705, 2.8412], [-1.1921, -1.5076, 0.3315, 1.5197, 0.0273], [ 0.6083, -2.3666, -0.6193, -0.1556, 0.7426]]]]) , <Tensor 0x7f4ed5b9f3a0> ├── 'a' --> tensor([[[0.3686, 0.3797, 0.9587, 0.0678, 0.8838, 0.7381, 0.0863, 0.2462], │ [0.5136, 0.8755, 0.4092, 0.7654, 0.2754, 0.3902, 0.9426, 0.3342], │ [0.4133, 0.3055, 0.6118, 0.5253, 0.6120, 0.2868, 0.4473, 0.9701]]]) ├── 'b' --> tensor([[1.0297, 1.0078, 1.4964, 1.3040, 1.6973, 1.2758]]) └── 'c' --> <Tensor 0x7f4ed5b9f280> ├── 'd' --> tensor([[3.]]) └── 'noise' --> tensor([[[[-0.6261, 1.0927, -0.8166, -1.0715, 0.3092], [ 1.8608, -0.4651, -1.7187, -0.8299, 0.4223], [-1.4070, 1.3266, -1.7648, 0.2992, -0.2652], [ 0.7462, 0.1425, -0.2780, -0.7319, 0.1484]], [[-0.4234, 0.2794, -0.4522, 0.0267, 0.4890], [ 1.3390, 0.1871, -0.9754, 0.9092, -0.5868], [-0.4940, 0.1136, -2.5127, 0.5628, -0.8442], [-0.2515, 1.7598, -0.3733, 0.0215, -0.0501]], [[-1.5443, 1.0383, 0.3005, 1.2338, 1.5016], [-0.8317, -0.9826, -0.4229, 0.8076, -0.3408], [-2.7238, 0.8420, 0.2547, -0.9996, -1.1935], [-0.0824, -0.2600, -0.2532, 0.3305, -0.5450]]]]) , <Tensor 0x7f4ed5b9f400> ├── 'a' --> tensor([[[0.1634, 0.9050, 0.1383, 0.4304, 0.6519, 0.7408, 0.8242, 0.4771], │ [0.5757, 0.9322, 0.4528, 0.7131, 0.8326, 0.8706, 0.0143, 0.4793], │ [0.6991, 0.2571, 0.7627, 0.8777, 0.9678, 0.5340, 0.9114, 0.8534]]]) ├── 'b' --> tensor([[1.0190, 1.0373, 1.5733, 1.0336, 1.2015, 1.4997]]) └── 'c' --> <Tensor 0x7f4ed5b9f3d0> ├── 'd' --> tensor([[1.]]) └── 'noise' --> tensor([[[[-0.7697, -0.6644, -0.6809, 0.5710, 0.8831], [ 0.5016, -1.6522, 0.6898, 0.9388, -2.8419], [ 1.0921, -1.8615, 0.0894, 1.0522, 0.0959], [ 0.1511, -0.5223, 0.0411, -0.3828, -0.8376]], [[ 0.2205, -0.8009, 2.1709, 1.3212, 1.9122], [ 0.1257, 1.5379, -0.6085, 0.6063, 0.4691], [ 0.9203, 0.6534, -0.6768, -0.8586, 0.1047], [ 0.0038, -2.3927, -2.6818, 1.6763, 1.0318]], [[-0.3389, -1.0578, -0.3507, 0.4921, 0.0679], [ 1.2042, -1.0514, -2.4790, 1.0759, -0.7012], [ 0.3001, -2.3341, 0.5223, -0.7165, -0.3541], [-0.0838, 0.2886, 0.9576, -0.5891, 0.4655]]]]) , <Tensor 0x7f4ed5b9f460> ├── 'a' --> tensor([[[0.5449, 0.3750, 0.1559, 0.1598, 0.8304, 0.5090, 0.8535, 0.4835], │ [0.7706, 0.3237, 0.6379, 0.6963, 0.8241, 0.8523, 0.1088, 0.6494], │ [0.4307, 0.0856, 0.3791, 0.7000, 0.2590, 0.6379, 0.9880, 0.7518]]]) ├── 'b' --> tensor([[1.7221, 1.2761, 1.2129, 1.0109, 1.8095, 1.5554]]) └── 'c' --> <Tensor 0x7f4ed5b9f430> ├── 'd' --> tensor([[5.]]) └── 'noise' --> tensor([[[[-0.1033, -0.5404, -0.5412, 0.6728, -0.0648], [-1.2375, 1.0447, -0.0787, -1.6941, 0.6986], [ 1.5592, -0.3467, -0.4237, -0.0275, -0.3738], [-1.5760, 0.7306, -0.0981, -0.3530, 1.3364]], [[-1.0773, -2.0509, -0.0376, -1.1306, -0.2804], [-1.6774, -1.1321, -0.0292, -0.8410, -0.0336], [ 0.2438, -0.1650, 1.4451, 0.7929, -2.7484], [ 1.3637, 0.7989, 1.5782, 0.9569, 0.3596]], [[ 0.1463, 1.9744, 0.2162, -1.7828, -0.3763], [-0.7063, 0.0489, 1.0735, -0.5237, 1.1289], [-0.8941, 0.8918, 1.3383, 0.0798, 0.6454], [-0.4515, 0.3357, -0.5779, 0.8850, 1.2229]]]]) ) mean0 & mean1: tensor(1.3437) tensor(1.3437) even_index_a0: tensor([[[0.8586, 0.8745, 0.2379, 0.5585, 0.4995, 0.4912, 0.6396, 0.3216], [0.3978, 0.3845, 0.2481, 0.6023, 0.3278, 0.4616, 0.1159, 0.6446]], [[0.3686, 0.3797, 0.9587, 0.0678, 0.8838, 0.7381, 0.0863, 0.2462], [0.4133, 0.3055, 0.6118, 0.5253, 0.6120, 0.2868, 0.4473, 0.9701]], [[0.1634, 0.9050, 0.1383, 0.4304, 0.6519, 0.7408, 0.8242, 0.4771], [0.6991, 0.2571, 0.7627, 0.8777, 0.9678, 0.5340, 0.9114, 0.8534]], [[0.5449, 0.3750, 0.1559, 0.1598, 0.8304, 0.5090, 0.8535, 0.4835], [0.4307, 0.0856, 0.3791, 0.7000, 0.2590, 0.6379, 0.9880, 0.7518]]]) even_index_a1: tensor([[[0.8586, 0.8745, 0.2379, 0.5585, 0.4995, 0.4912, 0.6396, 0.3216], [0.3978, 0.3845, 0.2481, 0.6023, 0.3278, 0.4616, 0.1159, 0.6446]], [[0.3686, 0.3797, 0.9587, 0.0678, 0.8838, 0.7381, 0.0863, 0.2462], [0.4133, 0.3055, 0.6118, 0.5253, 0.6120, 0.2868, 0.4473, 0.9701]], [[0.1634, 0.9050, 0.1383, 0.4304, 0.6519, 0.7408, 0.8242, 0.4771], [0.6991, 0.2571, 0.7627, 0.8777, 0.9678, 0.5340, 0.9114, 0.8534]], [[0.5449, 0.3750, 0.1559, 0.1598, 0.8304, 0.5090, 0.8535, 0.4835], [0.4307, 0.0856, 0.3791, 0.7000, 0.2590, 0.6379, 0.9880, 0.7518]]]) <Size 0x7f4f3c9157c0> ├── 'a' --> torch.Size([1, 3, 8]) ├── 'b' --> torch.Size([1, 6]) └── 'c' --> <Size 0x7f4ed5b9aee0> ├── 'd' --> torch.Size([1, 1]) └── 'noise' --> torch.Size([1, 3, 4, 5]) |
The implement with treetensor API is much simpler and clearer.