Customized Operations For Different Fields¶
Here is another example of the custom operations implemented with both native torch API and treetensor API.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 | import copy import torch import treetensor.torch as ttorch T, B = 3, 4 def with_nativetensor(batch_): mean_b_list = [] even_index_a_list = [] for i in range(len(batch_)): for k, v in batch_[i].items(): if k == 'a': v = v.float() even_index_a_list.append(v[::2]) elif k == 'b': v = v.float() transformed_v = torch.pow(v, 2) + 1.0 mean_b_list.append(transformed_v.mean()) elif k == 'c': for k1, v1 in v.items(): if k1 == 'd': v1 = v1.float() else: print('ignore keys: {}'.format(k1)) else: print('ignore keys: {}'.format(k)) for i in range(len(batch_)): for k in batch_[i].keys(): if k == 'd': batch_[i][k]['noise'] = torch.randn(size=(3, 4, 5)) mean_b = sum(mean_b_list) / len(mean_b_list) even_index_a = torch.stack(even_index_a_list, dim=0) return batch_, mean_b, even_index_a def with_treetensor(batch_): batch_ = [ttorch.tensor(b) for b in batch_] batch_ = ttorch.stack(batch_) batch_ = batch_.float() batch_.b = ttorch.pow(batch_.b, 2) + 1.0 batch_.c.noise = ttorch.randn(size=(B, 3, 4, 5)) mean_b = batch_.b.mean() even_index_a = batch_.a[:, ::2] batch_ = ttorch.split(batch_, split_size_or_sections=1, dim=0) return batch_, mean_b, even_index_a def get_data(): return { 'a': torch.rand(size=(T, 8)), 'b': torch.rand(size=(6,)), 'c': { 'd': torch.randint(0, 10, size=(1,)) } } if __name__ == "__main__": batch = [get_data() for _ in range(B)] batch0, mean0, even_index_a0 = with_nativetensor(copy.deepcopy(batch)) batch1, mean1, even_index_a1 = with_treetensor(copy.deepcopy(batch)) print(batch0) print('\n\n') print(batch1) assert torch.abs(mean0 - mean1) < 1e-6 print('mean0 & mean1:', mean0, mean1) print('\n') assert torch.abs((even_index_a0 - even_index_a1).max()) < 1e-6 print('even_index_a0:', even_index_a0) print('even_index_a1:', even_index_a1) assert len(batch0) == B assert len(batch1) == B assert isinstance(batch1[0], ttorch.Tensor) print(batch1[0].shape) |
The output should be like below, and all the assertions can be passed.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 | [{'a': tensor([[0.3966, 0.8244, 0.6097, 0.9855, 0.3798, 0.1565, 0.0098, 0.6476], [0.7319, 0.6274, 0.9433, 0.5831, 0.9184, 0.4625, 0.5958, 0.0988], [0.2903, 0.5619, 0.5450, 0.3228, 0.9717, 0.0112, 0.0610, 0.6176]]), 'b': tensor([0.1893, 0.4007, 0.8237, 0.2230, 0.6610, 0.9872]), 'c': {'d': tensor([4])}}, {'a': tensor([[0.3677, 0.3901, 0.2094, 0.3252, 0.2418, 0.2013, 0.3877, 0.8925], [0.9920, 0.1467, 0.1470, 0.9469, 0.8412, 0.6267, 0.9276, 0.3556], [0.9186, 0.9478, 0.0283, 0.2015, 0.2140, 0.2504, 0.2546, 0.5116]]), 'b': tensor([0.8362, 0.4179, 0.0113, 0.6096, 0.2867, 0.7525]), 'c': {'d': tensor([9])}}, {'a': tensor([[0.8125, 0.0259, 0.0690, 0.0772, 0.8263, 0.6004, 0.4874, 0.0964], [0.3606, 0.3776, 0.7780, 0.1085, 0.2932, 0.3797, 0.1329, 0.9739], [0.2839, 0.3744, 0.7001, 0.6329, 0.1951, 0.7454, 0.9891, 0.1892]]), 'b': tensor([0.7364, 0.5097, 0.4577, 0.9647, 0.4244, 0.2281]), 'c': {'d': tensor([0])}}, {'a': tensor([[0.1986, 0.1622, 0.4137, 0.1405, 0.9052, 0.2589, 0.3613, 0.9998], [0.1555, 0.0922, 0.4355, 0.9130, 0.4442, 0.6955, 0.8678, 0.7797], [0.6549, 0.6488, 0.5300, 0.4422, 0.9739, 0.3584, 0.5158, 0.0156]]), 'b': tensor([0.5408, 0.1078, 0.5927, 0.0799, 0.6125, 0.4272]), 'c': {'d': tensor([0])}}] (<Tensor 0x7f9185d152e0> ├── 'a' --> tensor([[[0.3966, 0.8244, 0.6097, 0.9855, 0.3798, 0.1565, 0.0098, 0.6476], │ [0.7319, 0.6274, 0.9433, 0.5831, 0.9184, 0.4625, 0.5958, 0.0988], │ [0.2903, 0.5619, 0.5450, 0.3228, 0.9717, 0.0112, 0.0610, 0.6176]]]) ├── 'b' --> tensor([[1.0358, 1.1605, 1.6784, 1.0497, 1.4369, 1.9745]]) └── 'c' --> <Tensor 0x7f9185d15340> ├── 'd' --> tensor([[4.]]) └── 'noise' --> tensor([[[[ 0.2562, -0.6548, 0.5631, 0.7738, -0.4308], [-1.8951, -0.1646, -1.1050, 0.5682, -0.1726], [ 1.4759, 0.6835, 0.0414, -0.1966, -0.3878], [ 1.8053, 1.8906, -0.8909, -0.0159, -0.1881]], [[-0.2894, 0.3312, 0.8342, -0.5001, -1.4112], [ 1.2179, 0.0530, -1.9090, 1.1412, -0.5141], [-1.0095, 1.7890, -0.6283, -1.4054, 1.0486], [ 1.1588, -1.2241, -0.1511, 1.7479, -1.0159]], [[-0.5286, 0.0276, -0.1845, 0.7690, 0.3237], [ 1.4234, 0.9676, -0.7948, 0.2739, 1.2180], [ 0.1559, -1.5447, -1.1592, -1.3781, 0.4401], [ 0.1588, -0.9652, -0.1542, 0.1192, -1.5558]]]]) , <Tensor 0x7f9185d153a0> ├── 'a' --> tensor([[[0.3677, 0.3901, 0.2094, 0.3252, 0.2418, 0.2013, 0.3877, 0.8925], │ [0.9920, 0.1467, 0.1470, 0.9469, 0.8412, 0.6267, 0.9276, 0.3556], │ [0.9186, 0.9478, 0.0283, 0.2015, 0.2140, 0.2504, 0.2546, 0.5116]]]) ├── 'b' --> tensor([[1.6992, 1.1747, 1.0001, 1.3716, 1.0822, 1.5662]]) └── 'c' --> <Tensor 0x7f9185d15280> ├── 'd' --> tensor([[9.]]) └── 'noise' --> tensor([[[[ 1.2127e-01, -9.7338e-02, -1.0641e+00, -2.2300e-01, 1.2607e-01], [-1.4654e+00, 3.2072e-01, -1.4064e-01, -5.3356e-01, 4.7993e-01], [-9.7642e-03, 6.9599e-01, 6.9635e-01, -7.9874e-01, 2.9000e-01], [ 1.7661e+00, 6.4294e-01, 1.9043e-01, 8.6431e-01, 9.0994e-04]], [[ 1.1610e+00, 2.1799e-01, 8.5444e-01, -2.2679e+00, -6.6386e-02], [ 1.5189e+00, -1.9001e-01, -4.4042e-01, -2.0036e-01, -1.5791e+00], [ 7.7262e-01, 1.4848e+00, -1.5760e+00, 8.5489e-01, 3.5849e-01], [ 6.9080e-01, -1.8826e-02, -1.2874e+00, 6.2864e-01, -2.4863e+00]], [[-1.1041e+00, -1.6640e-01, -1.5837e+00, 5.4099e-02, 1.1398e-01], [ 2.6030e+00, 5.8517e-01, -1.0918e+00, 2.6184e+00, 3.9944e-01], [ 5.3750e-02, -1.2791e+00, 9.5119e-01, 6.2373e-01, 9.7606e-02], [-5.8499e-01, -6.5673e-01, 2.9786e-02, -6.3741e-01, -1.7399e-01]]]]) , <Tensor 0x7f9185d15400> ├── 'a' --> tensor([[[0.8125, 0.0259, 0.0690, 0.0772, 0.8263, 0.6004, 0.4874, 0.0964], │ [0.3606, 0.3776, 0.7780, 0.1085, 0.2932, 0.3797, 0.1329, 0.9739], │ [0.2839, 0.3744, 0.7001, 0.6329, 0.1951, 0.7454, 0.9891, 0.1892]]]) ├── 'b' --> tensor([[1.5423, 1.2598, 1.2095, 1.9306, 1.1801, 1.0520]]) └── 'c' --> <Tensor 0x7f9185d153d0> ├── 'd' --> tensor([[0.]]) └── 'noise' --> tensor([[[[ 2.0106e+00, 1.1619e+00, -1.6692e+00, -1.4201e+00, 5.1294e-01], [ 2.4946e-01, -6.3604e-01, -2.1080e+00, 9.1723e-01, 3.5777e-02], [-5.1351e-01, 1.0697e+00, -1.3922e+00, -3.2787e-01, -5.1948e-01], [ 2.5014e-01, 4.4103e-01, 8.9880e-01, -1.9937e-01, 8.3412e-01]], [[-1.9194e+00, 8.0426e-01, -8.3870e-01, 1.5787e+00, 1.3566e-01], [ 2.5430e+00, -1.8173e+00, -7.5265e-02, -5.3250e-02, 6.2273e-01], [-3.4499e-01, -1.2787e+00, 7.6419e-01, -5.5824e-01, -7.7401e-01], [ 3.0583e-01, -1.8303e-01, 1.1664e+00, -2.1249e-03, 8.2907e-01]], [[ 2.7347e+00, 1.0071e+00, -9.2523e-01, -1.5009e+00, 2.5365e-01], [ 2.9202e-01, 6.2300e-01, -1.3104e+00, -1.4764e+00, -1.8070e+00], [ 1.6874e-01, -2.0066e+00, 4.6249e-01, -1.3472e+00, -5.5880e-01], [-2.7539e-01, 1.1854e-01, -1.3995e+00, -7.6103e-01, -2.8695e-01]]]]) , <Tensor 0x7f9185d15460> ├── 'a' --> tensor([[[0.1986, 0.1622, 0.4137, 0.1405, 0.9052, 0.2589, 0.3613, 0.9998], │ [0.1555, 0.0922, 0.4355, 0.9130, 0.4442, 0.6955, 0.8678, 0.7797], │ [0.6549, 0.6488, 0.5300, 0.4422, 0.9739, 0.3584, 0.5158, 0.0156]]]) ├── 'b' --> tensor([[1.2924, 1.0116, 1.3513, 1.0064, 1.3752, 1.1825]]) └── 'c' --> <Tensor 0x7f9185d15430> ├── 'd' --> tensor([[0.]]) └── 'noise' --> tensor([[[[ 1.4070, -0.6493, 0.1467, 1.6439, 0.5981], [ 1.0487, 1.4318, -1.0580, -0.5041, 0.4501], [-0.3174, 0.4318, 0.7212, 0.2747, 2.0658], [-2.0758, 0.9897, -0.5453, -0.2644, -0.8066]], [[ 0.1859, 0.4662, 0.9531, 0.5400, -0.5427], [ 0.1347, 0.6061, 0.1726, -0.6139, -1.7318], [ 0.1542, 1.4777, -0.5712, 1.3656, 0.4189], [-0.7131, -1.7418, -2.1632, 0.7465, -1.4006]], [[-1.2386, -1.2326, 1.6733, 0.2993, 0.1648], [ 1.0379, -0.4832, 0.2289, 0.6771, -0.3427], [-0.1972, 0.2718, 1.2705, -1.5191, 0.0776], [ 0.1047, -0.0416, -0.1009, -1.7804, 0.2607]]]]) ) mean0 & mean1: tensor(1.3177) tensor(1.3177) even_index_a0: tensor([[[0.3966, 0.8244, 0.6097, 0.9855, 0.3798, 0.1565, 0.0098, 0.6476], [0.2903, 0.5619, 0.5450, 0.3228, 0.9717, 0.0112, 0.0610, 0.6176]], [[0.3677, 0.3901, 0.2094, 0.3252, 0.2418, 0.2013, 0.3877, 0.8925], [0.9186, 0.9478, 0.0283, 0.2015, 0.2140, 0.2504, 0.2546, 0.5116]], [[0.8125, 0.0259, 0.0690, 0.0772, 0.8263, 0.6004, 0.4874, 0.0964], [0.2839, 0.3744, 0.7001, 0.6329, 0.1951, 0.7454, 0.9891, 0.1892]], [[0.1986, 0.1622, 0.4137, 0.1405, 0.9052, 0.2589, 0.3613, 0.9998], [0.6549, 0.6488, 0.5300, 0.4422, 0.9739, 0.3584, 0.5158, 0.0156]]]) even_index_a1: tensor([[[0.3966, 0.8244, 0.6097, 0.9855, 0.3798, 0.1565, 0.0098, 0.6476], [0.2903, 0.5619, 0.5450, 0.3228, 0.9717, 0.0112, 0.0610, 0.6176]], [[0.3677, 0.3901, 0.2094, 0.3252, 0.2418, 0.2013, 0.3877, 0.8925], [0.9186, 0.9478, 0.0283, 0.2015, 0.2140, 0.2504, 0.2546, 0.5116]], [[0.8125, 0.0259, 0.0690, 0.0772, 0.8263, 0.6004, 0.4874, 0.0964], [0.2839, 0.3744, 0.7001, 0.6329, 0.1951, 0.7454, 0.9891, 0.1892]], [[0.1986, 0.1622, 0.4137, 0.1405, 0.9052, 0.2589, 0.3613, 0.9998], [0.6549, 0.6488, 0.5300, 0.4422, 0.9739, 0.3584, 0.5158, 0.0156]]]) <Size 0x7f91eca957c0> ├── 'a' --> torch.Size([1, 3, 8]) ├── 'b' --> torch.Size([1, 6]) └── 'c' --> <Size 0x7f9185d10ee0> ├── 'd' --> torch.Size([1, 1]) └── 'noise' --> torch.Size([1, 3, 4, 5]) |
The implement with treetensor API is much simpler and clearer.