Customized Operations For Different Fields

Here is another example of the custom operations implemented with both native torch API and treetensor API.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import copy

import torch

import treetensor.torch as ttorch

T, B = 3, 4


def with_nativetensor(batch_):
    mean_b_list = []
    even_index_a_list = []
    for i in range(len(batch_)):
        for k, v in batch_[i].items():
            if k == 'a':
                v = v.float()
                even_index_a_list.append(v[::2])
            elif k == 'b':
                v = v.float()
                transformed_v = torch.pow(v, 2) + 1.0
                mean_b_list.append(transformed_v.mean())
            elif k == 'c':
                for k1, v1 in v.items():
                    if k1 == 'd':
                        v1 = v1.float()
                    else:
                        print('ignore keys: {}'.format(k1))
            else:
                print('ignore keys: {}'.format(k))
    for i in range(len(batch_)):
        for k in batch_[i].keys():
            if k == 'd':
                batch_[i][k]['noise'] = torch.randn(size=(3, 4, 5))

    mean_b = sum(mean_b_list) / len(mean_b_list)
    even_index_a = torch.stack(even_index_a_list, dim=0)
    return batch_, mean_b, even_index_a


def with_treetensor(batch_):
    batch_ = [ttorch.tensor(b) for b in batch_]
    batch_ = ttorch.stack(batch_)
    batch_ = batch_.float()
    batch_.b = ttorch.pow(batch_.b, 2) + 1.0
    batch_.c.noise = ttorch.randn(size=(B, 3, 4, 5))
    mean_b = batch_.b.mean()
    even_index_a = batch_.a[:, ::2]
    batch_ = ttorch.split(batch_, split_size_or_sections=1, dim=0)
    return batch_, mean_b, even_index_a


def get_data():
    return {
        'a': torch.rand(size=(T, 8)),
        'b': torch.rand(size=(6,)),
        'c': {
            'd': torch.randint(0, 10, size=(1,))
        }
    }


if __name__ == "__main__":
    batch = [get_data() for _ in range(B)]
    batch0, mean0, even_index_a0 = with_nativetensor(copy.deepcopy(batch))
    batch1, mean1, even_index_a1 = with_treetensor(copy.deepcopy(batch))
    print(batch0)
    print('\n\n')
    print(batch1)

    assert torch.abs(mean0 - mean1) < 1e-6
    print('mean0 & mean1:', mean0, mean1)
    print('\n')

    assert torch.abs((even_index_a0 - even_index_a1).max()) < 1e-6
    print('even_index_a0:', even_index_a0)
    print('even_index_a1:', even_index_a1)

    assert len(batch0) == B
    assert len(batch1) == B
    assert isinstance(batch1[0], ttorch.Tensor)
    print(batch1[0].shape)

The output should be like below, and all the assertions can be passed.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
[{'a': tensor([[0.6796, 0.8915, 0.1787, 0.7100, 0.1271, 0.6469, 0.8311, 0.6974],
        [0.4011, 0.6651, 0.7264, 0.4408, 0.8056, 0.3359, 0.9172, 0.8944],
        [0.5826, 0.7113, 0.1235, 0.6089, 0.1268, 0.2385, 0.2617, 0.9369]]), 'b': tensor([0.8683, 0.6075, 0.1722, 0.4260, 0.3023, 0.8900]), 'c': {'d': tensor([8])}}, {'a': tensor([[0.3677, 0.1415, 0.0893, 0.2518, 0.4041, 0.0492, 0.1923, 0.9586],
        [0.7426, 0.5916, 0.2602, 0.1307, 0.2072, 0.5941, 0.3626, 0.4814],
        [0.1689, 0.1465, 0.3012, 0.5370, 0.9788, 0.1135, 0.1176, 0.7977]]), 'b': tensor([0.8620, 0.7494, 0.6395, 0.2251, 0.6765, 0.9543]), 'c': {'d': tensor([3])}}, {'a': tensor([[0.2679, 0.1777, 0.1735, 0.3921, 0.2183, 0.4208, 0.0421, 0.6751],
        [0.9911, 0.3462, 0.8812, 0.8405, 0.3927, 0.9823, 0.9971, 0.8806],
        [0.7546, 0.0281, 0.2463, 0.4958, 0.7946, 0.8709, 0.1290, 0.1247]]), 'b': tensor([0.2163, 0.9633, 0.5466, 0.7943, 0.6659, 0.0497]), 'c': {'d': tensor([7])}}, {'a': tensor([[0.8043, 0.8778, 0.7910, 0.1306, 0.5225, 0.7919, 0.3902, 0.8837],
        [0.7651, 0.6421, 0.3727, 0.4784, 0.8094, 0.1772, 0.0986, 0.5673],
        [0.7576, 0.4917, 0.2139, 0.5751, 0.2011, 0.6180, 0.3734, 0.7506]]), 'b': tensor([0.2779, 0.9648, 0.7230, 0.2130, 0.3859, 0.4441]), 'c': {'d': tensor([4])}}]



