Customized Operations For Different Fields¶
Here is another example of the custom operations implemented with both native torch API and treetensor API.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 | import copy import torch import treetensor.torch as ttorch T, B = 3, 4 def with_nativetensor(batch_): mean_b_list = [] even_index_a_list = [] for i in range(len(batch_)): for k, v in batch_[i].items(): if k == 'a': v = v.float() even_index_a_list.append(v[::2]) elif k == 'b': v = v.float() transformed_v = torch.pow(v, 2) + 1.0 mean_b_list.append(transformed_v.mean()) elif k == 'c': for k1, v1 in v.items(): if k1 == 'd': v1 = v1.float() else: print('ignore keys: {}'.format(k1)) else: print('ignore keys: {}'.format(k)) for i in range(len(batch_)): for k in batch_[i].keys(): if k == 'd': batch_[i][k]['noise'] = torch.randn(size=(3, 4, 5)) mean_b = sum(mean_b_list) / len(mean_b_list) even_index_a = torch.stack(even_index_a_list, dim=0) return batch_, mean_b, even_index_a def with_treetensor(batch_): batch_ = [ttorch.tensor(b) for b in batch_] batch_ = ttorch.stack(batch_) batch_ = batch_.float() batch_.b = ttorch.pow(batch_.b, 2) + 1.0 batch_.c.noise = ttorch.randn(size=(B, 3, 4, 5)) mean_b = batch_.b.mean() even_index_a = batch_.a[:, ::2] batch_ = ttorch.split(batch_, split_size_or_sections=1, dim=0) return batch_, mean_b, even_index_a def get_data(): return { 'a': torch.rand(size=(T, 8)), 'b': torch.rand(size=(6,)), 'c': { 'd': torch.randint(0, 10, size=(1,)) } } if __name__ == "__main__": batch = [get_data() for _ in range(B)] batch0, mean0, even_index_a0 = with_nativetensor(copy.deepcopy(batch)) batch1, mean1, even_index_a1 = with_treetensor(copy.deepcopy(batch)) print(batch0) print('\n\n') print(batch1) assert torch.abs(mean0 - mean1) < 1e-6 print('mean0 & mean1:', mean0, mean1) print('\n') assert torch.abs((even_index_a0 - even_index_a1).max()) < 1e-6 print('even_index_a0:', even_index_a0) print('even_index_a1:', even_index_a1) assert len(batch0) == B assert len(batch1) == B assert isinstance(batch1[0], ttorch.Tensor) print(batch1[0].shape) |
The output should be like below, and all the assertions can be passed.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 | [{'a': tensor([[0.8053, 0.7445, 0.0793, 0.3184, 0.2756, 0.8252, 0.3445, 0.0752], [0.7182, 0.5773, 0.0218, 0.0600, 0.9528, 0.9006, 0.4050, 0.3537], [0.5020, 0.0985, 0.2068, 0.3102, 0.8672, 0.8716, 0.3657, 0.0718]]), 'b': tensor([0.5740, 0.5592, 0.7248, 0.7438, 0.5457, 0.7120]), 'c': {'d': tensor([0])}}, {'a': tensor([[0.5308, 0.7016, 0.6871, 0.0403, 0.9346, 0.9358, 0.0203, 0.9230], [0.2047, 0.4180, 0.1184, 0.1707, 0.3309, 0.0488, 0.1547, 0.3050], [0.8321, 0.4029, 0.9024, 0.6745, 0.2132, 0.8717, 0.2416, 0.9677]]), 'b': tensor([0.9478, 0.9531, 0.5402, 0.4392, 0.3258, 0.4395]), 'c': {'d': tensor([5])}}, {'a': tensor([[0.7133, 0.4646, 0.2921, 0.7726, 0.6104, 0.5284, 0.2939, 0.3569], [0.8337, 0.0498, 0.7647, 0.9300, 0.3821, 0.8607, 0.4000, 0.6861], [0.8760, 0.1902, 0.0465, 0.0177, 0.7976, 0.9254, 0.6127, 0.8822]]), 'b': tensor([0.5722, 0.8562, 0.5941, 0.8629, 0.7372, 0.4695]), 'c': {'d': tensor([7])}}, {'a': tensor([[0.1882, 0.1776, 0.0793, 0.3815, 0.4872, 0.0375, 0.6545, 0.2889], [0.0571, 0.8010, 0.0199, 0.3678, 0.5668, 0.4459, 0.4310, 0.0160], [0.6870, 0.2658, 0.9489, 0.0055, 0.5520, 0.0831, 0.1160, 0.0748]]), 'b': tensor([0.3212, 0.9066, 0.8391, 0.7212, 0.3519, 0.4270]), 'c': {'d': tensor([7])}}] (<Tensor 0x7f747b6b64f0> ├── 'a' --> tensor([[[0.8053, 0.7445, 0.0793, 0.3184, 0.2756, 0.8252, 0.3445, 0.0752], │ [0.7182, 0.5773, 0.0218, 0.0600, 0.9528, 0.9006, 0.4050, 0.3537], │ [0.5020, 0.0985, 0.2068, 0.3102, 0.8672, 0.8716, 0.3657, 0.0718]]]) ├── 'b' --> tensor([[1.3295, 1.3127, 1.5254, 1.5532, 1.2978, 1.5069]]) └── 'c' --> <Tensor 0x7f747b6b6550> ├── 'd' --> tensor([[0.]]) └── 'noise' --> tensor([[[[-6.8501e-01, 4.6697e-01, 1.7599e+00, 1.2905e-01, -4.8906e-01], [ 6.8946e-01, 1.7629e-01, -1.3696e+00, 1.1340e-03, 8.4019e-01], [-1.2656e-01, 1.7946e+00, 1.2009e+00, -1.5991e+00, 1.8196e-01], [ 1.2336e+00, 2.5105e-01, 6.7169e-02, -1.6914e+00, -3.5321e-01]], [[-4.8059e-01, -9.3635e-01, -1.2018e+00, 1.1593e+00, -3.0129e-01], [ 1.0256e+00, -1.1489e+00, 4.6913e-01, -3.7563e-01, -9.6251e-01], [-2.3017e+00, 1.1598e+00, 1.7594e+00, -2.4179e-01, -8.0267e-01], [-5.4404e-01, 1.0516e+00, -6.2866e-01, 1.4423e+00, -1.6218e-02]], [[ 1.3444e+00, -6.3613e-01, 1.1578e+00, -3.5533e-01, -2.5801e-01], [-8.2113e-01, -1.8916e+00, 9.3748e-01, -2.0547e+00, 1.5202e+00], [-3.1873e-01, 2.6639e-02, 1.6233e+00, 4.8557e-01, -1.6574e+00], [-8.8723e-02, -2.0476e+00, 1.1273e+00, -1.1973e+00, -9.3022e-01]]]]) , <Tensor 0x7f747b6b6610> ├── 'a' --> tensor([[[0.5308, 0.7016, 0.6871, 0.0403, 0.9346, 0.9358, 0.0203, 0.9230], │ [0.2047, 0.4180, 0.1184, 0.1707, 0.3309, 0.0488, 0.1547, 0.3050], │ [0.8321, 0.4029, 0.9024, 0.6745, 0.2132, 0.8717, 0.2416, 0.9677]]]) ├── 'b' --> tensor([[1.8983, 1.9084, 1.2918, 1.1929, 1.1061, 1.1931]]) └── 'c' --> <Tensor 0x7f747b6b6430> ├── 'd' --> tensor([[5.]]) └── 'noise' --> tensor([[[[ 2.9140e+00, 1.5177e+00, -1.6288e+00, 1.9607e-01, 2.1942e-01], [ 2.3884e+00, -5.4665e-01, 1.5765e-01, 2.4808e-01, -1.6871e+00], [ 5.9751e-01, 1.9206e-01, 1.0923e+00, -7.0509e-01, -2.5146e-01], [ 8.7729e-01, 1.3504e+00, -7.5557e-01, -1.3701e-01, -5.8436e-01]], [[ 5.3209e-01, -1.3542e+00, 1.5011e+00, -8.6212e-01, 1.7817e+00], [ 7.2772e-01, -1.6131e+00, -1.0709e+00, -9.3635e-01, 4.2625e-01], [ 1.2432e+00, 1.5236e+00, -4.9223e-01, 8.4649e-01, 1.4618e+00], [ 4.6363e-01, 6.8346e-01, -8.6416e-01, 1.7699e-01, 1.6948e+00]], [[ 5.6625e-01, 9.9995e-01, 1.0111e-01, 4.8011e-04, 9.1063e-01], [-4.7770e-01, 1.7299e+00, 9.1445e-01, -2.8423e+00, -4.2896e-01], [ 1.0115e+00, -6.0075e-01, 9.9652e-02, 9.1871e-01, 9.7710e-01], [-6.0919e-02, -1.8909e-01, 9.6011e-01, -1.5709e-01, 1.3108e-01]]]]) , <Tensor 0x7f747b6b6670> ├── 'a' --> tensor([[[0.7133, 0.4646, 0.2921, 0.7726, 0.6104, 0.5284, 0.2939, 0.3569], │ [0.8337, 0.0498, 0.7647, 0.9300, 0.3821, 0.8607, 0.4000, 0.6861], │ [0.8760, 0.1902, 0.0465, 0.0177, 0.7976, 0.9254, 0.6127, 0.8822]]]) ├── 'b' --> tensor([[1.3274, 1.7331, 1.3530, 1.7447, 1.5434, 1.2204]]) └── 'c' --> <Tensor 0x7f747b6b6640> ├── 'd' --> tensor([[7.]]) └── 'noise' --> tensor([[[[-0.3763, -0.7753, -0.4001, 0.0321, 0.2464], [-0.7008, 0.1068, -0.1251, 2.1447, 0.1137], [ 1.2222, -0.0875, -1.0401, -0.3119, 1.0039], [-0.4405, 0.6470, 0.9651, -0.2678, -0.9768]], [[-0.7289, 0.9036, 0.0264, -0.0244, 0.0816], [-0.7036, -0.0935, 0.7025, -1.4874, 0.4035], [ 0.6441, 0.4094, -0.3827, 0.0542, -1.2560], [-1.2502, 0.4918, 0.4921, 0.9951, -1.8664]], [[-1.5643, -1.6597, 0.0849, 1.6072, -2.0404], [-0.0367, -0.0865, 0.7401, 1.2236, 0.9519], [-0.5182, -0.0128, -0.1113, 0.6622, 0.3690], [-0.0197, -0.1582, -1.0284, -1.7841, 0.0648]]]]) , <Tensor 0x7f747b6b66d0> ├── 'a' --> tensor([[[0.1882, 0.1776, 0.0793, 0.3815, 0.4872, 0.0375, 0.6545, 0.2889], │ [0.0571, 0.8010, 0.0199, 0.3678, 0.5668, 0.4459, 0.4310, 0.0160], │ [0.6870, 0.2658, 0.9489, 0.0055, 0.5520, 0.0831, 0.1160, 0.0748]]]) ├── 'b' --> tensor([[1.1032, 1.8219, 1.7040, 1.5201, 1.1238, 1.1823]]) └── 'c' --> <Tensor 0x7f747b6b66a0> ├── 'd' --> tensor([[7.]]) └── 'noise' --> tensor([[[[ 1.0370, 1.3408, -0.4759, 0.3349, 0.9377], [ 0.1157, -0.8470, -0.0512, -0.3285, -0.6335], [ 0.6802, 0.6821, -0.2242, 0.5975, -0.5512], [ 1.8505, -0.0149, 0.4083, -0.9496, -1.5367]], [[-1.0895, 0.5305, -0.4091, 0.1896, -0.3201], [-1.7025, 1.3947, -0.5616, -0.7856, 1.7454], [-1.3787, -1.3461, -0.0710, 0.9998, -1.4652], [-0.6107, 1.1829, 0.9042, 1.1213, -0.4547]], [[-2.7062, -0.9415, -0.3098, 0.5568, 0.2289], [ 0.7780, -0.8436, 0.0879, -1.7352, -0.1688], [-1.3940, -0.8162, 1.5604, -1.2692, 1.2674], [ 1.1491, 1.8704, -0.5248, -1.1473, -0.4696]]]]) ) mean0 & mean1: tensor(1.4372) tensor(1.4372) even_index_a0: tensor([[[0.8053, 0.7445, 0.0793, 0.3184, 0.2756, 0.8252, 0.3445, 0.0752], [0.5020, 0.0985, 0.2068, 0.3102, 0.8672, 0.8716, 0.3657, 0.0718]], [[0.5308, 0.7016, 0.6871, 0.0403, 0.9346, 0.9358, 0.0203, 0.9230], [0.8321, 0.4029, 0.9024, 0.6745, 0.2132, 0.8717, 0.2416, 0.9677]], [[0.7133, 0.4646, 0.2921, 0.7726, 0.6104, 0.5284, 0.2939, 0.3569], [0.8760, 0.1902, 0.0465, 0.0177, 0.7976, 0.9254, 0.6127, 0.8822]], [[0.1882, 0.1776, 0.0793, 0.3815, 0.4872, 0.0375, 0.6545, 0.2889], [0.6870, 0.2658, 0.9489, 0.0055, 0.5520, 0.0831, 0.1160, 0.0748]]]) even_index_a1: tensor([[[0.8053, 0.7445, 0.0793, 0.3184, 0.2756, 0.8252, 0.3445, 0.0752], [0.5020, 0.0985, 0.2068, 0.3102, 0.8672, 0.8716, 0.3657, 0.0718]], [[0.5308, 0.7016, 0.6871, 0.0403, 0.9346, 0.9358, 0.0203, 0.9230], [0.8321, 0.4029, 0.9024, 0.6745, 0.2132, 0.8717, 0.2416, 0.9677]], [[0.7133, 0.4646, 0.2921, 0.7726, 0.6104, 0.5284, 0.2939, 0.3569], [0.8760, 0.1902, 0.0465, 0.0177, 0.7976, 0.9254, 0.6127, 0.8822]], [[0.1882, 0.1776, 0.0793, 0.3815, 0.4872, 0.0375, 0.6545, 0.2889], [0.6870, 0.2658, 0.9489, 0.0055, 0.5520, 0.0831, 0.1160, 0.0748]]]) <Size 0x7f74e7cee3d0> ├── 'a' --> torch.Size([1, 3, 8]) ├── 'b' --> torch.Size([1, 6]) └── 'c' --> <Size 0x7f747b7318b0> ├── 'd' --> torch.Size([1, 1]) └── 'noise' --> torch.Size([1, 3, 4, 5]) |
The implement with treetensor API is much simpler and clearer.