Customized Operations For Different Fields

Here is another example of the custom operations implemented with both native torch API and treetensor API.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import copy

import torch

import treetensor.torch as ttorch

T, B = 3, 4


def with_nativetensor(batch_):
    mean_b_list = []
    even_index_a_list = []
    for i in range(len(batch_)):
        for k, v in batch_[i].items():
            if k == 'a':
                v = v.float()
                even_index_a_list.append(v[::2])
            elif k == 'b':
                v = v.float()
                transformed_v = torch.pow(v, 2) + 1.0
                mean_b_list.append(transformed_v.mean())
            elif k == 'c':
                for k1, v1 in v.items():
                    if k1 == 'd':
                        v1 = v1.float()
                    else:
                        print('ignore keys: {}'.format(k1))
            else:
                print('ignore keys: {}'.format(k))
    for i in range(len(batch_)):
        for k in batch_[i].keys():
            if k == 'd':
                batch_[i][k]['noise'] = torch.randn(size=(3, 4, 5))

    mean_b = sum(mean_b_list) / len(mean_b_list)
    even_index_a = torch.stack(even_index_a_list, dim=0)
    return batch_, mean_b, even_index_a


def with_treetensor(batch_):
    batch_ = [ttorch.tensor(b) for b in batch_]
    batch_ = ttorch.stack(batch_)
    batch_ = batch_.float()
    batch_.b = ttorch.pow(batch_.b, 2) + 1.0
    batch_.c.noise = ttorch.randn(size=(B, 3, 4, 5))
    mean_b = batch_.b.mean()
    even_index_a = batch_.a[:, ::2]
    batch_ = ttorch.split(batch_, split_size_or_sections=1, dim=0)
    return batch_, mean_b, even_index_a


def get_data():
    return {
        'a': torch.rand(size=(T, 8)),
        'b': torch.rand(size=(6,)),
        'c': {
            'd': torch.randint(0, 10, size=(1,))
        }
    }


if __name__ == "__main__":
    batch = [get_data() for _ in range(B)]
    batch0, mean0, even_index_a0 = with_nativetensor(copy.deepcopy(batch))
    batch1, mean1, even_index_a1 = with_treetensor(copy.deepcopy(batch))
    print(batch0)
    print('\n\n')
    print(batch1)

    assert torch.abs(mean0 - mean1) < 1e-6
    print('mean0 & mean1:', mean0, mean1)
    print('\n')

    assert torch.abs((even_index_a0 - even_index_a1).max()) < 1e-6
    print('even_index_a0:', even_index_a0)
    print('even_index_a1:', even_index_a1)

    assert len(batch0) == B
    assert len(batch1) == B
    assert isinstance(batch1[0], ttorch.Tensor)
    print(batch1[0].shape)

The output should be like below, and all the assertions can be passed.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
[{'a': tensor([[0.2036, 0.2045, 0.6618, 0.7467, 0.9715, 0.8493, 0.1803, 0.1353],
        [0.3565, 0.5454, 0.2573, 0.7043, 0.3618, 0.4075, 0.2910, 0.6184],
        [0.6495, 0.0149, 0.2179, 0.4534, 0.0662, 0.8949, 0.7714, 0.0765]]), 'b': tensor([0.9334, 0.9717, 0.6869, 0.6752, 0.8431, 0.4989]), 'c': {'d': tensor([9])}}, {'a': tensor([[0.0783, 0.0599, 0.7010, 0.8613, 0.4005, 0.4793, 0.1911, 0.2949],
        [0.7930, 0.0374, 0.3266, 0.2080, 0.3605, 0.1079, 0.7367, 0.5571],
        [0.2806, 0.3294, 0.0518, 0.6544, 0.1504, 0.3349, 0.0295, 0.1594]]), 'b': tensor([0.4440, 0.2503, 0.3370, 0.5850, 0.0929, 0.7613]), 'c': {'d': tensor([4])}}, {'a': tensor([[0.8803, 0.2766, 0.7276, 0.9788, 0.4364, 0.0951, 0.8465, 0.6436],
        [0.0128, 0.4066, 0.5662, 0.2270, 0.3613, 0.2894, 0.9197, 0.0907],
        [0.8925, 0.5813, 0.4148, 0.1561, 0.9737, 0.9314, 0.9608, 0.0347]]), 'b': tensor([0.3093, 0.2831, 0.4555, 0.2733, 0.2475, 0.3501]), 'c': {'d': tensor([5])}}, {'a': tensor([[0.4850, 0.2218, 0.1663, 0.4499, 0.7313, 0.4146, 0.5337, 0.7840],
        [0.1885, 0.0900, 0.1045, 0.8186, 0.5989, 0.0149, 0.5255, 0.6977],
        [0.7255, 0.8035, 0.7201, 0.3028, 0.8759, 0.7901, 0.9948, 0.7760]]), 'b': tensor([0.5845, 0.2858, 0.6358, 0.8988, 0.7272, 0.6260]), 'c': {'d': tensor([5])}}]



