Customized Operations For Different Fields

Here is another example of the custom operations implemented with both native torch API and treetensor API.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import copy

import torch

import treetensor.torch as ttorch

T, B = 3, 4


def with_nativetensor(batch_):
    mean_b_list = []
    even_index_a_list = []
    for i in range(len(batch_)):
        for k, v in batch_[i].items():
            if k == 'a':
                v = v.float()
                even_index_a_list.append(v[::2])
            elif k == 'b':
                v = v.float()
                transformed_v = torch.pow(v, 2) + 1.0
                mean_b_list.append(transformed_v.mean())
            elif k == 'c':
                for k1, v1 in v.items():
                    if k1 == 'd':
                        v1 = v1.float()
                    else:
                        print('ignore keys: {}'.format(k1))
            else:
                print('ignore keys: {}'.format(k))
    for i in range(len(batch_)):
        for k in batch_[i].keys():
            if k == 'd':
                batch_[i][k]['noise'] = torch.randn(size=(3, 4, 5))

    mean_b = sum(mean_b_list) / len(mean_b_list)
    even_index_a = torch.stack(even_index_a_list, dim=0)
    return batch_, mean_b, even_index_a


def with_treetensor(batch_):
    batch_ = [ttorch.tensor(b) for b in batch_]
    batch_ = ttorch.stack(batch_)
    batch_ = batch_.float()
    batch_.b = ttorch.pow(batch_.b, 2) + 1.0
    batch_.c.noise = ttorch.randn(size=(B, 3, 4, 5))
    mean_b = batch_.b.mean()
    even_index_a = batch_.a[:, ::2]
    batch_ = ttorch.split(batch_, split_size_or_sections=1, dim=0)
    return batch_, mean_b, even_index_a


def get_data():
    return {
        'a': torch.rand(size=(T, 8)),
        'b': torch.rand(size=(6,)),
        'c': {
            'd': torch.randint(0, 10, size=(1,))
        }
    }


if __name__ == "__main__":
    batch = [get_data() for _ in range(B)]
    batch0, mean0, even_index_a0 = with_nativetensor(copy.deepcopy(batch))
    batch1, mean1, even_index_a1 = with_treetensor(copy.deepcopy(batch))
    print(batch0)
    print('\n\n')
    print(batch1)

    assert torch.abs(mean0 - mean1) < 1e-6
    print('mean0 & mean1:', mean0, mean1)
    print('\n')

    assert torch.abs((even_index_a0 - even_index_a1).max()) < 1e-6
    print('even_index_a0:', even_index_a0)
    print('even_index_a1:', even_index_a1)

    assert len(batch0) == B
    assert len(batch1) == B
    assert isinstance(batch1[0], ttorch.Tensor)
    print(batch1[0].shape)

The output should be like below, and all the assertions can be passed.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
[{'a': tensor([[0.8665, 0.9088, 0.7569, 0.6825, 0.1302, 0.6438, 0.3450, 0.2007],
        [0.4946, 0.0328, 0.6505, 0.2378, 0.2702, 0.2210, 0.3800, 0.7847],
        [0.2095, 0.2963, 0.7101, 0.6690, 0.6303, 0.2929, 0.1435, 0.6998]]), 'b': tensor([0.1214, 0.0274, 0.9256, 0.8215, 0.8466, 0.1902]), 'c': {'d': tensor([7])}}, {'a': tensor([[0.2431, 0.1269, 0.3689, 0.8099, 0.3914, 0.3421, 0.4355, 0.2371],
        [0.8309, 0.7894, 0.9701, 0.1638, 0.1176, 0.2241, 0.2241, 0.2522],
        [0.9711, 0.0218, 0.9972, 0.7963, 0.6596, 0.9119, 0.1892, 0.8936]]), 'b': tensor([0.8208, 0.6126, 0.4234, 0.9650, 0.0900, 0.2385]), 'c': {'d': tensor([9])}}, {'a': tensor([[0.5688, 0.3217, 0.8575, 0.8436, 0.1835, 0.8896, 0.0210, 0.5455],
        [0.9954, 0.7660, 0.1734, 0.3155, 0.3065, 0.6172, 0.0347, 0.4582],
        [0.4729, 0.9954, 0.5076, 0.7531, 0.3406, 0.5855, 0.5058, 0.1206]]), 'b': tensor([0.5111, 0.5044, 0.6388, 0.7865, 0.7909, 0.2100]), 'c': {'d': tensor([4])}}, {'a': tensor([[0.8107, 0.8194, 0.2188, 0.4298, 0.1951, 0.0313, 0.6070, 0.7073],
        [0.8138, 0.7121, 0.9815, 0.9346, 0.2885, 0.0585, 0.1913, 0.3806],
        [0.8669, 0.1006, 0.4896, 0.3018, 0.6727, 0.4490, 0.5455, 0.3220]]), 'b': tensor([0.2372, 0.8830, 0.6201, 0.6245, 0.6132, 0.4773]), 'c': {'d': tensor([7])}}]



