Customized Operations For Different Fields

Here is another example of the custom operations implemented with both native torch API and treetensor API.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import copy

import torch

import treetensor.torch as ttorch

T, B = 3, 4


def with_nativetensor(batch_):
    mean_b_list = []
    even_index_a_list = []
    for i in range(len(batch_)):
        for k, v in batch_[i].items():
            if k == 'a':
                v = v.float()
                even_index_a_list.append(v[::2])
            elif k == 'b':
                v = v.float()
                transformed_v = torch.pow(v, 2) + 1.0
                mean_b_list.append(transformed_v.mean())
            elif k == 'c':
                for k1, v1 in v.items():
                    if k1 == 'd':
                        v1 = v1.float()
                    else:
                        print('ignore keys: {}'.format(k1))
            else:
                print('ignore keys: {}'.format(k))
    for i in range(len(batch_)):
        for k in batch_[i].keys():
            if k == 'd':
                batch_[i][k]['noise'] = torch.randn(size=(3, 4, 5))

    mean_b = sum(mean_b_list) / len(mean_b_list)
    even_index_a = torch.stack(even_index_a_list, dim=0)
    return batch_, mean_b, even_index_a


def with_treetensor(batch_):
    batch_ = [ttorch.tensor(b) for b in batch_]
    batch_ = ttorch.stack(batch_)
    batch_ = batch_.float()
    batch_.b = ttorch.pow(batch_.b, 2) + 1.0
    batch_.c.noise = ttorch.randn(size=(B, 3, 4, 5))
    mean_b = batch_.b.mean()
    even_index_a = batch_.a[:, ::2]
    batch_ = ttorch.split(batch_, split_size_or_sections=1, dim=0)
    return batch_, mean_b, even_index_a


def get_data():
    return {
        'a': torch.rand(size=(T, 8)),
        'b': torch.rand(size=(6,)),
        'c': {
            'd': torch.randint(0, 10, size=(1,))
        }
    }


if __name__ == "__main__":
    batch = [get_data() for _ in range(B)]
    batch0, mean0, even_index_a0 = with_nativetensor(copy.deepcopy(batch))
    batch1, mean1, even_index_a1 = with_treetensor(copy.deepcopy(batch))
    print(batch0)
    print('\n\n')
    print(batch1)

    assert torch.abs(mean0 - mean1) < 1e-6
    print('mean0 & mean1:', mean0, mean1)
    print('\n')

    assert torch.abs((even_index_a0 - even_index_a1).max()) < 1e-6
    print('even_index_a0:', even_index_a0)
    print('even_index_a1:', even_index_a1)

    assert len(batch0) == B
    assert len(batch1) == B
    assert isinstance(batch1[0], ttorch.Tensor)
    print(batch1[0].shape)

The output should be like below, and all the assertions can be passed.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
[{'a': tensor([[0.3966, 0.8244, 0.6097, 0.9855, 0.3798, 0.1565, 0.0098, 0.6476],
        [0.7319, 0.6274, 0.9433, 0.5831, 0.9184, 0.4625, 0.5958, 0.0988],
        [0.2903, 0.5619, 0.5450, 0.3228, 0.9717, 0.0112, 0.0610, 0.6176]]), 'b': tensor([0.1893, 0.4007, 0.8237, 0.2230, 0.6610, 0.9872]), 'c': {'d': tensor([4])}}, {'a': tensor([[0.3677, 0.3901, 0.2094, 0.3252, 0.2418, 0.2013, 0.3877, 0.8925],
        [0.9920, 0.1467, 0.1470, 0.9469, 0.8412, 0.6267, 0.9276, 0.3556],
        [0.9186, 0.9478, 0.0283, 0.2015, 0.2140, 0.2504, 0.2546, 0.5116]]), 'b': tensor([0.8362, 0.4179, 0.0113, 0.6096, 0.2867, 0.7525]), 'c': {'d': tensor([9])}}, {'a': tensor([[0.8125, 0.0259, 0.0690, 0.0772, 0.8263, 0.6004, 0.4874, 0.0964],
        [0.3606, 0.3776, 0.7780, 0.1085, 0.2932, 0.3797, 0.1329, 0.9739],
        [0.2839, 0.3744, 0.7001, 0.6329, 0.1951, 0.7454, 0.9891, 0.1892]]), 'b': tensor([0.7364, 0.5097, 0.4577, 0.9647, 0.4244, 0.2281]), 'c': {'d': tensor([0])}}, {'a': tensor([[0.1986, 0.1622, 0.4137, 0.1405, 0.9052, 0.2589, 0.3613, 0.9998],
        [0.1555, 0.0922, 0.4355, 0.9130, 0.4442, 0.6955, 0.8678, 0.7797],
        [0.6549, 0.6488, 0.5300, 0.4422, 0.9739, 0.3584, 0.5158, 0.0156]]), 'b': tensor([0.5408, 0.1078, 0.5927, 0.0799, 0.6125, 0.4272]), 'c': {'d': tensor([0])}}]



