Customized Operations For Different Fields

Here is another example of the custom operations implemented with both native torch API and treetensor API.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import copy

import torch

import treetensor.torch as ttorch

T, B = 3, 4


def with_nativetensor(batch_):
    mean_b_list = []
    even_index_a_list = []
    for i in range(len(batch_)):
        for k, v in batch_[i].items():
            if k == 'a':
                v = v.float()
                even_index_a_list.append(v[::2])
            elif k == 'b':
                v = v.float()
                transformed_v = torch.pow(v, 2) + 1.0
                mean_b_list.append(transformed_v.mean())
            elif k == 'c':
                for k1, v1 in v.items():
                    if k1 == 'd':
                        v1 = v1.float()
                    else:
                        print('ignore keys: {}'.format(k1))
            else:
                print('ignore keys: {}'.format(k))
    for i in range(len(batch_)):
        for k in batch_[i].keys():
            if k == 'd':
                batch_[i][k]['noise'] = torch.randn(size=(3, 4, 5))

    mean_b = sum(mean_b_list) / len(mean_b_list)
    even_index_a = torch.stack(even_index_a_list, dim=0)
    return batch_, mean_b, even_index_a


def with_treetensor(batch_):
    batch_ = [ttorch.tensor(b) for b in batch_]
    batch_ = ttorch.stack(batch_)
    batch_ = batch_.float()
    batch_.b = ttorch.pow(batch_.b, 2) + 1.0
    batch_.c.noise = ttorch.randn(size=(B, 3, 4, 5))
    mean_b = batch_.b.mean()
    even_index_a = batch_.a[:, ::2]
    batch_ = ttorch.split(batch_, split_size_or_sections=1, dim=0)
    return batch_, mean_b, even_index_a


def get_data():
    return {
        'a': torch.rand(size=(T, 8)),
        'b': torch.rand(size=(6,)),
        'c': {
            'd': torch.randint(0, 10, size=(1,))
        }
    }


if __name__ == "__main__":
    batch = [get_data() for _ in range(B)]
    batch0, mean0, even_index_a0 = with_nativetensor(copy.deepcopy(batch))
    batch1, mean1, even_index_a1 = with_treetensor(copy.deepcopy(batch))
    print(batch0)
    print('\n\n')
    print(batch1)

    assert torch.abs(mean0 - mean1) < 1e-6
    print('mean0 & mean1:', mean0, mean1)
    print('\n')

    assert torch.abs((even_index_a0 - even_index_a1).max()) < 1e-6
    print('even_index_a0:', even_index_a0)
    print('even_index_a1:', even_index_a1)

    assert len(batch0) == B
    assert len(batch1) == B
    assert isinstance(batch1[0], ttorch.Tensor)
    print(batch1[0].shape)

The output should be like below, and all the assertions can be passed.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
[{'a': tensor([[0.0730, 0.2810, 0.7642, 0.3265, 0.8006, 0.1291, 0.4410, 0.5144],
        [0.2621, 0.1805, 0.5549, 0.1331, 0.6024, 0.3018, 0.3112, 0.3647],
        [0.6566, 0.9296, 0.2867, 0.5219, 0.4429, 0.5960, 0.9051, 0.1431]]), 'b': tensor([0.5313, 0.5180, 0.4947, 0.9723, 0.5365, 0.1009]), 'c': {'d': tensor([6])}}, {'a': tensor([[0.8982, 0.5922, 0.7774, 0.2902, 0.1801, 0.7970, 0.9584, 0.3310],
        [0.3842, 0.6901, 0.4248, 0.4139, 0.4342, 0.4262, 0.6012, 0.4525],
        [0.2149, 0.8399, 0.4481, 0.6800, 0.4614, 0.9342, 0.2913, 0.7659]]), 'b': tensor([0.6653, 0.9887, 0.7088, 0.3472, 0.3131, 0.4827]), 'c': {'d': tensor([3])}}, {'a': tensor([[0.8102, 0.6209, 0.8538, 0.5519, 0.0497, 0.8272, 0.4520, 0.4337],
        [0.9679, 0.5910, 0.6493, 0.3332, 0.1804, 0.4862, 0.7153, 0.6898],
        [0.0253, 0.8591, 0.0775, 0.0992, 0.3756, 0.9941, 0.9854, 0.1591]]), 'b': tensor([0.0182, 0.7272, 0.7945, 0.7043, 0.3251, 0.6554]), 'c': {'d': tensor([3])}}, {'a': tensor([[0.9523, 0.1690, 0.1028, 0.6150, 0.6532, 0.8465, 0.0777, 0.1573],
        [0.1165, 0.3488, 0.7171, 0.5685, 0.3000, 0.3334, 0.8046, 0.0748],
        [0.8846, 0.9219, 0.5501, 0.1917, 0.0404, 0.4133, 0.0725, 0.9250]]), 'b': tensor([0.8136, 0.9406, 0.3927, 0.1013, 0.9801, 0.3492]), 'c': {'d': tensor([0])}}]



