Customized Operations For Different Fields

Here is another example of the custom operations implemented with both native torch API and treetensor API.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import copy

import torch

import treetensor.torch as ttorch

T, B = 3, 4


def with_nativetensor(batch_):
    mean_b_list = []
    even_index_a_list = []
    for i in range(len(batch_)):
        for k, v in batch_[i].items():
            if k == 'a':
                v = v.float()
                even_index_a_list.append(v[::2])
            elif k == 'b':
                v = v.float()
                transformed_v = torch.pow(v, 2) + 1.0
                mean_b_list.append(transformed_v.mean())
            elif k == 'c':
                for k1, v1 in v.items():
                    if k1 == 'd':
                        v1 = v1.float()
                    else:
                        print('ignore keys: {}'.format(k1))
            else:
                print('ignore keys: {}'.format(k))
    for i in range(len(batch_)):
        for k in batch_[i].keys():
            if k == 'd':
                batch_[i][k]['noise'] = torch.randn(size=(3, 4, 5))

    mean_b = sum(mean_b_list) / len(mean_b_list)
    even_index_a = torch.stack(even_index_a_list, dim=0)
    return batch_, mean_b, even_index_a


def with_treetensor(batch_):
    batch_ = [ttorch.tensor(b) for b in batch_]
    batch_ = ttorch.stack(batch_)
    batch_ = batch_.float()
    batch_.b = ttorch.pow(batch_.b, 2) + 1.0
    batch_.c.noise = ttorch.randn(size=(B, 3, 4, 5))
    mean_b = batch_.b.mean()
    even_index_a = batch_.a[:, ::2]
    batch_ = ttorch.split(batch_, split_size_or_sections=1, dim=0)
    return batch_, mean_b, even_index_a


def get_data():
    return {
        'a': torch.rand(size=(T, 8)),
        'b': torch.rand(size=(6,)),
        'c': {
            'd': torch.randint(0, 10, size=(1,))
        }
    }


if __name__ == "__main__":
    batch = [get_data() for _ in range(B)]
    batch0, mean0, even_index_a0 = with_nativetensor(copy.deepcopy(batch))
    batch1, mean1, even_index_a1 = with_treetensor(copy.deepcopy(batch))
    print(batch0)
    print('\n\n')
    print(batch1)

    assert torch.abs(mean0 - mean1) < 1e-6
    print('mean0 & mean1:', mean0, mean1)
    print('\n')

    assert torch.abs((even_index_a0 - even_index_a1).max()) < 1e-6
    print('even_index_a0:', even_index_a0)
    print('even_index_a1:', even_index_a1)

    assert len(batch0) == B
    assert len(batch1) == B
    assert isinstance(batch1[0], ttorch.Tensor)
    print(batch1[0].shape)

The output should be like below, and all the assertions can be passed.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
[{'a': tensor([[0.7952, 0.5317, 0.9969, 0.7247, 0.9008, 0.8212, 0.3001, 0.8340],
        [0.0502, 0.1068, 0.0729, 0.3611, 0.7296, 0.5931, 0.6764, 0.2331],
        [0.4727, 0.5710, 0.9302, 0.8417, 0.7715, 0.8744, 0.8541, 0.4062]]), 'b': tensor([0.9890, 0.8908, 0.2642, 0.1434, 0.1721, 0.9819]), 'c': {'d': tensor([4])}}, {'a': tensor([[0.1144, 0.8584, 0.1021, 0.4512, 0.2160, 0.4735, 0.1162, 0.4354],
        [0.2487, 0.6985, 0.9192, 0.3302, 0.8952, 0.1003, 0.1221, 0.0799],
        [0.8819, 0.8667, 0.0678, 0.5979, 0.7199, 0.5001, 0.9542, 0.4657]]), 'b': tensor([0.4046, 0.8683, 0.6061, 0.8716, 0.4036, 0.3148]), 'c': {'d': tensor([8])}}, {'a': tensor([[0.3129, 0.4439, 0.8225, 0.1990, 0.4839, 0.6035, 0.6834, 0.5054],
        [0.3481, 0.6482, 0.3350, 0.7787, 0.2522, 0.7648, 0.1572, 0.7762],
        [0.5089, 0.3098, 0.3571, 0.4281, 0.0979, 0.2945, 0.0717, 0.9534]]), 'b': tensor([0.6999, 0.9815, 0.6931, 0.8468, 0.4924, 0.6062]), 'c': {'d': tensor([7])}}, {'a': tensor([[0.5043, 0.1589, 0.5069, 0.4711, 0.3651, 0.7958, 0.4787, 0.3208],
        [0.4928, 0.9831, 0.0118, 0.6966, 0.4535, 0.9948, 0.7176, 0.1944],
        [0.0527, 0.1514, 0.3471, 0.2261, 0.1730, 0.7510, 0.1243, 0.8186]]), 'b': tensor([0.9240, 0.4537, 0.9679, 0.1696, 0.8909, 0.2051]), 'c': {'d': tensor([8])}}]