(<Tensor 0x7f85d1f09dc0>
├── 'a' --> tensor([[[0.2036, 0.2045, 0.6618, 0.7467, 0.9715, 0.8493, 0.1803, 0.1353],
│                    [0.3565, 0.5454, 0.2573, 0.7043, 0.3618, 0.4075, 0.2910, 0.6184],
│                    [0.6495, 0.0149, 0.2179, 0.4534, 0.0662, 0.8949, 0.7714, 0.0765]]])
├── 'b' --> tensor([[1.8712, 1.9442, 1.4718, 1.4559, 1.7109, 1.2489]])
└── 'c' --> <Tensor 0x7f85d1f09e20>
    ├── 'd' --> tensor([[9.]])
    └── 'noise' --> tensor([[[[-1.3222, -0.4631, -1.1493,  0.2412,  1.7066],
                              [ 0.6862,  0.1240,  0.0722, -0.3273,  0.6220],
                              [ 1.0520, -0.0453,  0.0042,  0.3125,  1.0273],
                              [ 1.0052, -1.4199,  0.5473, -1.2724, -0.4614]],
                    
                             [[-0.8970,  0.6137,  0.7893,  0.7652,  0.6276],
                              [ 0.3589, -0.9416,  1.9306, -0.3673, -0.6270],
                              [ 1.3039, -0.0854,  0.0807, -0.3006,  0.5846],
                              [ 0.9596, -0.3978, -2.7955,  1.0804,  0.8696]],
                    
                             [[ 0.5054, -0.7677, -0.5725,  0.2164,  1.1939],
                              [ 0.2597,  0.6918,  1.1359,  1.7897, -0.0778],
                              [-0.1170, -1.7965, -1.0321,  0.9932, -1.1103],
                              [-0.7336,  0.0222,  0.7134,  0.1178,  0.2860]]]])
, <Tensor 0x7f85d1f09ee0>
├── 'a' --> tensor([[[0.0783, 0.0599, 0.7010, 0.8613, 0.4005, 0.4793, 0.1911, 0.2949],
│                    [0.7930, 0.0374, 0.3266, 0.2080, 0.3605, 0.1079, 0.7367, 0.5571],
│                    [0.2806, 0.3294, 0.0518, 0.6544, 0.1504, 0.3349, 0.0295, 0.1594]]])
├── 'b' --> tensor([[1.1971, 1.0627, 1.1136, 1.3422, 1.0086, 1.5797]])
└── 'c' --> <Tensor 0x7f85d1f09d30>
    ├── 'd' --> tensor([[4.]])
    └── 'noise' --> tensor([[[[ 2.5355, -1.2246,  1.1552,  1.1878, -0.8264],
                              [ 0.1038, -1.9889, -0.3179, -0.9583, -0.0319],
                              [ 0.8137,  0.9206,  1.5439, -0.4607,  1.3009],
                              [-0.4685, -0.0254,  0.3119, -1.2403, -1.3670]],
                    
                             [[-0.7800,  1.1956,  1.4925, -0.6537,  0.9278],
                              [ 0.8594,  0.5685,  0.9859, -0.7464, -1.3470],
                              [ 1.1382,  0.3734, -0.1874,  0.2087, -0.3723],
                              [-1.9231, -0.6158, -0.9289,  0.3163,  1.1917]],
                    
                             [[-1.1180,  0.0056, -1.4290, -0.8379, -0.4956],
                              [ 1.1919, -1.2769,  0.2825,  0.7284,  0.3196],
                              [ 0.7649, -1.2728, -1.7425, -0.1880, -0.7062],
                              [ 0.4846, -0.7361, -3.2466,  0.7879,  0.4578]]]])
, <Tensor 0x7f85d1f09f40>
├── 'a' --> tensor([[[0.8803, 0.2766, 0.7276, 0.9788, 0.4364, 0.0951, 0.8465, 0.6436],
│                    [0.0128, 0.4066, 0.5662, 0.2270, 0.3613, 0.2894, 0.9197, 0.0907],
│                    [0.8925, 0.5813, 0.4148, 0.1561, 0.9737, 0.9314, 0.9608, 0.0347]]])
├── 'b' --> tensor([[1.0957, 1.0802, 1.2075, 1.0747, 1.0613, 1.1226]])
└── 'c' --> <Tensor 0x7f85d1f09f10>
    ├── 'd' --> tensor([[5.]])
    └── 'noise' --> tensor([[[[ 1.0207,  0.3434, -0.8973,  0.1341, -0.4765],
                              [-0.7483,  0.3544, -1.0095,  1.2414,  0.0488],
                              [ 0.5660, -0.2623,  0.7233,  1.1281, -0.0232],
                              [-1.6810, -0.4228, -1.9214,  0.1472,  0.2862]],
                    
                             [[-0.0974, -1.3113,  0.5264, -2.0024, -0.3385],
                              [-0.5324,  0.9494,  0.5736, -1.0252, -0.5890],
                              [ 1.0508, -1.2013, -0.0814, -0.7333, -0.9893],
                              [-0.1241, -0.3741, -1.6691,  0.6177, -1.5291]],
                    
                             [[-2.3571, -0.5609, -0.9252, -0.5954,  0.1488],
                              [-1.2112, -0.7898,  0.7132, -0.7785, -0.2986],
                              [ 1.8967, -0.5564,  0.8275,  0.2167, -1.2242],
                              [ 0.7123,  1.0096, -0.7093,  0.2377,  0.9861]]]])
, <Tensor 0x7f85d1f09fa0>
├── 'a' --> tensor([[[0.4850, 0.2218, 0.1663, 0.4499, 0.7313, 0.4146, 0.5337, 0.7840],
│                    [0.1885, 0.0900, 0.1045, 0.8186, 0.5989, 0.0149, 0.5255, 0.6977],
│                    [0.7255, 0.8035, 0.7201, 0.3028, 0.8759, 0.7901, 0.9948, 0.7760]]])
├── 'b' --> tensor([[1.3417, 1.0817, 1.4043, 1.8078, 1.5288, 1.3919]])
└── 'c' --> <Tensor 0x7f85d1f09f70>
    ├── 'd' --> tensor([[5.]])
    └── 'noise' --> tensor([[[[ 0.3787, -1.1965, -1.2048, -0.3430,  0.6383],
                              [-0.7527, -1.5821, -1.4492,  1.0375, -0.0591],
                              [-0.1333, -1.5517, -0.5212, -1.8166, -2.1628],
                              [-3.0129, -0.8093, -0.7022, -0.3460,  1.9479]],
                    
                             [[ 1.4103,  0.3771,  1.0118, -0.4478,  0.7042],
                              [ 0.8111,  0.4685,  0.2835,  0.5010, -0.1655],
                              [ 1.2953,  0.0749, -0.5532, -0.5583, -0.0162],
                              [-0.5610,  0.8703, -0.4858,  0.0177,  1.0945]],
                    
                             [[ 0.0955, -1.1814, -1.7059,  1.0074, -1.0350],
                              [ 0.5058,  0.2503, -0.7092, -0.5157, -1.2040],
                              [ 1.0195, -0.7076, -1.3974, -0.0386, -1.8758],
                              [ 0.4644,  0.4292, -0.7586,  0.6937,  0.5348]]]])
)
mean0 & mean1: tensor(1.3419) tensor(1.3419)


