Customized Operations For Different Fields

Here is another example of the custom operations implemented with both native torch API and treetensor API.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import copy

import torch

import treetensor.torch as ttorch

T, B = 3, 4


def with_nativetensor(batch_):
    mean_b_list = []
    even_index_a_list = []
    for i in range(len(batch_)):
        for k, v in batch_[i].items():
            if k == 'a':
                v = v.float()
                even_index_a_list.append(v[::2])
            elif k == 'b':
                v = v.float()
                transformed_v = torch.pow(v, 2) + 1.0
                mean_b_list.append(transformed_v.mean())
            elif k == 'c':
                for k1, v1 in v.items():
                    if k1 == 'd':
                        v1 = v1.float()
                    else:
                        print('ignore keys: {}'.format(k1))
            else:
                print('ignore keys: {}'.format(k))
    for i in range(len(batch_)):
        for k in batch_[i].keys():
            if k == 'd':
                batch_[i][k]['noise'] = torch.randn(size=(3, 4, 5))

    mean_b = sum(mean_b_list) / len(mean_b_list)
    even_index_a = torch.stack(even_index_a_list, dim=0)
    return batch_, mean_b, even_index_a


def with_treetensor(batch_):
    batch_ = [ttorch.tensor(b) for b in batch_]
    batch_ = ttorch.stack(batch_)
    batch_ = batch_.float()
    batch_.b = ttorch.pow(batch_.b, 2) + 1.0
    batch_.c.noise = ttorch.randn(size=(B, 3, 4, 5))
    mean_b = batch_.b.mean()
    even_index_a = batch_.a[:, ::2]
    batch_ = ttorch.split(batch_, split_size_or_sections=1, dim=0)
    return batch_, mean_b, even_index_a


def get_data():
    return {
        'a': torch.rand(size=(T, 8)),
        'b': torch.rand(size=(6,)),
        'c': {
            'd': torch.randint(0, 10, size=(1,))
        }
    }


if __name__ == "__main__":
    batch = [get_data() for _ in range(B)]
    batch0, mean0, even_index_a0 = with_nativetensor(copy.deepcopy(batch))
    batch1, mean1, even_index_a1 = with_treetensor(copy.deepcopy(batch))
    print(batch0)
    print('\n\n')
    print(batch1)

    assert torch.abs(mean0 - mean1) < 1e-6
    print('mean0 & mean1:', mean0, mean1)
    print('\n')

    assert torch.abs((even_index_a0 - even_index_a1).max()) < 1e-6
    print('even_index_a0:', even_index_a0)
    print('even_index_a1:', even_index_a1)

    assert len(batch0) == B
    assert len(batch1) == B
    assert isinstance(batch1[0], ttorch.Tensor)
    print(batch1[0].shape)

The output should be like below, and all the assertions can be passed.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
[{'a': tensor([[0.7766, 0.5611, 0.2536, 0.6818, 0.7486, 0.6041, 0.5483, 0.7180],
        [0.3113, 0.9901, 0.6145, 0.2825, 0.9314, 0.7540, 0.6752, 0.8161],
        [0.8894, 0.4415, 0.1984, 0.4616, 0.3255, 0.2886, 0.1082, 0.0120]]), 'b': tensor([0.3847, 0.7117, 0.6698, 0.3810, 0.6242, 0.7978]), 'c': {'d': tensor([1])}}, {'a': tensor([[0.0242, 0.0353, 0.8126, 0.3531, 0.3268, 0.9762, 0.3606, 0.1387],
        [0.2568, 0.1814, 0.0016, 0.4845, 0.3869, 0.1839, 0.6877, 0.6100],
        [0.8865, 0.5986, 0.5261, 0.2205, 0.6694, 0.3079, 0.6498, 0.8052]]), 'b': tensor([0.4525, 0.2294, 0.6572, 0.5407, 0.7217, 0.8113]), 'c': {'d': tensor([5])}}, {'a': tensor([[0.5047, 0.1707, 0.0935, 0.5198, 0.2858, 0.7896, 0.7454, 0.7980],
        [0.9075, 0.3477, 0.5370, 0.2072, 0.1164, 0.2555, 0.4791, 0.8106],
        [0.1644, 0.4568, 0.5529, 0.6836, 0.5587, 0.8452, 0.2835, 0.3440]]), 'b': tensor([0.6000, 0.5349, 0.8243, 0.1235, 0.2358, 0.3924]), 'c': {'d': tensor([4])}}, {'a': tensor([[0.0312, 0.3731, 0.5947, 0.8834, 0.8944, 0.6921, 0.2843, 0.8457],
        [0.9184, 0.3797, 0.0462, 0.6986, 0.5699, 0.3366, 0.3529, 0.7713],
        [0.6949, 0.3342, 0.0112, 0.2913, 0.7899, 0.0386, 0.2649, 0.3976]]), 'b': tensor([0.4561, 0.7853, 0.4153, 0.6906, 0.7903, 0.4928]), 'c': {'d': tensor([2])}}]



