Customized Operations For Different Fields

Here is another example of the custom operations implemented with both native torch API and treetensor API.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import copy

import torch

import treetensor.torch as ttorch

T, B = 3, 4


def with_nativetensor(batch_):
    mean_b_list = []
    even_index_a_list = []
    for i in range(len(batch_)):
        for k, v in batch_[i].items():
            if k == 'a':
                v = v.float()
                even_index_a_list.append(v[::2])
            elif k == 'b':
                v = v.float()
                transformed_v = torch.pow(v, 2) + 1.0
                mean_b_list.append(transformed_v.mean())
            elif k == 'c':
                for k1, v1 in v.items():
                    if k1 == 'd':
                        v1 = v1.float()
                    else:
                        print('ignore keys: {}'.format(k1))
            else:
                print('ignore keys: {}'.format(k))
    for i in range(len(batch_)):
        for k in batch_[i].keys():
            if k == 'd':
                batch_[i][k]['noise'] = torch.randn(size=(3, 4, 5))

    mean_b = sum(mean_b_list) / len(mean_b_list)
    even_index_a = torch.stack(even_index_a_list, dim=0)
    return batch_, mean_b, even_index_a


def with_treetensor(batch_):
    batch_ = [ttorch.tensor(b) for b in batch_]
    batch_ = ttorch.stack(batch_)
    batch_ = batch_.float()
    batch_.b = ttorch.pow(batch_.b, 2) + 1.0
    batch_.c.noise = ttorch.randn(size=(B, 3, 4, 5))
    mean_b = batch_.b.mean()
    even_index_a = batch_.a[:, ::2]
    batch_ = ttorch.split(batch_, split_size_or_sections=1, dim=0)
    return batch_, mean_b, even_index_a


def get_data():
    return {
        'a': torch.rand(size=(T, 8)),
        'b': torch.rand(size=(6,)),
        'c': {
            'd': torch.randint(0, 10, size=(1,))
        }
    }


if __name__ == "__main__":
    batch = [get_data() for _ in range(B)]
    batch0, mean0, even_index_a0 = with_nativetensor(copy.deepcopy(batch))
    batch1, mean1, even_index_a1 = with_treetensor(copy.deepcopy(batch))
    print(batch0)
    print('\n\n')
    print(batch1)

    assert torch.abs(mean0 - mean1) < 1e-6
    print('mean0 & mean1:', mean0, mean1)
    print('\n')

    assert torch.abs((even_index_a0 - even_index_a1).max()) < 1e-6
    print('even_index_a0:', even_index_a0)
    print('even_index_a1:', even_index_a1)

    assert len(batch0) == B
    assert len(batch1) == B
    assert isinstance(batch1[0], ttorch.Tensor)
    print(batch1[0].shape)

The output should be like below, and all the assertions can be passed.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
[{'a': tensor([[0.1596, 0.4012, 0.3223, 0.4852, 0.8140, 0.1468, 0.5699, 0.1797],
        [0.2723, 0.2857, 0.5886, 0.8552, 0.5341, 0.3000, 0.0769, 0.0941],
        [0.1838, 0.3047, 0.4412, 0.8800, 0.8474, 0.5083, 0.8790, 0.9913]]), 'b': tensor([0.0547, 0.4228, 0.6144, 0.0303, 0.6479, 0.0555]), 'c': {'d': tensor([0])}}, {'a': tensor([[0.8184, 0.3520, 0.2519, 0.8049, 0.0454, 0.2602, 0.7012, 0.7521],
        [0.4884, 0.2814, 0.5682, 0.4806, 0.9580, 0.2219, 0.3575, 0.6630],
        [0.3724, 0.5604, 0.3354, 0.7325, 0.1509, 0.0982, 0.6868, 0.3617]]), 'b': tensor([0.3862, 0.3536, 0.8789, 0.2303, 0.2371, 0.9049]), 'c': {'d': tensor([4])}}, {'a': tensor([[0.9463, 0.1392, 0.6163, 0.9218, 0.6002, 0.8881, 0.8347, 0.1492],
        [0.0390, 0.6043, 0.2513, 0.3038, 0.1501, 0.3200, 0.7904, 0.9806],
        [0.7639, 0.5355, 0.9951, 0.9049, 0.7562, 0.7972, 0.0351, 0.0165]]), 'b': tensor([0.8851, 0.0581, 0.6590, 0.7748, 0.7943, 0.3452]), 'c': {'d': tensor([2])}}, {'a': tensor([[0.1397, 0.5560, 0.1809, 0.6830, 0.0409, 0.5592, 0.4417, 0.1280],
        [0.0140, 0.6646, 0.3147, 0.2864, 0.8664, 0.7113, 0.1141, 0.5964],
        [0.1541, 0.1119, 0.9381, 0.8998, 0.6656, 0.8518, 0.7366, 0.2615]]), 'b': tensor([0.7784, 0.7635, 0.8329, 0.2074, 0.4004, 0.3969]), 'c': {'d': tensor([3])}}]