(<Tensor 0x7fc906fdac10>
├── 'a' --> tensor([[[0.0730, 0.2810, 0.7642, 0.3265, 0.8006, 0.1291, 0.4410, 0.5144],
│                    [0.2621, 0.1805, 0.5549, 0.1331, 0.6024, 0.3018, 0.3112, 0.3647],
│                    [0.6566, 0.9296, 0.2867, 0.5219, 0.4429, 0.5960, 0.9051, 0.1431]]])
├── 'b' --> tensor([[1.2822, 1.2684, 1.2447, 1.9455, 1.2878, 1.0102]])
└── 'c' --> <Tensor 0x7fc906fdac70>
    ├── 'd' --> tensor([[6.]])
    └── 'noise' --> tensor([[[[-2.0164e+00,  1.6905e+00,  1.6358e+00,  1.1691e+00,  5.7766e-01],
                              [ 1.3527e+00, -1.0914e+00, -4.6705e-02,  2.0240e+00, -1.8224e+00],
                              [ 5.2425e-01, -1.7212e-01, -3.2210e-01,  1.3389e+00, -1.0577e+00],
                              [-1.3967e-02,  1.0815e+00,  5.0475e-01, -5.8208e-01,  6.0741e-02]],
                    
                             [[ 4.3895e-01,  1.3226e+00, -6.5097e-01,  1.1502e-01,  4.6786e-01],
                              [-8.2064e-04,  9.1445e-02, -4.7237e-01, -2.4750e-01, -1.9809e-01],
                              [ 7.3672e-01, -8.6755e-01,  3.4016e-02, -4.4958e-02, -1.8747e+00],
                              [-5.5719e-01,  5.0298e-01, -1.3750e+00,  8.4211e-01,  1.0541e+00]],
                    
                             [[-1.9202e-01,  7.3596e-01, -1.5025e+00,  6.4057e-01, -1.2579e+00],
                              [-7.7307e-01, -6.7083e-01,  1.7407e+00, -1.0105e-01,  2.7708e-01],
                              [-4.9228e-01,  1.7454e+00,  2.3372e-01, -1.0487e+00, -1.8510e+00],
                              [ 2.3535e+00, -1.7593e+00,  2.1347e-01, -6.7427e-01, -1.1731e+00]]]])
, <Tensor 0x7fc906fdad30>
├── 'a' --> tensor([[[0.8982, 0.5922, 0.7774, 0.2902, 0.1801, 0.7970, 0.9584, 0.3310],
│                    [0.3842, 0.6901, 0.4248, 0.4139, 0.4342, 0.4262, 0.6012, 0.4525],
│                    [0.2149, 0.8399, 0.4481, 0.6800, 0.4614, 0.9342, 0.2913, 0.7659]]])
├── 'b' --> tensor([[1.4426, 1.9774, 1.5024, 1.1205, 1.0980, 1.2330]])
└── 'c' --> <Tensor 0x7fc906fdaac0>
    ├── 'd' --> tensor([[3.]])
    └── 'noise' --> tensor([[[[-1.5450, -1.6905,  0.4312, -0.2604, -0.4865],
                              [ 0.5222, -0.4960, -0.7188, -0.7926, -0.6251],
                              [ 0.8556,  1.6708, -0.5835, -0.1483, -0.5698],
                              [ 1.6547, -0.7003,  1.0590, -2.4264,  0.5141]],
                    
                             [[-0.0606,  0.1895, -1.0684, -0.7247,  1.3095],
                              [ 2.5292, -1.1486,  0.2688,  0.3017,  0.3331],
                              [-0.7841,  1.5898,  1.6363, -0.7597,  0.2299],
                              [ 0.5379,  0.1701,  0.1874, -0.1444,  1.3577]],
                    
                             [[-0.3337, -0.1685, -0.8499, -1.2491,  1.3864],
                              [-0.5003,  0.8336, -1.8690, -1.2475,  0.2396],
                              [-0.1194,  1.3889, -0.6624, -1.2659, -0.1942],
                              [-0.1443, -0.8676,  0.7845,  0.0346,  0.8877]]]])
, <Tensor 0x7fc906fdad90>
├── 'a' --> tensor([[[0.8102, 0.6209, 0.8538, 0.5519, 0.0497, 0.8272, 0.4520, 0.4337],
│                    [0.9679, 0.5910, 0.6493, 0.3332, 0.1804, 0.4862, 0.7153, 0.6898],
│                    [0.0253, 0.8591, 0.0775, 0.0992, 0.3756, 0.9941, 0.9854, 0.1591]]])
├── 'b' --> tensor([[1.0003, 1.5289, 1.6312, 1.4960, 1.1057, 1.4296]])
└── 'c' --> <Tensor 0x7fc906fdad60>
    ├── 'd' --> tensor([[3.]])
    └── 'noise' --> tensor([[[[-0.6307,  0.0423, -0.7410, -0.3917,  0.1529],
                              [ 0.6046, -0.0461,  0.7372, -0.3655,  0.2533],
                              [ 1.0064, -0.9557,  1.0176,  0.8430,  1.5863],
                              [ 1.4941,  1.6549,  0.3781, -0.2377, -3.0946]],
                    
                             [[-1.3730,  1.1420,  2.0576,  0.6310, -1.4223],
                              [-1.6611,  0.0794,  1.6946,  1.6277,  0.8059],
                              [ 1.7959, -0.4943, -0.2074,  2.0674, -0.3033],
                              [-1.2095, -0.1526, -0.1072, -1.4623,  0.3611]],
                    
                             [[-1.4317,  0.2062,  0.7165, -1.0422,  1.3787],
                              [ 0.0730,  0.4072, -1.7002,  0.1608,  0.2913],
                              [-0.3936, -1.0239, -0.2620, -1.4381,  1.4095],
                              [ 1.1906,  0.3589,  2.8231,  0.8444,  0.6052]]]])
, <Tensor 0x7fc906fdadf0>
├── 'a' --> tensor([[[0.9523, 0.1690, 0.1028, 0.6150, 0.6532, 0.8465, 0.0777, 0.1573],
│                    [0.1165, 0.3488, 0.7171, 0.5685, 0.3000, 0.3334, 0.8046, 0.0748],
│                    [0.8846, 0.9219, 0.5501, 0.1917, 0.0404, 0.4133, 0.0725, 0.9250]]])
├── 'b' --> tensor([[1.6619, 1.8847, 1.1542, 1.0103, 1.9607, 1.1220]])
└── 'c' --> <Tensor 0x7fc906fdadc0>
    ├── 'd' --> tensor([[0.]])
    └── 'noise' --> tensor([[[[ 0.0185, -0.4831,  0.5030,  1.8792, -0.5789],
                              [ 0.3683, -0.1999,  0.0887,  0.1022, -0.8752],
                              [ 0.6978, -0.5402,  1.5988, -0.2379,  1.3199],
                              [-1.1288, -1.6968, -1.5964, -1.8739, -0.6625]],
                    
                             [[ 1.0150, -0.8857, -1.3007, -0.2133, -1.0358],
                              [-0.3572, -0.3062,  0.4154,  0.2206,  0.0947],
                              [-1.3657, -1.2397,  0.2615, -1.0847,  1.7383],
                              [-0.1126, -0.3927,  0.5227,  1.6310,  0.2789]],
                    
                             [[-0.5726,  0.9945,  0.6983,  0.0200, -0.8776],
                              [-0.8363, -0.1905, -1.7081,  0.6175,  1.2286],
                              [-0.5026,  0.7750, -0.1982, -1.6114, -0.2191],
                              [-0.8025, -1.7732,  0.6458,  0.8507,  0.2071]]]])
)
mean0 & mean1: tensor(1.3916) tensor(1.3916)