(<Tensor 0x7f83ecb97ac0>
├── 'a' --> tensor([[[0.7952, 0.5317, 0.9969, 0.7247, 0.9008, 0.8212, 0.3001, 0.8340],
│                    [0.0502, 0.1068, 0.0729, 0.3611, 0.7296, 0.5931, 0.6764, 0.2331],
│                    [0.4727, 0.5710, 0.9302, 0.8417, 0.7715, 0.8744, 0.8541, 0.4062]]])
├── 'b' --> tensor([[1.9782, 1.7935, 1.0698, 1.0206, 1.0296, 1.9642]])
└── 'c' --> <Tensor 0x7f83ecb97b20>
    ├── 'd' --> tensor([[4.]])
    └── 'noise' --> tensor([[[[-0.9922, -0.8831, -0.9031, -1.4641, -1.3487],
                              [ 0.3582, -1.2870, -0.0889,  0.4798,  0.4101],
                              [-1.5796, -0.0642, -1.0601,  2.1795,  0.0919],
                              [-1.1331, -0.7478, -1.6202,  0.2408,  0.4960]],
                    
                             [[-0.2332, -0.9402,  1.2484,  0.6176, -1.2069],
                              [-0.5010,  0.7441, -0.3839,  0.8156,  0.1747],
                              [ 1.3137,  0.7718, -0.7034,  0.2337, -0.4377],
                              [-0.0994,  0.3264, -0.2980, -1.2229,  0.9893]],
                    
                             [[ 1.1173, -0.2840, -0.8270, -0.3214,  2.0599],
                              [-0.1047,  1.6224, -0.3308,  0.9336,  0.6434],
                              [-0.2882, -0.8809, -0.4204,  1.6479, -0.2479],
                              [ 0.4165,  0.8377, -0.3197,  0.1841,  1.6500]]]])
, <Tensor 0x7f83ecb97be0>
├── 'a' --> tensor([[[0.1144, 0.8584, 0.1021, 0.4512, 0.2160, 0.4735, 0.1162, 0.4354],
│                    [0.2487, 0.6985, 0.9192, 0.3302, 0.8952, 0.1003, 0.1221, 0.0799],
│                    [0.8819, 0.8667, 0.0678, 0.5979, 0.7199, 0.5001, 0.9542, 0.4657]]])
├── 'b' --> tensor([[1.1637, 1.7539, 1.3673, 1.7596, 1.1629, 1.0991]])
└── 'c' --> <Tensor 0x7f83ecb97a30>
    ├── 'd' --> tensor([[8.]])
    └── 'noise' --> tensor([[[[ 0.5359, -1.1872, -1.1349, -1.2241, -0.2247],
                              [-1.4811,  0.5311, -0.5148,  1.7360, -0.6674],
                              [-1.5773,  0.3944, -0.5607, -0.4603,  0.6399],
                              [-0.3212, -0.4318,  0.0031,  1.1059, -0.9834]],
                    
                             [[-0.6793, -0.5348, -0.0504, -1.5032,  0.1691],
                              [ 1.6129, -0.6823,  1.0471, -1.1235,  0.4317],
                              [-0.4445,  0.0780,  0.3752,  1.8468,  1.1199],
                              [ 0.6156,  0.4850,  0.2734, -0.8555,  0.1752]],
                    
                             [[-0.1423,  0.0679, -0.5144, -1.3756, -0.8489],
                              [-0.7540, -1.1874, -0.3342, -0.2126,  0.8937],
                              [ 0.0817,  0.0894,  1.2505,  0.2829, -0.0346],
                              [-1.2883, -1.3640, -0.8059,  0.2859,  0.6248]]]])
, <Tensor 0x7f83ecb97c40>
├── 'a' --> tensor([[[0.3129, 0.4439, 0.8225, 0.1990, 0.4839, 0.6035, 0.6834, 0.5054],
│                    [0.3481, 0.6482, 0.3350, 0.7787, 0.2522, 0.7648, 0.1572, 0.7762],
│                    [0.5089, 0.3098, 0.3571, 0.4281, 0.0979, 0.2945, 0.0717, 0.9534]]])
├── 'b' --> tensor([[1.4898, 1.9633, 1.4804, 1.7171, 1.2424, 1.3674]])
└── 'c' --> <Tensor 0x7f83ecb97c10>
    ├── 'd' --> tensor([[7.]])
    └── 'noise' --> tensor([[[[ 0.5697, -0.5898, -0.8598, -0.5347,  1.8237],
                              [-1.5553, -0.3725, -0.7192,  0.9180, -1.4548],
                              [ 1.4326,  0.0348,  0.1363,  1.8379,  0.4643],
                              [ 0.0743, -0.4870,  0.6392, -0.0841,  0.3341]],
                    
                             [[-0.4257,  0.6298,  1.9045, -1.0348,  0.8331],
                              [ 0.8333,  0.7284, -0.0489, -0.6122,  0.1061],
                              [-0.0823, -1.4911, -1.6681,  2.9037, -0.0814],
                              [-1.3788,  0.2873,  0.1422,  1.1302,  0.6525]],
                    
                             [[ 1.3702,  0.3678,  1.2022,  0.3144,  0.3844],
                              [ 0.7469,  0.0259,  0.0321,  0.6302, -1.0288],
                              [ 0.4854,  0.9360, -0.1610,  0.4796,  0.5281],
                              [-0.3088,  0.4554, -0.6296,  0.0379,  0.4916]]]])
, <Tensor 0x7f83ecb97ca0>
├── 'a' --> tensor([[[0.5043, 0.1589, 0.5069, 0.4711, 0.3651, 0.7958, 0.4787, 0.3208],
│                    [0.4928, 0.9831, 0.0118, 0.6966, 0.4535, 0.9948, 0.7176, 0.1944],
│                    [0.0527, 0.1514, 0.3471, 0.2261, 0.1730, 0.7510, 0.1243, 0.8186]]])
├── 'b' --> tensor([[1.8538, 1.2059, 1.9369, 1.0288, 1.7937, 1.0420]])
└── 'c' --> <Tensor 0x7f83ecb97c70>
    ├── 'd' --> tensor([[8.]])
    └── 'noise' --> tensor([[[[ 1.6036e-01,  6.7931e-01,  1.4153e+00,  7.3368e-04,  2.9788e-01],
                              [ 8.1740e-01, -8.6354e-01, -1.5628e+00,  2.3292e-01, -6.9122e-01],
                              [ 1.5631e-01, -1.1406e+00,  2.9363e-01,  1.1081e+00,  7.3348e-01],
                              [ 5.8475e-01,  2.4910e-01, -3.2656e-01,  1.2840e+00, -7.0661e-01]],
                    
                             [[-1.3394e+00, -8.0459e-01,  8.6640e-01,  1.9354e+00,  1.1346e+00],
                              [-2.0040e-01,  6.3933e-01, -1.4943e+00,  7.8732e-01, -8.9331e-01],
                              [-1.5937e+00,  4.8034e-01, -1.3983e+00, -7.2016e-01,  5.6987e-01],
                              [ 8.3162e-01, -1.1837e+00,  1.4762e+00,  9.6509e-01,  2.4423e-02]],
                    
                             [[ 5.0932e-01, -1.0427e+00,  1.0486e+00,  3.6652e-01,  8.3352e-01],
                              [ 6.6314e-02, -1.0065e+00, -1.1592e-01, -2.3717e-01, -6.5655e-02],
                              [ 5.7803e-02, -6.4236e-01,  8.9233e-01,  9.9751e-01,  1.1912e-01],
                              [ 6.5045e-02, -1.7835e-01, -9.1286e-01, -7.8407e-01,  5.5585e-01]]]])
)
mean0 & mean1: tensor(1.4702) tensor(1.4702)


