Customized Operations For Different Fields¶
Here is another example of the custom operations implemented with both native torch API and treetensor API.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 | import copy import torch import treetensor.torch as ttorch T, B = 3, 4 def with_nativetensor(batch_): mean_b_list = [] even_index_a_list = [] for i in range(len(batch_)): for k, v in batch_[i].items(): if k == 'a': v = v.float() even_index_a_list.append(v[::2]) elif k == 'b': v = v.float() transformed_v = torch.pow(v, 2) + 1.0 mean_b_list.append(transformed_v.mean()) elif k == 'c': for k1, v1 in v.items(): if k1 == 'd': v1 = v1.float() else: print('ignore keys: {}'.format(k1)) else: print('ignore keys: {}'.format(k)) for i in range(len(batch_)): for k in batch_[i].keys(): if k == 'd': batch_[i][k]['noise'] = torch.randn(size=(3, 4, 5)) mean_b = sum(mean_b_list) / len(mean_b_list) even_index_a = torch.stack(even_index_a_list, dim=0) return batch_, mean_b, even_index_a def with_treetensor(batch_): batch_ = [ttorch.tensor(b) for b in batch_] batch_ = ttorch.stack(batch_) batch_ = batch_.float() batch_.b = ttorch.pow(batch_.b, 2) + 1.0 batch_.c.noise = ttorch.randn(size=(B, 3, 4, 5)) mean_b = batch_.b.mean() even_index_a = batch_.a[:, ::2] batch_ = ttorch.split(batch_, split_size_or_sections=1, dim=0) return batch_, mean_b, even_index_a def get_data(): return { 'a': torch.rand(size=(T, 8)), 'b': torch.rand(size=(6,)), 'c': { 'd': torch.randint(0, 10, size=(1,)) } } if __name__ == "__main__": batch = [get_data() for _ in range(B)] batch0, mean0, even_index_a0 = with_nativetensor(copy.deepcopy(batch)) batch1, mean1, even_index_a1 = with_treetensor(copy.deepcopy(batch)) print(batch0) print('\n\n') print(batch1) assert torch.abs(mean0 - mean1) < 1e-6 print('mean0 & mean1:', mean0, mean1) print('\n') assert torch.abs((even_index_a0 - even_index_a1).max()) < 1e-6 print('even_index_a0:', even_index_a0) print('even_index_a1:', even_index_a1) assert len(batch0) == B assert len(batch1) == B assert isinstance(batch1[0], ttorch.Tensor) print(batch1[0].shape) |
The output should be like below, and all the assertions can be passed.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 | [{'a': tensor([[0.7766, 0.5611, 0.2536, 0.6818, 0.7486, 0.6041, 0.5483, 0.7180], [0.3113, 0.9901, 0.6145, 0.2825, 0.9314, 0.7540, 0.6752, 0.8161], [0.8894, 0.4415, 0.1984, 0.4616, 0.3255, 0.2886, 0.1082, 0.0120]]), 'b': tensor([0.3847, 0.7117, 0.6698, 0.3810, 0.6242, 0.7978]), 'c': {'d': tensor([1])}}, {'a': tensor([[0.0242, 0.0353, 0.8126, 0.3531, 0.3268, 0.9762, 0.3606, 0.1387], [0.2568, 0.1814, 0.0016, 0.4845, 0.3869, 0.1839, 0.6877, 0.6100], [0.8865, 0.5986, 0.5261, 0.2205, 0.6694, 0.3079, 0.6498, 0.8052]]), 'b': tensor([0.4525, 0.2294, 0.6572, 0.5407, 0.7217, 0.8113]), 'c': {'d': tensor([5])}}, {'a': tensor([[0.5047, 0.1707, 0.0935, 0.5198, 0.2858, 0.7896, 0.7454, 0.7980], [0.9075, 0.3477, 0.5370, 0.2072, 0.1164, 0.2555, 0.4791, 0.8106], [0.1644, 0.4568, 0.5529, 0.6836, 0.5587, 0.8452, 0.2835, 0.3440]]), 'b': tensor([0.6000, 0.5349, 0.8243, 0.1235, 0.2358, 0.3924]), 'c': {'d': tensor([4])}}, {'a': tensor([[0.0312, 0.3731, 0.5947, 0.8834, 0.8944, 0.6921, 0.2843, 0.8457], [0.9184, 0.3797, 0.0462, 0.6986, 0.5699, 0.3366, 0.3529, 0.7713], [0.6949, 0.3342, 0.0112, 0.2913, 0.7899, 0.0386, 0.2649, 0.3976]]), 'b': tensor([0.4561, 0.7853, 0.4153, 0.6906, 0.7903, 0.4928]), 'c': {'d': tensor([2])}}] (<Tensor 0x7f91ea905ca0> ├── 'a' --> tensor([[[0.7766, 0.5611, 0.2536, 0.6818, 0.7486, 0.6041, 0.5483, 0.7180], │ [0.3113, 0.9901, 0.6145, 0.2825, 0.9314, 0.7540, 0.6752, 0.8161], │ [0.8894, 0.4415, 0.1984, 0.4616, 0.3255, 0.2886, 0.1082, 0.0120]]]) ├── 'b' --> tensor([[1.1480, 1.5065, 1.4486, 1.1452, 1.3896, 1.6366]]) └── 'c' --> <Tensor 0x7f91ea905d00> ├── 'd' --> tensor([[1.]]) └── 'noise' --> tensor([[[[ 2.5136, -0.1081, 0.8018, 1.5091, -1.0896], [ 2.1947, 0.9641, 0.3989, -0.5170, 0.0855], [ 0.8565, -0.6989, 0.2717, 0.8485, -0.0664], [-0.3098, -0.9206, 0.5248, 0.0264, 1.5740]], [[-1.4064, -0.4015, 0.1006, 0.0507, 1.5793], [ 0.9215, 1.6179, -1.5108, 1.2909, -0.2709], [-0.0231, 0.3169, -0.6971, -0.8785, 0.0584], [ 0.8124, 0.2928, 1.4596, 0.6832, -1.9324]], [[-1.5192, -1.3798, -1.1759, 1.1067, -1.4371], [-1.3846, 0.1428, 0.8906, -0.2528, -0.4610], [ 1.3863, 0.4864, 0.2585, 0.2342, -0.0341], [-0.5841, -0.9200, 0.1922, -0.6996, 0.9480]]]]) , <Tensor 0x7f91ea905dc0> ├── 'a' --> tensor([[[0.0242, 0.0353, 0.8126, 0.3531, 0.3268, 0.9762, 0.3606, 0.1387], │ [0.2568, 0.1814, 0.0016, 0.4845, 0.3869, 0.1839, 0.6877, 0.6100], │ [0.8865, 0.5986, 0.5261, 0.2205, 0.6694, 0.3079, 0.6498, 0.8052]]]) ├── 'b' --> tensor([[1.2048, 1.0526, 1.4320, 1.2923, 1.5209, 1.6582]]) └── 'c' --> <Tensor 0x7f91ea905be0> ├── 'd' --> tensor([[5.]]) └── 'noise' --> tensor([[[[-0.1860, 0.5182, 0.6956, 0.4537, -0.2303], [-0.5841, 2.5433, -0.7179, 0.5835, -0.9927], [ 0.2585, 0.1612, 1.5680, 0.4102, -2.1762], [-1.5873, 0.7842, 0.5519, 0.4897, -0.9798]], [[ 0.2232, 1.1629, 0.8133, -1.4461, -0.0508], [-1.0642, -2.1902, -0.5686, -0.2585, 0.1476], [-1.2493, 0.6766, -1.0951, 0.9376, 1.7562], [-0.0064, 0.4575, -1.1243, -0.0211, 0.2175]], [[-0.0507, -0.0816, -2.4400, -2.5221, -0.2645], [-0.4867, 0.5980, -0.2298, -1.0104, -0.3548], [ 1.5903, 0.2890, 0.5077, 0.7891, -0.8798], [ 1.2358, 1.6742, 0.8530, 0.4171, -1.4960]]]]) , <Tensor 0x7f91ea905e20> ├── 'a' --> tensor([[[0.5047, 0.1707, 0.0935, 0.5198, 0.2858, 0.7896, 0.7454, 0.7980], │ [0.9075, 0.3477, 0.5370, 0.2072, 0.1164, 0.2555, 0.4791, 0.8106], │ [0.1644, 0.4568, 0.5529, 0.6836, 0.5587, 0.8452, 0.2835, 0.3440]]]) ├── 'b' --> tensor([[1.3600, 1.2861, 1.6795, 1.0152, 1.0556, 1.1540]]) └── 'c' --> <Tensor 0x7f91ea905df0> ├── 'd' --> tensor([[4.]]) └── 'noise' --> tensor([[[[-0.2597, 0.5502, -0.3550, 1.2803, -0.4777], [-1.9271, -0.0587, 0.9955, -0.3581, -1.3456], [-0.5616, 0.5569, 0.0172, -0.4264, 0.2685], [-0.7921, -0.3920, 0.7851, 0.7612, 0.3785]], [[-0.2208, -0.9577, 0.4848, -0.4689, -1.2425], [ 0.9249, 1.6592, 0.8347, -1.2518, -0.1103], [-1.1229, 0.4504, -0.1091, -1.8570, 0.5973], [ 0.8977, 0.5524, 0.8264, -1.1807, -0.8822]], [[-1.0476, -0.4988, 1.3481, 0.3397, -1.0965], [ 0.4100, -0.3561, 0.3236, -0.7633, 0.6493], [ 0.4271, -0.2514, -0.5729, 0.6740, -0.5706], [ 0.7090, -0.2594, 0.6588, 2.5359, 0.3798]]]]) , <Tensor 0x7f91ea905e80> ├── 'a' --> tensor([[[0.0312, 0.3731, 0.5947, 0.8834, 0.8944, 0.6921, 0.2843, 0.8457], │ [0.9184, 0.3797, 0.0462, 0.6986, 0.5699, 0.3366, 0.3529, 0.7713], │ [0.6949, 0.3342, 0.0112, 0.2913, 0.7899, 0.0386, 0.2649, 0.3976]]]) ├── 'b' --> tensor([[1.2080, 1.6168, 1.1724, 1.4769, 1.6246, 1.2428]]) └── 'c' --> <Tensor 0x7f91ea905e50> ├── 'd' --> tensor([[2.]]) └── 'noise' --> tensor([[[[-2.2400, -1.6115, -1.4326, 1.6931, -1.3931], [ 1.6587, -0.2218, 1.1462, -0.4680, -0.0203], [ 0.1144, -0.9507, 0.6894, 1.0788, 0.5769], [-0.2663, 0.9609, 0.6088, -0.5087, 0.6183]], [[-0.2387, -0.7638, 0.3636, -0.7947, 0.1206], [ 0.1529, 0.6614, -0.0161, -0.2323, -0.2046], [-0.1206, -0.8588, 1.2333, 1.5618, -0.7350], [ 1.3395, -0.9671, -0.7410, 0.7517, -0.1983]], [[ 1.4318, -0.2031, -1.7052, -0.7655, 0.3964], [-0.1339, -0.4975, 0.3394, -0.6810, 0.6485], [-0.6155, -0.6125, 1.0350, 0.2076, -0.1171], [-0.0900, 1.0732, -0.0580, -1.2796, 0.3661]]]]) ) mean0 & mean1: tensor(1.3470) tensor(1.3470) even_index_a0: tensor([[[0.7766, 0.5611, 0.2536, 0.6818, 0.7486, 0.6041, 0.5483, 0.7180], [0.8894, 0.4415, 0.1984, 0.4616, 0.3255, 0.2886, 0.1082, 0.0120]], [[0.0242, 0.0353, 0.8126, 0.3531, 0.3268, 0.9762, 0.3606, 0.1387], [0.8865, 0.5986, 0.5261, 0.2205, 0.6694, 0.3079, 0.6498, 0.8052]], [[0.5047, 0.1707, 0.0935, 0.5198, 0.2858, 0.7896, 0.7454, 0.7980], [0.1644, 0.4568, 0.5529, 0.6836, 0.5587, 0.8452, 0.2835, 0.3440]], [[0.0312, 0.3731, 0.5947, 0.8834, 0.8944, 0.6921, 0.2843, 0.8457], [0.6949, 0.3342, 0.0112, 0.2913, 0.7899, 0.0386, 0.2649, 0.3976]]]) even_index_a1: tensor([[[0.7766, 0.5611, 0.2536, 0.6818, 0.7486, 0.6041, 0.5483, 0.7180], [0.8894, 0.4415, 0.1984, 0.4616, 0.3255, 0.2886, 0.1082, 0.0120]], [[0.0242, 0.0353, 0.8126, 0.3531, 0.3268, 0.9762, 0.3606, 0.1387], [0.8865, 0.5986, 0.5261, 0.2205, 0.6694, 0.3079, 0.6498, 0.8052]], [[0.5047, 0.1707, 0.0935, 0.5198, 0.2858, 0.7896, 0.7454, 0.7980], [0.1644, 0.4568, 0.5529, 0.6836, 0.5587, 0.8452, 0.2835, 0.3440]], [[0.0312, 0.3731, 0.5947, 0.8834, 0.8944, 0.6921, 0.2843, 0.8457], [0.6949, 0.3342, 0.0112, 0.2913, 0.7899, 0.0386, 0.2649, 0.3976]]]) <Size 0x7f92507ee430> ├── 'a' --> torch.Size([1, 3, 8]) ├── 'b' --> torch.Size([1, 6]) └── 'c' --> <Size 0x7f91ea9930d0> ├── 'd' --> torch.Size([1, 1]) └── 'noise' --> torch.Size([1, 3, 4, 5]) |
The implement with treetensor API is much simpler and clearer.