(<Tensor 0x7f632b0f4dc0>
├── 'a' --> tensor([[[0.1596, 0.4012, 0.3223, 0.4852, 0.8140, 0.1468, 0.5699, 0.1797],
│                    [0.2723, 0.2857, 0.5886, 0.8552, 0.5341, 0.3000, 0.0769, 0.0941],
│                    [0.1838, 0.3047, 0.4412, 0.8800, 0.8474, 0.5083, 0.8790, 0.9913]]])
├── 'b' --> tensor([[1.0030, 1.1788, 1.3775, 1.0009, 1.4197, 1.0031]])
└── 'c' --> <Tensor 0x7f632b0f4e20>
    ├── 'd' --> tensor([[0.]])
    └── 'noise' --> tensor([[[[-1.6472, -2.1012,  0.5953,  1.2593,  0.4807],
                              [-1.7966, -1.3679,  0.2307,  1.3067,  0.6097],
                              [ 0.1925, -1.0879, -0.1310,  0.7680, -0.3777],
                              [-0.0476, -1.5376,  1.2984,  1.0988, -0.1543]],
                    
                             [[-0.4359, -1.0796,  0.2791,  1.5900, -0.2499],
                              [ 0.5715, -0.5642, -1.2343,  0.7848, -0.9442],
                              [ 0.2103,  0.3017,  0.0613, -0.7353, -0.5311],
                              [ 0.0271,  0.6253, -0.0628, -0.5104, -0.3841]],
                    
                             [[ 0.0466, -2.4251, -1.4401,  1.4015, -1.7317],
                              [ 0.3690, -0.1695, -0.1779,  0.3515,  1.5173],
                              [ 1.2173,  1.1610,  0.7979, -0.4319,  0.0112],
                              [-1.6619,  1.0395, -0.3872, -0.9098,  2.0577]]]])
, <Tensor 0x7f632b0f4ee0>
├── 'a' --> tensor([[[0.8184, 0.3520, 0.2519, 0.8049, 0.0454, 0.2602, 0.7012, 0.7521],
│                    [0.4884, 0.2814, 0.5682, 0.4806, 0.9580, 0.2219, 0.3575, 0.6630],
│                    [0.3724, 0.5604, 0.3354, 0.7325, 0.1509, 0.0982, 0.6868, 0.3617]]])
├── 'b' --> tensor([[1.1492, 1.1250, 1.7725, 1.0531, 1.0562, 1.8188]])
└── 'c' --> <Tensor 0x7f632b0f4d30>
    ├── 'd' --> tensor([[4.]])
    └── 'noise' --> tensor([[[[ 0.7921, -0.0755, -0.8151,  0.8850,  1.4933],
                              [ 1.2783, -0.4826,  1.3904, -1.5220, -0.8982],
                              [-1.1363, -0.3053,  0.0885,  0.2277, -0.7028],
                              [-0.9773, -0.0277,  0.6444, -1.2285, -0.2365]],
                    
                             [[-0.3434, -0.8346,  0.0528,  0.0435, -0.5553],
                              [ 0.7556, -1.4243, -0.5564,  0.2192, -0.3744],
                              [ 0.2173, -0.0042,  0.3533, -2.2362, -0.4302],
                              [-0.0620, -0.8077, -1.3023, -1.5293, -0.3642]],
                    
                             [[ 0.9234,  0.6934,  0.3226,  0.6812,  0.4780],
                              [ 0.7483,  0.9376, -0.7160, -0.8707,  2.6354],
                              [ 1.7357, -0.3347,  1.0475,  1.2589, -1.0588],
                              [-1.0726, -2.3005,  0.7702,  1.2074, -0.6260]]]])
, <Tensor 0x7f632b0f4f40>
├── 'a' --> tensor([[[0.9463, 0.1392, 0.6163, 0.9218, 0.6002, 0.8881, 0.8347, 0.1492],
│                    [0.0390, 0.6043, 0.2513, 0.3038, 0.1501, 0.3200, 0.7904, 0.9806],
│                    [0.7639, 0.5355, 0.9951, 0.9049, 0.7562, 0.7972, 0.0351, 0.0165]]])
├── 'b' --> tensor([[1.7834, 1.0034, 1.4343, 1.6003, 1.6309, 1.1191]])
└── 'c' --> <Tensor 0x7f632b0f4f10>
    ├── 'd' --> tensor([[2.]])
    └── 'noise' --> tensor([[[[ 0.9213,  0.0383,  0.2955,  1.3049, -1.9328],
                              [-1.0894, -0.5289, -0.2533, -1.6255, -0.8783],
                              [-0.4002,  0.6349,  2.1597, -0.0473, -0.2817],
                              [-0.0421,  0.3092,  0.9032,  0.8430,  0.7661]],
                    
                             [[ 0.1034,  1.1913,  0.2249,  0.5569, -0.0987],
                              [ 0.3374,  0.1335,  1.9032, -1.1055, -0.1811],
                              [-0.7397,  0.4564, -0.3752,  0.7368,  0.5090],
                              [-1.0403, -1.3333, -0.0523,  1.1758,  0.2347]],
                    
                             [[ 0.4211, -1.6655, -0.4795,  0.4314,  0.8669],
                              [ 0.2273,  0.3823,  1.9700,  0.5738,  1.8101],
                              [ 2.1773, -0.0461,  0.3586, -1.5405, -0.4960],
                              [-0.6693, -0.3611,  0.0776, -0.8386,  0.3934]]]])
, <Tensor 0x7f632b0f4fa0>
├── 'a' --> tensor([[[0.1397, 0.5560, 0.1809, 0.6830, 0.0409, 0.5592, 0.4417, 0.1280],
│                    [0.0140, 0.6646, 0.3147, 0.2864, 0.8664, 0.7113, 0.1141, 0.5964],
│                    [0.1541, 0.1119, 0.9381, 0.8998, 0.6656, 0.8518, 0.7366, 0.2615]]])
├── 'b' --> tensor([[1.6059, 1.5829, 1.6938, 1.0430, 1.1603, 1.1575]])
└── 'c' --> <Tensor 0x7f632b0f4f70>
    ├── 'd' --> tensor([[3.]])
    └── 'noise' --> tensor([[[[-0.2365, -1.1637, -0.8815, -0.7596,  0.3517],
                              [-1.4273,  0.2471,  0.7383, -1.1357, -1.0635],
                              [-0.5137,  1.2854,  0.9059,  1.2296,  0.4843],
                              [-0.3079,  1.8051, -0.5554, -1.5409,  0.5662]],
                    
                             [[ 0.0711, -0.6803, -0.2576, -0.5582, -0.1279],
                              [ 0.0366,  0.9816, -0.6424,  1.1566, -0.1251],
                              [ 0.3017,  1.1231,  0.1248, -0.3778,  1.1984],
                              [-1.0851, -1.1335, -0.3563, -0.4848, -1.0288]],
                    
                             [[-0.4021,  0.5456, -0.7004,  1.3560, -0.4658],
                              [ 0.2582, -0.5490,  1.0653,  0.3482,  0.9825],
                              [ 2.2991, -0.6494,  0.6641,  1.0891, -1.8204],
                              [-0.5928,  0.7025,  1.4344,  0.4906,  2.3079]]]])
)
mean0 & mean1: tensor(1.3239) tensor(1.3239)