(<Tensor 0x7f4a8d72bdc0>
├── 'a' --> tensor([[[0.6796, 0.8915, 0.1787, 0.7100, 0.1271, 0.6469, 0.8311, 0.6974],
│                    [0.4011, 0.6651, 0.7264, 0.4408, 0.8056, 0.3359, 0.9172, 0.8944],
│                    [0.5826, 0.7113, 0.1235, 0.6089, 0.1268, 0.2385, 0.2617, 0.9369]]])
├── 'b' --> tensor([[1.7539, 1.3691, 1.0296, 1.1814, 1.0914, 1.7922]])
└── 'c' --> <Tensor 0x7f4a8d72be20>
    ├── 'd' --> tensor([[8.]])
    └── 'noise' --> tensor([[[[-0.8381,  2.1321, -0.9444, -0.7191, -1.3926],
                              [-0.2115, -0.5253, -0.7440, -1.7407, -0.7213],
                              [ 0.2595, -1.0568, -0.2872,  0.2834,  0.8790],
                              [ 0.1232, -1.4918,  0.9498, -0.7962, -0.0331]],
                    
                             [[-0.0877, -1.0657,  0.5894, -0.8115, -1.1342],
                              [ 0.9418, -0.5228, -0.2236,  0.8553, -0.2079],
                              [ 0.8717,  0.0194, -0.3866, -1.3372,  0.9505],
                              [-0.5231, -0.1153,  1.2639,  1.8143, -2.3913]],
                    
                             [[-0.2092, -0.6510,  0.7318,  0.2611, -0.5800],
                              [ 0.2776, -0.1129,  1.3485, -0.8305, -0.3350],
                              [ 0.5349, -0.2376, -0.5073, -0.0369, -0.2382],
                              [ 1.6290, -0.2432,  0.3592, -0.7307, -1.1146]]]])
, <Tensor 0x7f4a8d72bee0>
├── 'a' --> tensor([[[0.3677, 0.1415, 0.0893, 0.2518, 0.4041, 0.0492, 0.1923, 0.9586],
│                    [0.7426, 0.5916, 0.2602, 0.1307, 0.2072, 0.5941, 0.3626, 0.4814],
│                    [0.1689, 0.1465, 0.3012, 0.5370, 0.9788, 0.1135, 0.1176, 0.7977]]])
├── 'b' --> tensor([[1.7430, 1.5616, 1.4090, 1.0507, 1.4577, 1.9107]])
└── 'c' --> <Tensor 0x7f4a8d72bd30>
    ├── 'd' --> tensor([[3.]])
    └── 'noise' --> tensor([[[[ 2.0865, -0.3595, -0.6910,  1.0244, -0.0228],
                              [ 0.9128,  0.4903,  0.1303, -1.5050,  1.0615],
                              [-0.0071, -0.2558,  0.1073,  0.5586,  0.6103],
                              [-0.3212,  1.2623,  1.6619,  0.6409,  0.5480]],
                    
                             [[-0.5059, -0.9238,  0.7028, -0.7919,  0.7945],
                              [ 1.2895,  1.2693, -0.0658, -0.2758, -0.1439],
                              [ 1.5293,  0.5293, -1.5634, -0.1730,  2.1593],
                              [ 1.8063,  1.2966, -1.2386, -0.2025,  1.0203]],
                    
                             [[-0.3223,  0.7071,  1.2362,  0.9568,  0.5885],
                              [ 1.3537,  0.0566,  1.3208,  0.4373,  1.7377],
                              [-0.9590,  0.1773, -0.3827, -0.3538, -2.9457],
                              [ 1.1134,  0.4801,  0.2854,  0.2599,  0.6335]]]])
, <Tensor 0x7f4a8d72bf40>
├── 'a' --> tensor([[[0.2679, 0.1777, 0.1735, 0.3921, 0.2183, 0.4208, 0.0421, 0.6751],
│                    [0.9911, 0.3462, 0.8812, 0.8405, 0.3927, 0.9823, 0.9971, 0.8806],
│                    [0.7546, 0.0281, 0.2463, 0.4958, 0.7946, 0.8709, 0.1290, 0.1247]]])
├── 'b' --> tensor([[1.0468, 1.9280, 1.2988, 1.6309, 1.4434, 1.0025]])
└── 'c' --> <Tensor 0x7f4a8d72bf10>
    ├── 'd' --> tensor([[7.]])
    └── 'noise' --> tensor([[[[ 0.0507,  0.8475,  1.9797,  1.5418,  0.4959],
                              [ 0.2704,  1.2632, -0.9198, -0.3680,  0.1904],
                              [ 0.4642,  0.6332,  1.3865, -0.9788, -1.0222],
                              [-0.9968, -0.8438,  1.1953, -0.3468,  1.8236]],
                    
                             [[-2.4549, -0.6488,  0.8985,  0.7699, -1.3897],
                              [-0.2723,  1.1534, -1.5565,  0.2130,  1.5684],
                              [-1.1921,  0.1810, -0.9803, -1.6161, -1.3742],
                              [-0.4538, -0.0770,  1.6496, -1.5832, -0.5101]],
                    
                             [[ 1.4290,  1.1963, -0.4897,  0.0123,  1.0885],
                              [-0.2512, -1.6909,  0.1550,  0.2726, -0.6209],
                              [ 2.0050,  0.6278,  0.3968,  0.5575,  0.4668],
                              [ 0.0547,  0.7338,  1.6606,  0.4362, -1.9252]]]])
, <Tensor 0x7f4a8d72bfa0>
├── 'a' --> tensor([[[0.8043, 0.8778, 0.7910, 0.1306, 0.5225, 0.7919, 0.3902, 0.8837],
│                    [0.7651, 0.6421, 0.3727, 0.4784, 0.8094, 0.1772, 0.0986, 0.5673],
│                    [0.7576, 0.4917, 0.2139, 0.5751, 0.2011, 0.6180, 0.3734, 0.7506]]])
├── 'b' --> tensor([[1.0772, 1.9309, 1.5227, 1.0454, 1.1489, 1.1972]])
└── 'c' --> <Tensor 0x7f4a8d72bf70>
    ├── 'd' --> tensor([[4.]])
    └── 'noise' --> tensor([[[[-8.6057e-01,  1.9909e+00, -1.1341e+00, -3.1219e-01, -1.0579e+00],
                              [ 4.2417e-01,  9.0933e-01,  1.5546e+00,  2.4609e+00, -2.7219e+00],
                              [ 3.5179e-01,  4.9726e-01,  6.9975e-01, -6.5159e-01, -1.1356e+00],
                              [ 2.8656e-01, -1.7430e-02,  1.6347e+00,  7.6030e-01, -8.6411e-01]],
                    
                             [[ 6.5482e-01, -3.3193e-01,  1.8156e+00, -1.2931e-01, -1.5559e+00],
                              [-3.2648e-01,  8.0178e-01, -3.0099e-01,  1.9234e+00,  5.6460e-01],
                              [-3.3813e-01, -1.6584e+00, -6.7409e-01,  4.0787e-01, -3.6855e-01],
                              [-1.2040e+00,  3.6019e-01, -6.1413e-01,  1.2220e+00, -9.5071e-01]],
                    
                             [[-6.4820e-01,  1.3130e+00, -3.1143e-01,  7.4179e-02,  9.6257e-01],
                              [ 1.7304e+00, -4.4796e-01,  9.9969e-01,  1.5349e-01,  3.0027e-02],
                              [ 7.0398e-01, -3.9019e-01,  1.1322e+00, -2.1520e+00,  4.9974e-01],
                              [ 1.4728e+00, -1.9006e-01, -5.1278e-01, -6.7268e-01,  1.4833e-03]]]])
)
mean0 & mean1: tensor(1.4009) tensor(1.4009)