(<Tensor 0x7f91ea905ca0>
├── 'a' --> tensor([[[0.7766, 0.5611, 0.2536, 0.6818, 0.7486, 0.6041, 0.5483, 0.7180],
│                    [0.3113, 0.9901, 0.6145, 0.2825, 0.9314, 0.7540, 0.6752, 0.8161],
│                    [0.8894, 0.4415, 0.1984, 0.4616, 0.3255, 0.2886, 0.1082, 0.0120]]])
├── 'b' --> tensor([[1.1480, 1.5065, 1.4486, 1.1452, 1.3896, 1.6366]])
└── 'c' --> <Tensor 0x7f91ea905d00>
    ├── 'd' --> tensor([[1.]])
    └── 'noise' --> tensor([[[[ 2.5136, -0.1081,  0.8018,  1.5091, -1.0896],
                              [ 2.1947,  0.9641,  0.3989, -0.5170,  0.0855],
                              [ 0.8565, -0.6989,  0.2717,  0.8485, -0.0664],
                              [-0.3098, -0.9206,  0.5248,  0.0264,  1.5740]],
                    
                             [[-1.4064, -0.4015,  0.1006,  0.0507,  1.5793],
                              [ 0.9215,  1.6179, -1.5108,  1.2909, -0.2709],
                              [-0.0231,  0.3169, -0.6971, -0.8785,  0.0584],
                              [ 0.8124,  0.2928,  1.4596,  0.6832, -1.9324]],
                    
                             [[-1.5192, -1.3798, -1.1759,  1.1067, -1.4371],
                              [-1.3846,  0.1428,  0.8906, -0.2528, -0.4610],
                              [ 1.3863,  0.4864,  0.2585,  0.2342, -0.0341],
                              [-0.5841, -0.9200,  0.1922, -0.6996,  0.9480]]]])
, <Tensor 0x7f91ea905dc0>
├── 'a' --> tensor([[[0.0242, 0.0353, 0.8126, 0.3531, 0.3268, 0.9762, 0.3606, 0.1387],
│                    [0.2568, 0.1814, 0.0016, 0.4845, 0.3869, 0.1839, 0.6877, 0.6100],
│                    [0.8865, 0.5986, 0.5261, 0.2205, 0.6694, 0.3079, 0.6498, 0.8052]]])
├── 'b' --> tensor([[1.2048, 1.0526, 1.4320, 1.2923, 1.5209, 1.6582]])
└── 'c' --> <Tensor 0x7f91ea905be0>
    ├── 'd' --> tensor([[5.]])
    └── 'noise' --> tensor([[[[-0.1860,  0.5182,  0.6956,  0.4537, -0.2303],
                              [-0.5841,  2.5433, -0.7179,  0.5835, -0.9927],
                              [ 0.2585,  0.1612,  1.5680,  0.4102, -2.1762],
                              [-1.5873,  0.7842,  0.5519,  0.4897, -0.9798]],
                    
                             [[ 0.2232,  1.1629,  0.8133, -1.4461, -0.0508],
                              [-1.0642, -2.1902, -0.5686, -0.2585,  0.1476],
                              [-1.2493,  0.6766, -1.0951,  0.9376,  1.7562],
                              [-0.0064,  0.4575, -1.1243, -0.0211,  0.2175]],
                    
                             [[-0.0507, -0.0816, -2.4400, -2.5221, -0.2645],
                              [-0.4867,  0.5980, -0.2298, -1.0104, -0.3548],
                              [ 1.5903,  0.2890,  0.5077,  0.7891, -0.8798],
                              [ 1.2358,  1.6742,  0.8530,  0.4171, -1.4960]]]])
, <Tensor 0x7f91ea905e20>
├── 'a' --> tensor([[[0.5047, 0.1707, 0.0935, 0.5198, 0.2858, 0.7896, 0.7454, 0.7980],
│                    [0.9075, 0.3477, 0.5370, 0.2072, 0.1164, 0.2555, 0.4791, 0.8106],
│                    [0.1644, 0.4568, 0.5529, 0.6836, 0.5587, 0.8452, 0.2835, 0.3440]]])
├── 'b' --> tensor([[1.3600, 1.2861, 1.6795, 1.0152, 1.0556, 1.1540]])
└── 'c' --> <Tensor 0x7f91ea905df0>
    ├── 'd' --> tensor([[4.]])
    └── 'noise' --> tensor([[[[-0.2597,  0.5502, -0.3550,  1.2803, -0.4777],
                              [-1.9271, -0.0587,  0.9955, -0.3581, -1.3456],
                              [-0.5616,  0.5569,  0.0172, -0.4264,  0.2685],
                              [-0.7921, -0.3920,  0.7851,  0.7612,  0.3785]],
                    
                             [[-0.2208, -0.9577,  0.4848, -0.4689, -1.2425],
                              [ 0.9249,  1.6592,  0.8347, -1.2518, -0.1103],
                              [-1.1229,  0.4504, -0.1091, -1.8570,  0.5973],
                              [ 0.8977,  0.5524,  0.8264, -1.1807, -0.8822]],
                    
                             [[-1.0476, -0.4988,  1.3481,  0.3397, -1.0965],
                              [ 0.4100, -0.3561,  0.3236, -0.7633,  0.6493],
                              [ 0.4271, -0.2514, -0.5729,  0.6740, -0.5706],
                              [ 0.7090, -0.2594,  0.6588,  2.5359,  0.3798]]]])
, <Tensor 0x7f91ea905e80>
├── 'a' --> tensor([[[0.0312, 0.3731, 0.5947, 0.8834, 0.8944, 0.6921, 0.2843, 0.8457],
│                    [0.9184, 0.3797, 0.0462, 0.6986, 0.5699, 0.3366, 0.3529, 0.7713],
│                    [0.6949, 0.3342, 0.0112, 0.2913, 0.7899, 0.0386, 0.2649, 0.3976]]])
├── 'b' --> tensor([[1.2080, 1.6168, 1.1724, 1.4769, 1.6246, 1.2428]])
└── 'c' --> <Tensor 0x7f91ea905e50>
    ├── 'd' --> tensor([[2.]])
    └── 'noise' --> tensor([[[[-2.2400, -1.6115, -1.4326,  1.6931, -1.3931],
                              [ 1.6587, -0.2218,  1.1462, -0.4680, -0.0203],
                              [ 0.1144, -0.9507,  0.6894,  1.0788,  0.5769],
                              [-0.2663,  0.9609,  0.6088, -0.5087,  0.6183]],
                    
                             [[-0.2387, -0.7638,  0.3636, -0.7947,  0.1206],
                              [ 0.1529,  0.6614, -0.0161, -0.2323, -0.2046],
                              [-0.1206, -0.8588,  1.2333,  1.5618, -0.7350],
                              [ 1.3395, -0.9671, -0.7410,  0.7517, -0.1983]],
                    
                             [[ 1.4318, -0.2031, -1.7052, -0.7655,  0.3964],
                              [-0.1339, -0.4975,  0.3394, -0.6810,  0.6485],
                              [-0.6155, -0.6125,  1.0350,  0.2076, -0.1171],
                              [-0.0900,  1.0732, -0.0580, -1.2796,  0.3661]]]])
)
mean0 & mean1: tensor(1.3470) tensor(1.3470)


