Customized Operations For Different Fields¶
Here is another example of the custom operations implemented with both native torch API and treetensor API.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 | import copy import torch import treetensor.torch as ttorch T, B = 3, 4 def with_nativetensor(batch_): mean_b_list = [] even_index_a_list = [] for i in range(len(batch_)): for k, v in batch_[i].items(): if k == 'a': v = v.float() even_index_a_list.append(v[::2]) elif k == 'b': v = v.float() transformed_v = torch.pow(v, 2) + 1.0 mean_b_list.append(transformed_v.mean()) elif k == 'c': for k1, v1 in v.items(): if k1 == 'd': v1 = v1.float() else: print('ignore keys: {}'.format(k1)) else: print('ignore keys: {}'.format(k)) for i in range(len(batch_)): for k in batch_[i].keys(): if k == 'd': batch_[i][k]['noise'] = torch.randn(size=(3, 4, 5)) mean_b = sum(mean_b_list) / len(mean_b_list) even_index_a = torch.stack(even_index_a_list, dim=0) return batch_, mean_b, even_index_a def with_treetensor(batch_): batch_ = [ttorch.tensor(b) for b in batch_] batch_ = ttorch.stack(batch_) batch_ = batch_.float() batch_.b = ttorch.pow(batch_.b, 2) + 1.0 batch_.c.noise = ttorch.randn(size=(B, 3, 4, 5)) mean_b = batch_.b.mean() even_index_a = batch_.a[:, ::2] batch_ = ttorch.split(batch_, split_size_or_sections=1, dim=0) return batch_, mean_b, even_index_a def get_data(): return { 'a': torch.rand(size=(T, 8)), 'b': torch.rand(size=(6,)), 'c': { 'd': torch.randint(0, 10, size=(1,)) } } if __name__ == "__main__": batch = [get_data() for _ in range(B)] batch0, mean0, even_index_a0 = with_nativetensor(copy.deepcopy(batch)) batch1, mean1, even_index_a1 = with_treetensor(copy.deepcopy(batch)) print(batch0) print('\n\n') print(batch1) assert torch.abs(mean0 - mean1) < 1e-6 print('mean0 & mean1:', mean0, mean1) print('\n') assert torch.abs((even_index_a0 - even_index_a1).max()) < 1e-6 print('even_index_a0:', even_index_a0) print('even_index_a1:', even_index_a1) assert len(batch0) == B assert len(batch1) == B assert isinstance(batch1[0], ttorch.Tensor) print(batch1[0].shape) |
The output should be like below, and all the assertions can be passed.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 | [{'a': tensor([[0.8665, 0.9088, 0.7569, 0.6825, 0.1302, 0.6438, 0.3450, 0.2007], [0.4946, 0.0328, 0.6505, 0.2378, 0.2702, 0.2210, 0.3800, 0.7847], [0.2095, 0.2963, 0.7101, 0.6690, 0.6303, 0.2929, 0.1435, 0.6998]]), 'b': tensor([0.1214, 0.0274, 0.9256, 0.8215, 0.8466, 0.1902]), 'c': {'d': tensor([7])}}, {'a': tensor([[0.2431, 0.1269, 0.3689, 0.8099, 0.3914, 0.3421, 0.4355, 0.2371], [0.8309, 0.7894, 0.9701, 0.1638, 0.1176, 0.2241, 0.2241, 0.2522], [0.9711, 0.0218, 0.9972, 0.7963, 0.6596, 0.9119, 0.1892, 0.8936]]), 'b': tensor([0.8208, 0.6126, 0.4234, 0.9650, 0.0900, 0.2385]), 'c': {'d': tensor([9])}}, {'a': tensor([[0.5688, 0.3217, 0.8575, 0.8436, 0.1835, 0.8896, 0.0210, 0.5455], [0.9954, 0.7660, 0.1734, 0.3155, 0.3065, 0.6172, 0.0347, 0.4582], [0.4729, 0.9954, 0.5076, 0.7531, 0.3406, 0.5855, 0.5058, 0.1206]]), 'b': tensor([0.5111, 0.5044, 0.6388, 0.7865, 0.7909, 0.2100]), 'c': {'d': tensor([4])}}, {'a': tensor([[0.8107, 0.8194, 0.2188, 0.4298, 0.1951, 0.0313, 0.6070, 0.7073], [0.8138, 0.7121, 0.9815, 0.9346, 0.2885, 0.0585, 0.1913, 0.3806], [0.8669, 0.1006, 0.4896, 0.3018, 0.6727, 0.4490, 0.5455, 0.3220]]), 'b': tensor([0.2372, 0.8830, 0.6201, 0.6245, 0.6132, 0.4773]), 'c': {'d': tensor([7])}}] (<Tensor 0x7f237c162580> ├── 'a' --> tensor([[[0.8665, 0.9088, 0.7569, 0.6825, 0.1302, 0.6438, 0.3450, 0.2007], │ [0.4946, 0.0328, 0.6505, 0.2378, 0.2702, 0.2210, 0.3800, 0.7847], │ [0.2095, 0.2963, 0.7101, 0.6690, 0.6303, 0.2929, 0.1435, 0.6998]]]) ├── 'b' --> tensor([[1.0147, 1.0007, 1.8567, 1.6749, 1.7167, 1.0362]]) └── 'c' --> <Tensor 0x7f237c1625e0> ├── 'd' --> tensor([[7.]]) └── 'noise' --> tensor([[[[-0.3026, 0.3336, 0.7056, -1.2039, 1.7792], [-0.6242, 0.0618, -0.5512, -0.1455, 1.1428], [ 1.1051, 2.3970, 0.9906, 0.4791, 0.2180], [-1.5307, -0.6763, 0.0491, -1.1093, 0.3926]], [[ 1.2435, 0.2536, -0.5028, -0.8716, 1.1416], [-0.2970, 0.1877, 0.4230, -0.4240, -0.4371], [-0.8912, 1.6403, 0.2320, -1.6659, -0.3558], [ 0.6275, 0.2561, -1.5862, 2.0024, -0.9043]], [[ 0.7452, -1.8395, 0.3822, -0.9894, -1.5323], [-0.0126, 0.0525, -0.0981, -0.8024, 1.4610], [-1.0592, 0.4134, 0.6497, 0.1778, -0.3963], [ 0.2844, 1.7031, 0.6817, -1.0866, 0.7724]]]]) , <Tensor 0x7f237c1626a0> ├── 'a' --> tensor([[[0.2431, 0.1269, 0.3689, 0.8099, 0.3914, 0.3421, 0.4355, 0.2371], │ [0.8309, 0.7894, 0.9701, 0.1638, 0.1176, 0.2241, 0.2241, 0.2522], │ [0.9711, 0.0218, 0.9972, 0.7963, 0.6596, 0.9119, 0.1892, 0.8936]]]) ├── 'b' --> tensor([[1.6738, 1.3753, 1.1793, 1.9312, 1.0081, 1.0569]]) └── 'c' --> <Tensor 0x7f237c1624c0> ├── 'd' --> tensor([[9.]]) └── 'noise' --> tensor([[[[-2.5049e+00, 2.0461e-01, -6.6002e-01, 1.8798e-01, 8.0707e-01], [ 5.5676e-01, 1.0060e+00, -9.4148e-01, -1.4495e+00, -6.5691e-01], [ 2.1647e+00, -8.5329e-01, -1.3162e+00, 3.0015e-01, -4.9607e-01], [-8.6826e-01, -1.3946e+00, 1.2714e+00, -3.2209e-01, 9.2902e-01]], [[ 1.5400e+00, -1.1072e+00, 4.0012e-01, 1.4014e-04, 3.5665e-01], [ 6.9737e-01, 4.0163e-01, -1.3559e+00, 2.0701e-01, 6.3983e-01], [ 1.3846e-01, -1.3418e+00, -4.2567e-01, 6.8998e-01, -3.8492e-01], [ 3.6196e-01, -1.5220e-01, 6.6868e-01, 2.6043e-01, -4.0372e-02]], [[ 6.1624e-01, -4.3484e-01, -1.7773e+00, 3.4712e-01, 2.7517e-01], [ 1.1534e+00, -9.2221e-01, 1.7659e+00, 1.2852e+00, -2.8716e-01], [ 5.9528e-02, -1.1638e+00, -2.4237e-01, -1.0312e+00, 1.2858e+00], [ 3.6574e-01, -6.7653e-02, 4.4200e-01, -7.2800e-01, 2.7208e-01]]]]) , <Tensor 0x7f237c162700> ├── 'a' --> tensor([[[0.5688, 0.3217, 0.8575, 0.8436, 0.1835, 0.8896, 0.0210, 0.5455], │ [0.9954, 0.7660, 0.1734, 0.3155, 0.3065, 0.6172, 0.0347, 0.4582], │ [0.4729, 0.9954, 0.5076, 0.7531, 0.3406, 0.5855, 0.5058, 0.1206]]]) ├── 'b' --> tensor([[1.2612, 1.2544, 1.4080, 1.6186, 1.6255, 1.0441]]) └── 'c' --> <Tensor 0x7f237c1626d0> ├── 'd' --> tensor([[4.]]) └── 'noise' --> tensor([[[[ 0.1353, -1.8480, -1.6621, 0.5799, -0.2800], [ 1.4138, -0.0113, 0.0177, 0.8198, 0.7301], [-0.6666, 0.0304, -0.6770, 0.5411, 0.5198], [ 1.5511, -0.1262, -0.6161, -1.1997, -1.5366]], [[ 0.1982, 1.4697, 1.3387, 0.8497, -0.5123], [ 0.4804, -0.9265, 0.8640, 0.4346, 0.1220], [ 2.0205, 0.1645, -0.3352, -1.1066, -0.7150], [ 0.5400, -0.4626, -1.1799, 0.5766, -0.6324]], [[ 2.1123, 1.2872, 1.4614, 0.7326, 1.3558], [-0.1035, -0.5072, -1.4503, -1.8749, 1.9906], [-0.0079, 0.3056, 1.6099, 0.0699, 0.6805], [-0.8504, -0.3356, 0.7516, 1.1599, 0.3744]]]]) , <Tensor 0x7f237c162760> ├── 'a' --> tensor([[[0.8107, 0.8194, 0.2188, 0.4298, 0.1951, 0.0313, 0.6070, 0.7073], │ [0.8138, 0.7121, 0.9815, 0.9346, 0.2885, 0.0585, 0.1913, 0.3806], │ [0.8669, 0.1006, 0.4896, 0.3018, 0.6727, 0.4490, 0.5455, 0.3220]]]) ├── 'b' --> tensor([[1.0563, 1.7797, 1.3846, 1.3900, 1.3760, 1.2278]]) └── 'c' --> <Tensor 0x7f237c162730> ├── 'd' --> tensor([[7.]]) └── 'noise' --> tensor([[[[ 1.8565, 0.9559, -2.7744, 0.9325, -0.5216], [ 0.4288, -1.1126, -0.7462, 1.5362, 1.4256], [-0.9266, -0.8704, -1.3405, -0.8918, -0.2227], [-0.5473, -1.3886, -0.3050, -1.6887, -1.3981]], [[-0.9450, 0.9875, -0.8314, -0.3334, 0.5684], [ 0.5428, 0.3821, -0.6291, 0.8887, 0.6346], [-0.4497, -0.5102, 0.5783, -0.1734, 1.0046], [-0.5322, 1.5718, 0.1697, -0.4856, -0.5880]], [[-0.6278, 1.3811, -0.0196, 0.8296, 0.6133], [-0.7989, 0.4975, -0.6005, 1.6146, 0.4410], [-0.2954, -0.7136, -0.2561, -0.0071, 0.8017], [ 0.4188, -1.1521, -0.9709, -1.2668, -0.3882]]]]) ) mean0 & mean1: tensor(1.3729) tensor(1.3729) even_index_a0: tensor([[[0.8665, 0.9088, 0.7569, 0.6825, 0.1302, 0.6438, 0.3450, 0.2007], [0.2095, 0.2963, 0.7101, 0.6690, 0.6303, 0.2929, 0.1435, 0.6998]], [[0.2431, 0.1269, 0.3689, 0.8099, 0.3914, 0.3421, 0.4355, 0.2371], [0.9711, 0.0218, 0.9972, 0.7963, 0.6596, 0.9119, 0.1892, 0.8936]], [[0.5688, 0.3217, 0.8575, 0.8436, 0.1835, 0.8896, 0.0210, 0.5455], [0.4729, 0.9954, 0.5076, 0.7531, 0.3406, 0.5855, 0.5058, 0.1206]], [[0.8107, 0.8194, 0.2188, 0.4298, 0.1951, 0.0313, 0.6070, 0.7073], [0.8669, 0.1006, 0.4896, 0.3018, 0.6727, 0.4490, 0.5455, 0.3220]]]) even_index_a1: tensor([[[0.8665, 0.9088, 0.7569, 0.6825, 0.1302, 0.6438, 0.3450, 0.2007], [0.2095, 0.2963, 0.7101, 0.6690, 0.6303, 0.2929, 0.1435, 0.6998]], [[0.2431, 0.1269, 0.3689, 0.8099, 0.3914, 0.3421, 0.4355, 0.2371], [0.9711, 0.0218, 0.9972, 0.7963, 0.6596, 0.9119, 0.1892, 0.8936]], [[0.5688, 0.3217, 0.8575, 0.8436, 0.1835, 0.8896, 0.0210, 0.5455], [0.4729, 0.9954, 0.5076, 0.7531, 0.3406, 0.5855, 0.5058, 0.1206]], [[0.8107, 0.8194, 0.2188, 0.4298, 0.1951, 0.0313, 0.6070, 0.7073], [0.8669, 0.1006, 0.4896, 0.3018, 0.6727, 0.4490, 0.5455, 0.3220]]]) <Size 0x7f23e202e430> ├── 'a' --> torch.Size([1, 3, 8]) ├── 'b' --> torch.Size([1, 6]) └── 'c' --> <Size 0x7f237c1ed940> ├── 'd' --> torch.Size([1, 1]) └── 'noise' --> torch.Size([1, 3, 4, 5]) |
The implement with treetensor API is much simpler and clearer.