even_index_a0: tensor([[[0.6796, 0.8915, 0.1787, 0.7100, 0.1271, 0.6469, 0.8311, 0.6974],
         [0.5826, 0.7113, 0.1235, 0.6089, 0.1268, 0.2385, 0.2617, 0.9369]],

        [[0.3677, 0.1415, 0.0893, 0.2518, 0.4041, 0.0492, 0.1923, 0.9586],
         [0.1689, 0.1465, 0.3012, 0.5370, 0.9788, 0.1135, 0.1176, 0.7977]],

        [[0.2679, 0.1777, 0.1735, 0.3921, 0.2183, 0.4208, 0.0421, 0.6751],
         [0.7546, 0.0281, 0.2463, 0.4958, 0.7946, 0.8709, 0.1290, 0.1247]],

        [[0.8043, 0.8778, 0.7910, 0.1306, 0.5225, 0.7919, 0.3902, 0.8837],
         [0.7576, 0.4917, 0.2139, 0.5751, 0.2011, 0.6180, 0.3734, 0.7506]]])
even_index_a1: tensor([[[0.6796, 0.8915, 0.1787, 0.7100, 0.1271, 0.6469, 0.8311, 0.6974],
         [0.5826, 0.7113, 0.1235, 0.6089, 0.1268, 0.2385, 0.2617, 0.9369]],

        [[0.3677, 0.1415, 0.0893, 0.2518, 0.4041, 0.0492, 0.1923, 0.9586],
         [0.1689, 0.1465, 0.3012, 0.5370, 0.9788, 0.1135, 0.1176, 0.7977]],

        [[0.2679, 0.1777, 0.1735, 0.3921, 0.2183, 0.4208, 0.0421, 0.6751],
         [0.7546, 0.0281, 0.2463, 0.4958, 0.7946, 0.8709, 0.1290, 0.1247]],

        [[0.8043, 0.8778, 0.7910, 0.1306, 0.5225, 0.7919, 0.3902, 0.8837],
         [0.7576, 0.4917, 0.2139, 0.5751, 0.2011, 0.6180, 0.3734, 0.7506]]])
<Size 0x7f4aed616460>
├── 'a' --> torch.Size([1, 3, 8])
├── 'b' --> torch.Size([1, 6])
└── 'c' --> <Size 0x7f4a8d7a8460>
    ├── 'd' --> torch.Size([1, 1])
    └── 'noise' --> torch.Size([1, 3, 4, 5])

The implement with treetensor API is much simpler and clearer.