even_index_a0: tensor([[[0.7766, 0.5611, 0.2536, 0.6818, 0.7486, 0.6041, 0.5483, 0.7180],
         [0.8894, 0.4415, 0.1984, 0.4616, 0.3255, 0.2886, 0.1082, 0.0120]],

        [[0.0242, 0.0353, 0.8126, 0.3531, 0.3268, 0.9762, 0.3606, 0.1387],
         [0.8865, 0.5986, 0.5261, 0.2205, 0.6694, 0.3079, 0.6498, 0.8052]],

        [[0.5047, 0.1707, 0.0935, 0.5198, 0.2858, 0.7896, 0.7454, 0.7980],
         [0.1644, 0.4568, 0.5529, 0.6836, 0.5587, 0.8452, 0.2835, 0.3440]],

        [[0.0312, 0.3731, 0.5947, 0.8834, 0.8944, 0.6921, 0.2843, 0.8457],
         [0.6949, 0.3342, 0.0112, 0.2913, 0.7899, 0.0386, 0.2649, 0.3976]]])
even_index_a1: tensor([[[0.7766, 0.5611, 0.2536, 0.6818, 0.7486, 0.6041, 0.5483, 0.7180],
         [0.8894, 0.4415, 0.1984, 0.4616, 0.3255, 0.2886, 0.1082, 0.0120]],

        [[0.0242, 0.0353, 0.8126, 0.3531, 0.3268, 0.9762, 0.3606, 0.1387],
         [0.8865, 0.5986, 0.5261, 0.2205, 0.6694, 0.3079, 0.6498, 0.8052]],

        [[0.5047, 0.1707, 0.0935, 0.5198, 0.2858, 0.7896, 0.7454, 0.7980],
         [0.1644, 0.4568, 0.5529, 0.6836, 0.5587, 0.8452, 0.2835, 0.3440]],

        [[0.0312, 0.3731, 0.5947, 0.8834, 0.8944, 0.6921, 0.2843, 0.8457],
         [0.6949, 0.3342, 0.0112, 0.2913, 0.7899, 0.0386, 0.2649, 0.3976]]])
<Size 0x7f92507ee430>
├── 'a' --> torch.Size([1, 3, 8])
├── 'b' --> torch.Size([1, 6])
└── 'c' --> <Size 0x7f91ea9930d0>
    ├── 'd' --> torch.Size([1, 1])
    └── 'noise' --> torch.Size([1, 3, 4, 5])

The implement with treetensor API is much simpler and clearer.