even_index_a0: tensor([[[0.7952, 0.5317, 0.9969, 0.7247, 0.9008, 0.8212, 0.3001, 0.8340],
         [0.4727, 0.5710, 0.9302, 0.8417, 0.7715, 0.8744, 0.8541, 0.4062]],

        [[0.1144, 0.8584, 0.1021, 0.4512, 0.2160, 0.4735, 0.1162, 0.4354],
         [0.8819, 0.8667, 0.0678, 0.5979, 0.7199, 0.5001, 0.9542, 0.4657]],

        [[0.3129, 0.4439, 0.8225, 0.1990, 0.4839, 0.6035, 0.6834, 0.5054],
         [0.5089, 0.3098, 0.3571, 0.4281, 0.0979, 0.2945, 0.0717, 0.9534]],

        [[0.5043, 0.1589, 0.5069, 0.4711, 0.3651, 0.7958, 0.4787, 0.3208],
         [0.0527, 0.1514, 0.3471, 0.2261, 0.1730, 0.7510, 0.1243, 0.8186]]])
even_index_a1: tensor([[[0.7952, 0.5317, 0.9969, 0.7247, 0.9008, 0.8212, 0.3001, 0.8340],
         [0.4727, 0.5710, 0.9302, 0.8417, 0.7715, 0.8744, 0.8541, 0.4062]],

        [[0.1144, 0.8584, 0.1021, 0.4512, 0.2160, 0.4735, 0.1162, 0.4354],
         [0.8819, 0.8667, 0.0678, 0.5979, 0.7199, 0.5001, 0.9542, 0.4657]],

        [[0.3129, 0.4439, 0.8225, 0.1990, 0.4839, 0.6035, 0.6834, 0.5054],
         [0.5089, 0.3098, 0.3571, 0.4281, 0.0979, 0.2945, 0.0717, 0.9534]],

        [[0.5043, 0.1589, 0.5069, 0.4711, 0.3651, 0.7958, 0.4787, 0.3208],
         [0.0527, 0.1514, 0.3471, 0.2261, 0.1730, 0.7510, 0.1243, 0.8186]]])
<Size 0x7f844caa3460>
├── 'a' --> torch.Size([1, 3, 8])
├── 'b' --> torch.Size([1, 6])
└── 'c' --> <Size 0x7f83ecc12460>
    ├── 'd' --> torch.Size([1, 1])
    └── 'noise' --> torch.Size([1, 3, 4, 5])

The implement with treetensor API is much simpler and clearer.