(<Tensor 0x7f9185d152e0>
├── 'a' --> tensor([[[0.3966, 0.8244, 0.6097, 0.9855, 0.3798, 0.1565, 0.0098, 0.6476],
│                    [0.7319, 0.6274, 0.9433, 0.5831, 0.9184, 0.4625, 0.5958, 0.0988],
│                    [0.2903, 0.5619, 0.5450, 0.3228, 0.9717, 0.0112, 0.0610, 0.6176]]])
├── 'b' --> tensor([[1.0358, 1.1605, 1.6784, 1.0497, 1.4369, 1.9745]])
└── 'c' --> <Tensor 0x7f9185d15340>
    ├── 'd' --> tensor([[4.]])
    └── 'noise' --> tensor([[[[ 0.2562, -0.6548,  0.5631,  0.7738, -0.4308],
                              [-1.8951, -0.1646, -1.1050,  0.5682, -0.1726],
                              [ 1.4759,  0.6835,  0.0414, -0.1966, -0.3878],
                              [ 1.8053,  1.8906, -0.8909, -0.0159, -0.1881]],
                    
                             [[-0.2894,  0.3312,  0.8342, -0.5001, -1.4112],
                              [ 1.2179,  0.0530, -1.9090,  1.1412, -0.5141],
                              [-1.0095,  1.7890, -0.6283, -1.4054,  1.0486],
                              [ 1.1588, -1.2241, -0.1511,  1.7479, -1.0159]],
                    
                             [[-0.5286,  0.0276, -0.1845,  0.7690,  0.3237],
                              [ 1.4234,  0.9676, -0.7948,  0.2739,  1.2180],
                              [ 0.1559, -1.5447, -1.1592, -1.3781,  0.4401],
                              [ 0.1588, -0.9652, -0.1542,  0.1192, -1.5558]]]])
, <Tensor 0x7f9185d153a0>
├── 'a' --> tensor([[[0.3677, 0.3901, 0.2094, 0.3252, 0.2418, 0.2013, 0.3877, 0.8925],
│                    [0.9920, 0.1467, 0.1470, 0.9469, 0.8412, 0.6267, 0.9276, 0.3556],
│                    [0.9186, 0.9478, 0.0283, 0.2015, 0.2140, 0.2504, 0.2546, 0.5116]]])
├── 'b' --> tensor([[1.6992, 1.1747, 1.0001, 1.3716, 1.0822, 1.5662]])
└── 'c' --> <Tensor 0x7f9185d15280>
    ├── 'd' --> tensor([[9.]])
    └── 'noise' --> tensor([[[[ 1.2127e-01, -9.7338e-02, -1.0641e+00, -2.2300e-01,  1.2607e-01],
                              [-1.4654e+00,  3.2072e-01, -1.4064e-01, -5.3356e-01,  4.7993e-01],
                              [-9.7642e-03,  6.9599e-01,  6.9635e-01, -7.9874e-01,  2.9000e-01],
                              [ 1.7661e+00,  6.4294e-01,  1.9043e-01,  8.6431e-01,  9.0994e-04]],
                    
                             [[ 1.1610e+00,  2.1799e-01,  8.5444e-01, -2.2679e+00, -6.6386e-02],
                              [ 1.5189e+00, -1.9001e-01, -4.4042e-01, -2.0036e-01, -1.5791e+00],
                              [ 7.7262e-01,  1.4848e+00, -1.5760e+00,  8.5489e-01,  3.5849e-01],
                              [ 6.9080e-01, -1.8826e-02, -1.2874e+00,  6.2864e-01, -2.4863e+00]],
                    
                             [[-1.1041e+00, -1.6640e-01, -1.5837e+00,  5.4099e-02,  1.1398e-01],
                              [ 2.6030e+00,  5.8517e-01, -1.0918e+00,  2.6184e+00,  3.9944e-01],
                              [ 5.3750e-02, -1.2791e+00,  9.5119e-01,  6.2373e-01,  9.7606e-02],
                              [-5.8499e-01, -6.5673e-01,  2.9786e-02, -6.3741e-01, -1.7399e-01]]]])
, <Tensor 0x7f9185d15400>
├── 'a' --> tensor([[[0.8125, 0.0259, 0.0690, 0.0772, 0.8263, 0.6004, 0.4874, 0.0964],
│                    [0.3606, 0.3776, 0.7780, 0.1085, 0.2932, 0.3797, 0.1329, 0.9739],
│                    [0.2839, 0.3744, 0.7001, 0.6329, 0.1951, 0.7454, 0.9891, 0.1892]]])
├── 'b' --> tensor([[1.5423, 1.2598, 1.2095, 1.9306, 1.1801, 1.0520]])
└── 'c' --> <Tensor 0x7f9185d153d0>
    ├── 'd' --> tensor([[0.]])
    └── 'noise' --> tensor([[[[ 2.0106e+00,  1.1619e+00, -1.6692e+00, -1.4201e+00,  5.1294e-01],
                              [ 2.4946e-01, -6.3604e-01, -2.1080e+00,  9.1723e-01,  3.5777e-02],
                              [-5.1351e-01,  1.0697e+00, -1.3922e+00, -3.2787e-01, -5.1948e-01],
                              [ 2.5014e-01,  4.4103e-01,  8.9880e-01, -1.9937e-01,  8.3412e-01]],
                    
                             [[-1.9194e+00,  8.0426e-01, -8.3870e-01,  1.5787e+00,  1.3566e-01],
                              [ 2.5430e+00, -1.8173e+00, -7.5265e-02, -5.3250e-02,  6.2273e-01],
                              [-3.4499e-01, -1.2787e+00,  7.6419e-01, -5.5824e-01, -7.7401e-01],
                              [ 3.0583e-01, -1.8303e-01,  1.1664e+00, -2.1249e-03,  8.2907e-01]],
                    
                             [[ 2.7347e+00,  1.0071e+00, -9.2523e-01, -1.5009e+00,  2.5365e-01],
                              [ 2.9202e-01,  6.2300e-01, -1.3104e+00, -1.4764e+00, -1.8070e+00],
                              [ 1.6874e-01, -2.0066e+00,  4.6249e-01, -1.3472e+00, -5.5880e-01],
                              [-2.7539e-01,  1.1854e-01, -1.3995e+00, -7.6103e-01, -2.8695e-01]]]])
, <Tensor 0x7f9185d15460>
├── 'a' --> tensor([[[0.1986, 0.1622, 0.4137, 0.1405, 0.9052, 0.2589, 0.3613, 0.9998],
│                    [0.1555, 0.0922, 0.4355, 0.9130, 0.4442, 0.6955, 0.8678, 0.7797],
│                    [0.6549, 0.6488, 0.5300, 0.4422, 0.9739, 0.3584, 0.5158, 0.0156]]])
├── 'b' --> tensor([[1.2924, 1.0116, 1.3513, 1.0064, 1.3752, 1.1825]])
└── 'c' --> <Tensor 0x7f9185d15430>
    ├── 'd' --> tensor([[0.]])
    └── 'noise' --> tensor([[[[ 1.4070, -0.6493,  0.1467,  1.6439,  0.5981],
                              [ 1.0487,  1.4318, -1.0580, -0.5041,  0.4501],
                              [-0.3174,  0.4318,  0.7212,  0.2747,  2.0658],
                              [-2.0758,  0.9897, -0.5453, -0.2644, -0.8066]],
                    
                             [[ 0.1859,  0.4662,  0.9531,  0.5400, -0.5427],
                              [ 0.1347,  0.6061,  0.1726, -0.6139, -1.7318],
                              [ 0.1542,  1.4777, -0.5712,  1.3656,  0.4189],
                              [-0.7131, -1.7418, -2.1632,  0.7465, -1.4006]],
                    
                             [[-1.2386, -1.2326,  1.6733,  0.2993,  0.1648],
                              [ 1.0379, -0.4832,  0.2289,  0.6771, -0.3427],
                              [-0.1972,  0.2718,  1.2705, -1.5191,  0.0776],
                              [ 0.1047, -0.0416, -0.1009, -1.7804,  0.2607]]]])
)
mean0 & mean1: tensor(1.3177) tensor(1.3177)