even_index_a0: tensor([[[0.0730, 0.2810, 0.7642, 0.3265, 0.8006, 0.1291, 0.4410, 0.5144],
         [0.6566, 0.9296, 0.2867, 0.5219, 0.4429, 0.5960, 0.9051, 0.1431]],

        [[0.8982, 0.5922, 0.7774, 0.2902, 0.1801, 0.7970, 0.9584, 0.3310],
         [0.2149, 0.8399, 0.4481, 0.6800, 0.4614, 0.9342, 0.2913, 0.7659]],

        [[0.8102, 0.6209, 0.8538, 0.5519, 0.0497, 0.8272, 0.4520, 0.4337],
         [0.0253, 0.8591, 0.0775, 0.0992, 0.3756, 0.9941, 0.9854, 0.1591]],

        [[0.9523, 0.1690, 0.1028, 0.6150, 0.6532, 0.8465, 0.0777, 0.1573],
         [0.8846, 0.9219, 0.5501, 0.1917, 0.0404, 0.4133, 0.0725, 0.9250]]])
even_index_a1: tensor([[[0.0730, 0.2810, 0.7642, 0.3265, 0.8006, 0.1291, 0.4410, 0.5144],
         [0.6566, 0.9296, 0.2867, 0.5219, 0.4429, 0.5960, 0.9051, 0.1431]],

        [[0.8982, 0.5922, 0.7774, 0.2902, 0.1801, 0.7970, 0.9584, 0.3310],
         [0.2149, 0.8399, 0.4481, 0.6800, 0.4614, 0.9342, 0.2913, 0.7659]],

        [[0.8102, 0.6209, 0.8538, 0.5519, 0.0497, 0.8272, 0.4520, 0.4337],
         [0.0253, 0.8591, 0.0775, 0.0992, 0.3756, 0.9941, 0.9854, 0.1591]],

        [[0.9523, 0.1690, 0.1028, 0.6150, 0.6532, 0.8465, 0.0777, 0.1573],
         [0.8846, 0.9219, 0.5501, 0.1917, 0.0404, 0.4133, 0.0725, 0.9250]]])
<Size 0x7fc906fdab20>
├── 'a' --> torch.Size([1, 3, 8])
├── 'b' --> torch.Size([1, 6])
└── 'c' --> <Size 0x7fc90719c7c0>
    ├── 'd' --> torch.Size([1, 1])
    └── 'noise' --> torch.Size([1, 3, 4, 5])

The implement with treetensor API is much simpler and clearer.