(<Tensor 0x7f237c162580>
├── 'a' --> tensor([[[0.8665, 0.9088, 0.7569, 0.6825, 0.1302, 0.6438, 0.3450, 0.2007],
│                    [0.4946, 0.0328, 0.6505, 0.2378, 0.2702, 0.2210, 0.3800, 0.7847],
│                    [0.2095, 0.2963, 0.7101, 0.6690, 0.6303, 0.2929, 0.1435, 0.6998]]])
├── 'b' --> tensor([[1.0147, 1.0007, 1.8567, 1.6749, 1.7167, 1.0362]])
└── 'c' --> <Tensor 0x7f237c1625e0>
    ├── 'd' --> tensor([[7.]])
    └── 'noise' --> tensor([[[[-0.3026,  0.3336,  0.7056, -1.2039,  1.7792],
                              [-0.6242,  0.0618, -0.5512, -0.1455,  1.1428],
                              [ 1.1051,  2.3970,  0.9906,  0.4791,  0.2180],
                              [-1.5307, -0.6763,  0.0491, -1.1093,  0.3926]],
                    
                             [[ 1.2435,  0.2536, -0.5028, -0.8716,  1.1416],
                              [-0.2970,  0.1877,  0.4230, -0.4240, -0.4371],
                              [-0.8912,  1.6403,  0.2320, -1.6659, -0.3558],
                              [ 0.6275,  0.2561, -1.5862,  2.0024, -0.9043]],
                    
                             [[ 0.7452, -1.8395,  0.3822, -0.9894, -1.5323],
                              [-0.0126,  0.0525, -0.0981, -0.8024,  1.4610],
                              [-1.0592,  0.4134,  0.6497,  0.1778, -0.3963],
                              [ 0.2844,  1.7031,  0.6817, -1.0866,  0.7724]]]])
, <Tensor 0x7f237c1626a0>
├── 'a' --> tensor([[[0.2431, 0.1269, 0.3689, 0.8099, 0.3914, 0.3421, 0.4355, 0.2371],
│                    [0.8309, 0.7894, 0.9701, 0.1638, 0.1176, 0.2241, 0.2241, 0.2522],
│                    [0.9711, 0.0218, 0.9972, 0.7963, 0.6596, 0.9119, 0.1892, 0.8936]]])
├── 'b' --> tensor([[1.6738, 1.3753, 1.1793, 1.9312, 1.0081, 1.0569]])
└── 'c' --> <Tensor 0x7f237c1624c0>
    ├── 'd' --> tensor([[9.]])
    └── 'noise' --> tensor([[[[-2.5049e+00,  2.0461e-01, -6.6002e-01,  1.8798e-01,  8.0707e-01],
                              [ 5.5676e-01,  1.0060e+00, -9.4148e-01, -1.4495e+00, -6.5691e-01],
                              [ 2.1647e+00, -8.5329e-01, -1.3162e+00,  3.0015e-01, -4.9607e-01],
                              [-8.6826e-01, -1.3946e+00,  1.2714e+00, -3.2209e-01,  9.2902e-01]],
                    
                             [[ 1.5400e+00, -1.1072e+00,  4.0012e-01,  1.4014e-04,  3.5665e-01],
                              [ 6.9737e-01,  4.0163e-01, -1.3559e+00,  2.0701e-01,  6.3983e-01],
                              [ 1.3846e-01, -1.3418e+00, -4.2567e-01,  6.8998e-01, -3.8492e-01],
                              [ 3.6196e-01, -1.5220e-01,  6.6868e-01,  2.6043e-01, -4.0372e-02]],
                    
                             [[ 6.1624e-01, -4.3484e-01, -1.7773e+00,  3.4712e-01,  2.7517e-01],
                              [ 1.1534e+00, -9.2221e-01,  1.7659e+00,  1.2852e+00, -2.8716e-01],
                              [ 5.9528e-02, -1.1638e+00, -2.4237e-01, -1.0312e+00,  1.2858e+00],
                              [ 3.6574e-01, -6.7653e-02,  4.4200e-01, -7.2800e-01,  2.7208e-01]]]])
, <Tensor 0x7f237c162700>
├── 'a' --> tensor([[[0.5688, 0.3217, 0.8575, 0.8436, 0.1835, 0.8896, 0.0210, 0.5455],
│                    [0.9954, 0.7660, 0.1734, 0.3155, 0.3065, 0.6172, 0.0347, 0.4582],
│                    [0.4729, 0.9954, 0.5076, 0.7531, 0.3406, 0.5855, 0.5058, 0.1206]]])
├── 'b' --> tensor([[1.2612, 1.2544, 1.4080, 1.6186, 1.6255, 1.0441]])
└── 'c' --> <Tensor 0x7f237c1626d0>
    ├── 'd' --> tensor([[4.]])
    └── 'noise' --> tensor([[[[ 0.1353, -1.8480, -1.6621,  0.5799, -0.2800],
                              [ 1.4138, -0.0113,  0.0177,  0.8198,  0.7301],
                              [-0.6666,  0.0304, -0.6770,  0.5411,  0.5198],
                              [ 1.5511, -0.1262, -0.6161, -1.1997, -1.5366]],
                    
                             [[ 0.1982,  1.4697,  1.3387,  0.8497, -0.5123],
                              [ 0.4804, -0.9265,  0.8640,  0.4346,  0.1220],
                              [ 2.0205,  0.1645, -0.3352, -1.1066, -0.7150],
                              [ 0.5400, -0.4626, -1.1799,  0.5766, -0.6324]],
                    
                             [[ 2.1123,  1.2872,  1.4614,  0.7326,  1.3558],
                              [-0.1035, -0.5072, -1.4503, -1.8749,  1.9906],
                              [-0.0079,  0.3056,  1.6099,  0.0699,  0.6805],
                              [-0.8504, -0.3356,  0.7516,  1.1599,  0.3744]]]])
, <Tensor 0x7f237c162760>
├── 'a' --> tensor([[[0.8107, 0.8194, 0.2188, 0.4298, 0.1951, 0.0313, 0.6070, 0.7073],
│                    [0.8138, 0.7121, 0.9815, 0.9346, 0.2885, 0.0585, 0.1913, 0.3806],
│                    [0.8669, 0.1006, 0.4896, 0.3018, 0.6727, 0.4490, 0.5455, 0.3220]]])
├── 'b' --> tensor([[1.0563, 1.7797, 1.3846, 1.3900, 1.3760, 1.2278]])
└── 'c' --> <Tensor 0x7f237c162730>
    ├── 'd' --> tensor([[7.]])
    └── 'noise' --> tensor([[[[ 1.8565,  0.9559, -2.7744,  0.9325, -0.5216],
                              [ 0.4288, -1.1126, -0.7462,  1.5362,  1.4256],
                              [-0.9266, -0.8704, -1.3405, -0.8918, -0.2227],
                              [-0.5473, -1.3886, -0.3050, -1.6887, -1.3981]],
                    
                             [[-0.9450,  0.9875, -0.8314, -0.3334,  0.5684],
                              [ 0.5428,  0.3821, -0.6291,  0.8887,  0.6346],
                              [-0.4497, -0.5102,  0.5783, -0.1734,  1.0046],
                              [-0.5322,  1.5718,  0.1697, -0.4856, -0.5880]],
                    
                             [[-0.6278,  1.3811, -0.0196,  0.8296,  0.6133],
                              [-0.7989,  0.4975, -0.6005,  1.6146,  0.4410],
                              [-0.2954, -0.7136, -0.2561, -0.0071,  0.8017],
                              [ 0.4188, -1.1521, -0.9709, -1.2668, -0.3882]]]])
)
mean0 & mean1: tensor(1.3729) tensor(1.3729)