even_index_a0: tensor([[[0.3966, 0.8244, 0.6097, 0.9855, 0.3798, 0.1565, 0.0098, 0.6476],
         [0.2903, 0.5619, 0.5450, 0.3228, 0.9717, 0.0112, 0.0610, 0.6176]],

        [[0.3677, 0.3901, 0.2094, 0.3252, 0.2418, 0.2013, 0.3877, 0.8925],
         [0.9186, 0.9478, 0.0283, 0.2015, 0.2140, 0.2504, 0.2546, 0.5116]],

        [[0.8125, 0.0259, 0.0690, 0.0772, 0.8263, 0.6004, 0.4874, 0.0964],
         [0.2839, 0.3744, 0.7001, 0.6329, 0.1951, 0.7454, 0.9891, 0.1892]],

        [[0.1986, 0.1622, 0.4137, 0.1405, 0.9052, 0.2589, 0.3613, 0.9998],
         [0.6549, 0.6488, 0.5300, 0.4422, 0.9739, 0.3584, 0.5158, 0.0156]]])
even_index_a1: tensor([[[0.3966, 0.8244, 0.6097, 0.9855, 0.3798, 0.1565, 0.0098, 0.6476],
         [0.2903, 0.5619, 0.5450, 0.3228, 0.9717, 0.0112, 0.0610, 0.6176]],

        [[0.3677, 0.3901, 0.2094, 0.3252, 0.2418, 0.2013, 0.3877, 0.8925],
         [0.9186, 0.9478, 0.0283, 0.2015, 0.2140, 0.2504, 0.2546, 0.5116]],

        [[0.8125, 0.0259, 0.0690, 0.0772, 0.8263, 0.6004, 0.4874, 0.0964],
         [0.2839, 0.3744, 0.7001, 0.6329, 0.1951, 0.7454, 0.9891, 0.1892]],

        [[0.1986, 0.1622, 0.4137, 0.1405, 0.9052, 0.2589, 0.3613, 0.9998],
         [0.6549, 0.6488, 0.5300, 0.4422, 0.9739, 0.3584, 0.5158, 0.0156]]])
<Size 0x7f91eca957c0>
├── 'a' --> torch.Size([1, 3, 8])
├── 'b' --> torch.Size([1, 6])
└── 'c' --> <Size 0x7f9185d10ee0>
    ├── 'd' --> torch.Size([1, 1])
    └── 'noise' --> torch.Size([1, 3, 4, 5])

The implement with treetensor API is much simpler and clearer.