Customized Operations For Different Fields¶
Here is another example of the custom operations implemented with both native torch API and treetensor API.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 | import copy import torch import treetensor.torch as ttorch T, B = 3, 4 def with_nativetensor(batch_): mean_b_list = [] even_index_a_list = [] for i in range(len(batch_)): for k, v in batch_[i].items(): if k == 'a': v = v.float() even_index_a_list.append(v[::2]) elif k == 'b': v = v.float() transformed_v = torch.pow(v, 2) + 1.0 mean_b_list.append(transformed_v.mean()) elif k == 'c': for k1, v1 in v.items(): if k1 == 'd': v1 = v1.float() else: print('ignore keys: {}'.format(k1)) else: print('ignore keys: {}'.format(k)) for i in range(len(batch_)): for k in batch_[i].keys(): if k == 'd': batch_[i][k]['noise'] = torch.randn(size=(3, 4, 5)) mean_b = sum(mean_b_list) / len(mean_b_list) even_index_a = torch.stack(even_index_a_list, dim=0) return batch_, mean_b, even_index_a def with_treetensor(batch_): batch_ = [ttorch.tensor(b) for b in batch_] batch_ = ttorch.stack(batch_) batch_ = batch_.float() batch_.b = ttorch.pow(batch_.b, 2) + 1.0 batch_.c.noise = ttorch.randn(size=(B, 3, 4, 5)) mean_b = batch_.b.mean() even_index_a = batch_.a[:, ::2] batch_ = ttorch.split(batch_, split_size_or_sections=1, dim=0) return batch_, mean_b, even_index_a def get_data(): return { 'a': torch.rand(size=(T, 8)), 'b': torch.rand(size=(6,)), 'c': { 'd': torch.randint(0, 10, size=(1,)) } } if __name__ == "__main__": batch = [get_data() for _ in range(B)] batch0, mean0, even_index_a0 = with_nativetensor(copy.deepcopy(batch)) batch1, mean1, even_index_a1 = with_treetensor(copy.deepcopy(batch)) print(batch0) print('\n\n') print(batch1) assert torch.abs(mean0 - mean1) < 1e-6 print('mean0 & mean1:', mean0, mean1) print('\n') assert torch.abs((even_index_a0 - even_index_a1).max()) < 1e-6 print('even_index_a0:', even_index_a0) print('even_index_a1:', even_index_a1) assert len(batch0) == B assert len(batch1) == B assert isinstance(batch1[0], ttorch.Tensor) print(batch1[0].shape) |
The output should be like below, and all the assertions can be passed.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 | [{'a': tensor([[0.0730, 0.2810, 0.7642, 0.3265, 0.8006, 0.1291, 0.4410, 0.5144], [0.2621, 0.1805, 0.5549, 0.1331, 0.6024, 0.3018, 0.3112, 0.3647], [0.6566, 0.9296, 0.2867, 0.5219, 0.4429, 0.5960, 0.9051, 0.1431]]), 'b': tensor([0.5313, 0.5180, 0.4947, 0.9723, 0.5365, 0.1009]), 'c': {'d': tensor([6])}}, {'a': tensor([[0.8982, 0.5922, 0.7774, 0.2902, 0.1801, 0.7970, 0.9584, 0.3310], [0.3842, 0.6901, 0.4248, 0.4139, 0.4342, 0.4262, 0.6012, 0.4525], [0.2149, 0.8399, 0.4481, 0.6800, 0.4614, 0.9342, 0.2913, 0.7659]]), 'b': tensor([0.6653, 0.9887, 0.7088, 0.3472, 0.3131, 0.4827]), 'c': {'d': tensor([3])}}, {'a': tensor([[0.8102, 0.6209, 0.8538, 0.5519, 0.0497, 0.8272, 0.4520, 0.4337], [0.9679, 0.5910, 0.6493, 0.3332, 0.1804, 0.4862, 0.7153, 0.6898], [0.0253, 0.8591, 0.0775, 0.0992, 0.3756, 0.9941, 0.9854, 0.1591]]), 'b': tensor([0.0182, 0.7272, 0.7945, 0.7043, 0.3251, 0.6554]), 'c': {'d': tensor([3])}}, {'a': tensor([[0.9523, 0.1690, 0.1028, 0.6150, 0.6532, 0.8465, 0.0777, 0.1573], [0.1165, 0.3488, 0.7171, 0.5685, 0.3000, 0.3334, 0.8046, 0.0748], [0.8846, 0.9219, 0.5501, 0.1917, 0.0404, 0.4133, 0.0725, 0.9250]]), 'b': tensor([0.8136, 0.9406, 0.3927, 0.1013, 0.9801, 0.3492]), 'c': {'d': tensor([0])}}] (<Tensor 0x7fc906fdac10> ├── 'a' --> tensor([[[0.0730, 0.2810, 0.7642, 0.3265, 0.8006, 0.1291, 0.4410, 0.5144], │ [0.2621, 0.1805, 0.5549, 0.1331, 0.6024, 0.3018, 0.3112, 0.3647], │ [0.6566, 0.9296, 0.2867, 0.5219, 0.4429, 0.5960, 0.9051, 0.1431]]]) ├── 'b' --> tensor([[1.2822, 1.2684, 1.2447, 1.9455, 1.2878, 1.0102]]) └── 'c' --> <Tensor 0x7fc906fdac70> ├── 'd' --> tensor([[6.]]) └── 'noise' --> tensor([[[[-2.0164e+00, 1.6905e+00, 1.6358e+00, 1.1691e+00, 5.7766e-01], [ 1.3527e+00, -1.0914e+00, -4.6705e-02, 2.0240e+00, -1.8224e+00], [ 5.2425e-01, -1.7212e-01, -3.2210e-01, 1.3389e+00, -1.0577e+00], [-1.3967e-02, 1.0815e+00, 5.0475e-01, -5.8208e-01, 6.0741e-02]], [[ 4.3895e-01, 1.3226e+00, -6.5097e-01, 1.1502e-01, 4.6786e-01], [-8.2064e-04, 9.1445e-02, -4.7237e-01, -2.4750e-01, -1.9809e-01], [ 7.3672e-01, -8.6755e-01, 3.4016e-02, -4.4958e-02, -1.8747e+00], [-5.5719e-01, 5.0298e-01, -1.3750e+00, 8.4211e-01, 1.0541e+00]], [[-1.9202e-01, 7.3596e-01, -1.5025e+00, 6.4057e-01, -1.2579e+00], [-7.7307e-01, -6.7083e-01, 1.7407e+00, -1.0105e-01, 2.7708e-01], [-4.9228e-01, 1.7454e+00, 2.3372e-01, -1.0487e+00, -1.8510e+00], [ 2.3535e+00, -1.7593e+00, 2.1347e-01, -6.7427e-01, -1.1731e+00]]]]) , <Tensor 0x7fc906fdad30> ├── 'a' --> tensor([[[0.8982, 0.5922, 0.7774, 0.2902, 0.1801, 0.7970, 0.9584, 0.3310], │ [0.3842, 0.6901, 0.4248, 0.4139, 0.4342, 0.4262, 0.6012, 0.4525], │ [0.2149, 0.8399, 0.4481, 0.6800, 0.4614, 0.9342, 0.2913, 0.7659]]]) ├── 'b' --> tensor([[1.4426, 1.9774, 1.5024, 1.1205, 1.0980, 1.2330]]) └── 'c' --> <Tensor 0x7fc906fdaac0> ├── 'd' --> tensor([[3.]]) └── 'noise' --> tensor([[[[-1.5450, -1.6905, 0.4312, -0.2604, -0.4865], [ 0.5222, -0.4960, -0.7188, -0.7926, -0.6251], [ 0.8556, 1.6708, -0.5835, -0.1483, -0.5698], [ 1.6547, -0.7003, 1.0590, -2.4264, 0.5141]], [[-0.0606, 0.1895, -1.0684, -0.7247, 1.3095], [ 2.5292, -1.1486, 0.2688, 0.3017, 0.3331], [-0.7841, 1.5898, 1.6363, -0.7597, 0.2299], [ 0.5379, 0.1701, 0.1874, -0.1444, 1.3577]], [[-0.3337, -0.1685, -0.8499, -1.2491, 1.3864], [-0.5003, 0.8336, -1.8690, -1.2475, 0.2396], [-0.1194, 1.3889, -0.6624, -1.2659, -0.1942], [-0.1443, -0.8676, 0.7845, 0.0346, 0.8877]]]]) , <Tensor 0x7fc906fdad90> ├── 'a' --> tensor([[[0.8102, 0.6209, 0.8538, 0.5519, 0.0497, 0.8272, 0.4520, 0.4337], │ [0.9679, 0.5910, 0.6493, 0.3332, 0.1804, 0.4862, 0.7153, 0.6898], │ [0.0253, 0.8591, 0.0775, 0.0992, 0.3756, 0.9941, 0.9854, 0.1591]]]) ├── 'b' --> tensor([[1.0003, 1.5289, 1.6312, 1.4960, 1.1057, 1.4296]]) └── 'c' --> <Tensor 0x7fc906fdad60> ├── 'd' --> tensor([[3.]]) └── 'noise' --> tensor([[[[-0.6307, 0.0423, -0.7410, -0.3917, 0.1529], [ 0.6046, -0.0461, 0.7372, -0.3655, 0.2533], [ 1.0064, -0.9557, 1.0176, 0.8430, 1.5863], [ 1.4941, 1.6549, 0.3781, -0.2377, -3.0946]], [[-1.3730, 1.1420, 2.0576, 0.6310, -1.4223], [-1.6611, 0.0794, 1.6946, 1.6277, 0.8059], [ 1.7959, -0.4943, -0.2074, 2.0674, -0.3033], [-1.2095, -0.1526, -0.1072, -1.4623, 0.3611]], [[-1.4317, 0.2062, 0.7165, -1.0422, 1.3787], [ 0.0730, 0.4072, -1.7002, 0.1608, 0.2913], [-0.3936, -1.0239, -0.2620, -1.4381, 1.4095], [ 1.1906, 0.3589, 2.8231, 0.8444, 0.6052]]]]) , <Tensor 0x7fc906fdadf0> ├── 'a' --> tensor([[[0.9523, 0.1690, 0.1028, 0.6150, 0.6532, 0.8465, 0.0777, 0.1573], │ [0.1165, 0.3488, 0.7171, 0.5685, 0.3000, 0.3334, 0.8046, 0.0748], │ [0.8846, 0.9219, 0.5501, 0.1917, 0.0404, 0.4133, 0.0725, 0.9250]]]) ├── 'b' --> tensor([[1.6619, 1.8847, 1.1542, 1.0103, 1.9607, 1.1220]]) └── 'c' --> <Tensor 0x7fc906fdadc0> ├── 'd' --> tensor([[0.]]) └── 'noise' --> tensor([[[[ 0.0185, -0.4831, 0.5030, 1.8792, -0.5789], [ 0.3683, -0.1999, 0.0887, 0.1022, -0.8752], [ 0.6978, -0.5402, 1.5988, -0.2379, 1.3199], [-1.1288, -1.6968, -1.5964, -1.8739, -0.6625]], [[ 1.0150, -0.8857, -1.3007, -0.2133, -1.0358], [-0.3572, -0.3062, 0.4154, 0.2206, 0.0947], [-1.3657, -1.2397, 0.2615, -1.0847, 1.7383], [-0.1126, -0.3927, 0.5227, 1.6310, 0.2789]], [[-0.5726, 0.9945, 0.6983, 0.0200, -0.8776], [-0.8363, -0.1905, -1.7081, 0.6175, 1.2286], [-0.5026, 0.7750, -0.1982, -1.6114, -0.2191], [-0.8025, -1.7732, 0.6458, 0.8507, 0.2071]]]]) ) mean0 & mean1: tensor(1.3916) tensor(1.3916) even_index_a0: tensor([[[0.0730, 0.2810, 0.7642, 0.3265, 0.8006, 0.1291, 0.4410, 0.5144], [0.6566, 0.9296, 0.2867, 0.5219, 0.4429, 0.5960, 0.9051, 0.1431]], [[0.8982, 0.5922, 0.7774, 0.2902, 0.1801, 0.7970, 0.9584, 0.3310], [0.2149, 0.8399, 0.4481, 0.6800, 0.4614, 0.9342, 0.2913, 0.7659]], [[0.8102, 0.6209, 0.8538, 0.5519, 0.0497, 0.8272, 0.4520, 0.4337], [0.0253, 0.8591, 0.0775, 0.0992, 0.3756, 0.9941, 0.9854, 0.1591]], [[0.9523, 0.1690, 0.1028, 0.6150, 0.6532, 0.8465, 0.0777, 0.1573], [0.8846, 0.9219, 0.5501, 0.1917, 0.0404, 0.4133, 0.0725, 0.9250]]]) even_index_a1: tensor([[[0.0730, 0.2810, 0.7642, 0.3265, 0.8006, 0.1291, 0.4410, 0.5144], [0.6566, 0.9296, 0.2867, 0.5219, 0.4429, 0.5960, 0.9051, 0.1431]], [[0.8982, 0.5922, 0.7774, 0.2902, 0.1801, 0.7970, 0.9584, 0.3310], [0.2149, 0.8399, 0.4481, 0.6800, 0.4614, 0.9342, 0.2913, 0.7659]], [[0.8102, 0.6209, 0.8538, 0.5519, 0.0497, 0.8272, 0.4520, 0.4337], [0.0253, 0.8591, 0.0775, 0.0992, 0.3756, 0.9941, 0.9854, 0.1591]], [[0.9523, 0.1690, 0.1028, 0.6150, 0.6532, 0.8465, 0.0777, 0.1573], [0.8846, 0.9219, 0.5501, 0.1917, 0.0404, 0.4133, 0.0725, 0.9250]]]) <Size 0x7fc906fdab20> ├── 'a' --> torch.Size([1, 3, 8]) ├── 'b' --> torch.Size([1, 6]) └── 'c' --> <Size 0x7fc90719c7c0> ├── 'd' --> torch.Size([1, 1]) └── 'noise' --> torch.Size([1, 3, 4, 5]) |
The implement with treetensor API is much simpler and clearer.