even_index_a0: tensor([[[0.2036, 0.2045, 0.6618, 0.7467, 0.9715, 0.8493, 0.1803, 0.1353],
         [0.6495, 0.0149, 0.2179, 0.4534, 0.0662, 0.8949, 0.7714, 0.0765]],

        [[0.0783, 0.0599, 0.7010, 0.8613, 0.4005, 0.4793, 0.1911, 0.2949],
         [0.2806, 0.3294, 0.0518, 0.6544, 0.1504, 0.3349, 0.0295, 0.1594]],

        [[0.8803, 0.2766, 0.7276, 0.9788, 0.4364, 0.0951, 0.8465, 0.6436],
         [0.8925, 0.5813, 0.4148, 0.1561, 0.9737, 0.9314, 0.9608, 0.0347]],

        [[0.4850, 0.2218, 0.1663, 0.4499, 0.7313, 0.4146, 0.5337, 0.7840],
         [0.7255, 0.8035, 0.7201, 0.3028, 0.8759, 0.7901, 0.9948, 0.7760]]])
even_index_a1: tensor([[[0.2036, 0.2045, 0.6618, 0.7467, 0.9715, 0.8493, 0.1803, 0.1353],
         [0.6495, 0.0149, 0.2179, 0.4534, 0.0662, 0.8949, 0.7714, 0.0765]],

        [[0.0783, 0.0599, 0.7010, 0.8613, 0.4005, 0.4793, 0.1911, 0.2949],
         [0.2806, 0.3294, 0.0518, 0.6544, 0.1504, 0.3349, 0.0295, 0.1594]],

        [[0.8803, 0.2766, 0.7276, 0.9788, 0.4364, 0.0951, 0.8465, 0.6436],
         [0.8925, 0.5813, 0.4148, 0.1561, 0.9737, 0.9314, 0.9608, 0.0347]],

        [[0.4850, 0.2218, 0.1663, 0.4499, 0.7313, 0.4146, 0.5337, 0.7840],
         [0.7255, 0.8035, 0.7201, 0.3028, 0.8759, 0.7901, 0.9948, 0.7760]]])
<Size 0x7f8631dfc460>
├── 'a' --> torch.Size([1, 3, 8])
├── 'b' --> torch.Size([1, 6])
└── 'c' --> <Size 0x7f85d1f86460>
    ├── 'd' --> torch.Size([1, 1])
    └── 'noise' --> torch.Size([1, 3, 4, 5])

The implement with treetensor API is much simpler and clearer.