even_index_a0: tensor([[[0.8665, 0.9088, 0.7569, 0.6825, 0.1302, 0.6438, 0.3450, 0.2007],
         [0.2095, 0.2963, 0.7101, 0.6690, 0.6303, 0.2929, 0.1435, 0.6998]],

        [[0.2431, 0.1269, 0.3689, 0.8099, 0.3914, 0.3421, 0.4355, 0.2371],
         [0.9711, 0.0218, 0.9972, 0.7963, 0.6596, 0.9119, 0.1892, 0.8936]],

        [[0.5688, 0.3217, 0.8575, 0.8436, 0.1835, 0.8896, 0.0210, 0.5455],
         [0.4729, 0.9954, 0.5076, 0.7531, 0.3406, 0.5855, 0.5058, 0.1206]],

        [[0.8107, 0.8194, 0.2188, 0.4298, 0.1951, 0.0313, 0.6070, 0.7073],
         [0.8669, 0.1006, 0.4896, 0.3018, 0.6727, 0.4490, 0.5455, 0.3220]]])
even_index_a1: tensor([[[0.8665, 0.9088, 0.7569, 0.6825, 0.1302, 0.6438, 0.3450, 0.2007],
         [0.2095, 0.2963, 0.7101, 0.6690, 0.6303, 0.2929, 0.1435, 0.6998]],

        [[0.2431, 0.1269, 0.3689, 0.8099, 0.3914, 0.3421, 0.4355, 0.2371],
         [0.9711, 0.0218, 0.9972, 0.7963, 0.6596, 0.9119, 0.1892, 0.8936]],

        [[0.5688, 0.3217, 0.8575, 0.8436, 0.1835, 0.8896, 0.0210, 0.5455],
         [0.4729, 0.9954, 0.5076, 0.7531, 0.3406, 0.5855, 0.5058, 0.1206]],

        [[0.8107, 0.8194, 0.2188, 0.4298, 0.1951, 0.0313, 0.6070, 0.7073],
         [0.8669, 0.1006, 0.4896, 0.3018, 0.6727, 0.4490, 0.5455, 0.3220]]])
<Size 0x7f23e202e430>
├── 'a' --> torch.Size([1, 3, 8])
├── 'b' --> torch.Size([1, 6])
└── 'c' --> <Size 0x7f237c1ed940>
    ├── 'd' --> torch.Size([1, 1])
    └── 'noise' --> torch.Size([1, 3, 4, 5])

The implement with treetensor API is much simpler and clearer.