even_index_a0: tensor([[[0.1596, 0.4012, 0.3223, 0.4852, 0.8140, 0.1468, 0.5699, 0.1797],
         [0.1838, 0.3047, 0.4412, 0.8800, 0.8474, 0.5083, 0.8790, 0.9913]],

        [[0.8184, 0.3520, 0.2519, 0.8049, 0.0454, 0.2602, 0.7012, 0.7521],
         [0.3724, 0.5604, 0.3354, 0.7325, 0.1509, 0.0982, 0.6868, 0.3617]],

        [[0.9463, 0.1392, 0.6163, 0.9218, 0.6002, 0.8881, 0.8347, 0.1492],
         [0.7639, 0.5355, 0.9951, 0.9049, 0.7562, 0.7972, 0.0351, 0.0165]],

        [[0.1397, 0.5560, 0.1809, 0.6830, 0.0409, 0.5592, 0.4417, 0.1280],
         [0.1541, 0.1119, 0.9381, 0.8998, 0.6656, 0.8518, 0.7366, 0.2615]]])
even_index_a1: tensor([[[0.1596, 0.4012, 0.3223, 0.4852, 0.8140, 0.1468, 0.5699, 0.1797],
         [0.1838, 0.3047, 0.4412, 0.8800, 0.8474, 0.5083, 0.8790, 0.9913]],

        [[0.8184, 0.3520, 0.2519, 0.8049, 0.0454, 0.2602, 0.7012, 0.7521],
         [0.3724, 0.5604, 0.3354, 0.7325, 0.1509, 0.0982, 0.6868, 0.3617]],

        [[0.9463, 0.1392, 0.6163, 0.9218, 0.6002, 0.8881, 0.8347, 0.1492],
         [0.7639, 0.5355, 0.9951, 0.9049, 0.7562, 0.7972, 0.0351, 0.0165]],

        [[0.1397, 0.5560, 0.1809, 0.6830, 0.0409, 0.5592, 0.4417, 0.1280],
         [0.1541, 0.1119, 0.9381, 0.8998, 0.6656, 0.8518, 0.7366, 0.2615]]])
<Size 0x7f638affb460>
├── 'a' --> torch.Size([1, 3, 8])
├── 'b' --> torch.Size([1, 6])
└── 'c' --> <Size 0x7f632b171460>
    ├── 'd' --> torch.Size([1, 1])
    └── 'noise' --> torch.Size([1, 3, 4, 5])

The implement with treetensor API is much simpler and clearer.