Customized Operations For Different Fields

Here is another example of the custom operations implemented with both native torch API and treetensor API.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import copy

import torch

import treetensor.torch as ttorch

T, B = 3, 4


def with_nativetensor(batch_):
    mean_b_list = []
    even_index_a_list = []
    for i in range(len(batch_)):
        for k, v in batch_[i].items():
            if k == 'a':
                v = v.float()
                even_index_a_list.append(v[::2])
            elif k == 'b':
                v = v.float()
                transformed_v = torch.pow(v, 2) + 1.0
                mean_b_list.append(transformed_v.mean())
            elif k == 'c':
                for k1, v1 in v.items():
                    if k1 == 'd':
                        v1 = v1.float()
                    else:
                        print('ignore keys: {}'.format(k1))
            else:
                print('ignore keys: {}'.format(k))
    for i in range(len(batch_)):
        for k in batch_[i].keys():
            if k == 'd':
                batch_[i][k]['noise'] = torch.randn(size=(3, 4, 5))

    mean_b = sum(mean_b_list) / len(mean_b_list)
    even_index_a = torch.stack(even_index_a_list, dim=0)
    return batch_, mean_b, even_index_a


def with_treetensor(batch_):
    batch_ = [ttorch.tensor(b) for b in batch_]
    batch_ = ttorch.stack(batch_)
    batch_ = batch_.float()
    batch_.b = ttorch.pow(batch_.b, 2) + 1.0
    batch_.c.noise = ttorch.randn(size=(B, 3, 4, 5))
    mean_b = batch_.b.mean()
    even_index_a = batch_.a[:, ::2]
    batch_ = ttorch.split(batch_, split_size_or_sections=1, dim=0)
    return batch_, mean_b, even_index_a


def get_data():
    return {
        'a': torch.rand(size=(T, 8)),
        'b': torch.rand(size=(6,)),
        'c': {
            'd': torch.randint(0, 10, size=(1,))
        }
    }


if __name__ == "__main__":
    batch = [get_data() for _ in range(B)]
    batch0, mean0, even_index_a0 = with_nativetensor(copy.deepcopy(batch))
    batch1, mean1, even_index_a1 = with_treetensor(copy.deepcopy(batch))
    print(batch0)
    print('\n\n')
    print(batch1)

    assert torch.abs(mean0 - mean1) < 1e-6
    print('mean0 & mean1:', mean0, mean1)
    print('\n')

    assert torch.abs((even_index_a0 - even_index_a1).max()) < 1e-6
    print('even_index_a0:', even_index_a0)
    print('even_index_a1:', even_index_a1)

    assert len(batch0) == B
    assert len(batch1) == B
    assert isinstance(batch1[0], ttorch.Tensor)
    print(batch1[0].shape)

The output should be like below, and all the assertions can be passed.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
[{'a': tensor([[0.4529, 0.4845, 0.1233, 0.7727, 0.7163, 0.0018, 0.1644, 0.1536],
        [0.5998, 0.5347, 0.9329, 0.4599, 0.4904, 0.3431, 0.5684, 0.5958],
        [0.6771, 0.6037, 0.8489, 0.3818, 0.6618, 0.1331, 0.0240, 0.8878]]), 'b': tensor([0.2393, 0.0322, 0.8971, 0.4516, 0.1839, 0.5778]), 'c': {'d': tensor([3])}}, {'a': tensor([[0.2874, 0.2397, 0.1155, 0.6106, 0.2545, 0.3252, 0.6645, 0.2612],
        [0.7247, 0.1225, 0.2925, 0.6812, 0.5074, 0.1101, 0.6757, 0.4603],
        [0.1935, 0.7102, 0.5032, 0.3095, 0.1066, 0.0958, 0.3336, 0.3753]]), 'b': tensor([0.5000, 0.8592, 0.3209, 0.7188, 0.3207, 0.7602]), 'c': {'d': tensor([0])}}, {'a': tensor([[0.3435, 0.4422, 0.0218, 0.8240, 0.1337, 0.8707, 0.2374, 0.8759],
        [0.8631, 0.3800, 0.5087, 0.8458, 0.3355, 0.3682, 0.3533, 0.8632],
        [0.2984, 0.6672, 0.9125, 0.8459, 0.6046, 0.6513, 0.1312, 0.7935]]), 'b': tensor([0.7767, 0.8433, 0.1131, 0.4656, 0.8797, 0.6381]), 'c': {'d': tensor([9])}}, {'a': tensor([[0.3697, 0.8184, 0.2459, 0.7395, 0.6239, 0.6514, 0.2199, 0.8016],
        [0.1031, 0.6713, 0.0948, 0.8860, 0.7416, 0.7898, 0.6888, 0.8017],
        [0.2897, 0.4645, 0.3337, 0.8838, 0.4184, 0.5080, 0.7330, 0.3640]]), 'b': tensor([0.4962, 0.1462, 0.7673, 0.3855, 0.2231, 0.2650]), 'c': {'d': tensor([7])}}]



(<Tensor 0x7fedcf41e520>
├── 'a' --> tensor([[[0.4529, 0.4845, 0.1233, 0.7727, 0.7163, 0.0018, 0.1644, 0.1536],
│                    [0.5998, 0.5347, 0.9329, 0.4599, 0.4904, 0.3431, 0.5684, 0.5958],
│                    [0.6771, 0.6037, 0.8489, 0.3818, 0.6618, 0.1331, 0.0240, 0.8878]]])
├── 'b' --> tensor([[1.0573, 1.0010, 1.8047, 1.2039, 1.0338, 1.3338]])
└── 'c' --> <Tensor 0x7fedcf41e580>
    ├── 'd' --> tensor([[3.]])
    └── 'noise' --> tensor([[[[-0.6636, -0.0559, -1.7296,  0.1993, -0.4144],
                              [-0.9026, -0.1666, -1.0286,  0.9501,  0.2277],
                              [ 0.1565,  0.7128,  1.3450,  0.1384, -0.2809],
                              [ 2.3120, -0.6240, -0.0100,  0.0555,  0.1233]],
                    
                             [[ 0.7889, -0.3542,  0.5173, -0.1994, -0.2877],
                              [ 0.5132, -0.2618, -0.2381, -0.5973,  2.0943],
                              [ 2.5564,  0.3201, -0.6498, -0.7667,  0.2432],
                              [-0.3787, -0.5635,  0.7659,  1.1013, -0.3738]],
                    
                             [[-2.0796, -1.2503, -0.8508,  0.4856,  0.3095],
                              [-1.3404,  1.1604, -1.2025,  0.7917, -0.6804],
                              [-1.3208,  0.5036, -1.2873,  0.1188, -0.5987],
                              [ 1.3959, -1.5887, -1.0825,  0.2393, -1.0344]]]])
, <Tensor 0x7fedcf41e640>
├── 'a' --> tensor([[[0.2874, 0.2397, 0.1155, 0.6106, 0.2545, 0.3252, 0.6645, 0.2612],
│                    [0.7247, 0.1225, 0.2925, 0.6812, 0.5074, 0.1101, 0.6757, 0.4603],
│                    [0.1935, 0.7102, 0.5032, 0.3095, 0.1066, 0.0958, 0.3336, 0.3753]]])
├── 'b' --> tensor([[1.2500, 1.7382, 1.1030, 1.5166, 1.1029, 1.5779]])
└── 'c' --> <Tensor 0x7fedcf41e490>
    ├── 'd' --> tensor([[0.]])
    └── 'noise' --> tensor([[[[-0.6272, -0.8048, -0.1083, -2.0494, -0.5725],
                              [ 0.4207,  0.5677, -2.1537,  0.0689, -0.9418],
                              [ 0.3334,  0.0102, -2.2345, -0.4336,  0.6982],
                              [-0.7987, -1.1848, -0.3990,  0.9738,  1.0983]],
                    
                             [[-0.1831, -1.0681,  1.5724,  0.8108, -0.6631],
                              [-0.1058, -0.2816, -0.4718,  1.8463, -2.0032],
                              [-1.1256, -0.7793,  0.2796,  0.9440, -0.1748],
                              [-0.0289, -1.1661, -1.0366,  0.2706, -0.0884]],
                    
                             [[ 0.3446, -1.2585,  0.0536, -1.0360, -2.0640],
                              [-0.2337, -0.2699, -0.0259, -0.5141,  0.7299],
                              [-0.3413, -0.0165,  0.3234,  1.4532,  0.8804],
                              [ 0.6048, -0.4461,  0.2000, -0.8728,  1.8801]]]])
, <Tensor 0x7fedcf41e6a0>
├── 'a' --> tensor([[[0.3435, 0.4422, 0.0218, 0.8240, 0.1337, 0.8707, 0.2374, 0.8759],
│                    [0.8631, 0.3800, 0.5087, 0.8458, 0.3355, 0.3682, 0.3533, 0.8632],
│                    [0.2984, 0.6672, 0.9125, 0.8459, 0.6046, 0.6513, 0.1312, 0.7935]]])
├── 'b' --> tensor([[1.6033, 1.7111, 1.0128, 1.2168, 1.7738, 1.4071]])
└── 'c' --> <Tensor 0x7fedcf41e670>
    ├── 'd' --> tensor([[9.]])
    └── 'noise' --> tensor([[[[ 0.9333,  1.1424, -0.7474, -0.8829,  0.9201],
                              [ 0.2274, -1.8195, -1.6136, -0.2165, -1.0574],
                              [ 1.4707, -1.4699, -0.7836, -0.4376,  1.3080],
                              [-0.9974,  0.2809,  0.4782, -1.5480, -0.7322]],
                    
                             [[-0.1923,  0.1196, -0.6996,  0.5642, -0.3245],
                              [-0.7999,  1.7541,  0.7113, -0.3174,  1.5693],
                              [-0.5462, -1.2245,  0.2500,  0.4545,  0.5209],
                              [ 1.0194,  1.3523, -0.7063,  0.0980,  0.2298]],
                    
                             [[ 0.5091,  1.2884, -2.4875, -0.6895,  1.3439],
                              [-0.1998,  0.0077,  0.9488,  0.1471,  1.2440],
                              [ 1.7300,  1.9349, -0.3384,  0.1602,  1.2618],
                              [-0.0880, -1.4244,  0.1738,  0.2489,  0.2008]]]])
, <Tensor 0x7fedcf41e700>
├── 'a' --> tensor([[[0.3697, 0.8184, 0.2459, 0.7395, 0.6239, 0.6514, 0.2199, 0.8016],
│                    [0.1031, 0.6713, 0.0948, 0.8860, 0.7416, 0.7898, 0.6888, 0.8017],
│                    [0.2897, 0.4645, 0.3337, 0.8838, 0.4184, 0.5080, 0.7330, 0.3640]]])
├── 'b' --> tensor([[1.2462, 1.0214, 1.5887, 1.1486, 1.0498, 1.0702]])
└── 'c' --> <Tensor 0x7fedcf41e6d0>
    ├── 'd' --> tensor([[7.]])
    └── 'noise' --> tensor([[[[ 1.3262,  3.0024, -1.3932, -1.0703,  2.0210],
                              [-0.2846,  1.1873, -0.5566,  0.4753,  0.6220],
                              [ 1.3426,  0.1599,  1.3555, -1.1109,  1.4186],
                              [-1.2408,  1.6457, -1.3557,  0.4189, -1.4630]],
                    
                             [[-1.0050,  0.7207, -0.9922, -0.0952,  0.5409],
                              [-0.7450,  1.3094, -1.5485,  0.3075,  1.0038],
                              [ 0.4184, -0.3876, -0.3017,  0.2980, -1.4328],
                              [-0.2267, -0.9949, -0.4961, -1.3194, -1.5610]],
                    
                             [[-0.6511, -0.1034, -0.9375,  0.9415, -1.5765],
                              [ 0.2072, -0.7496,  0.4781, -1.7320,  1.1552],
                              [-1.0410,  0.3267, -0.0955,  0.3992,  0.1344],
                              [-0.0522,  1.7607, -2.5081, -1.0226, -0.6805]]]])
)
mean0 & mean1: tensor(1.3155) tensor(1.3155)


even_index_a0: tensor([[[0.4529, 0.4845, 0.1233, 0.7727, 0.7163, 0.0018, 0.1644, 0.1536],
         [0.6771, 0.6037, 0.8489, 0.3818, 0.6618, 0.1331, 0.0240, 0.8878]],

        [[0.2874, 0.2397, 0.1155, 0.6106, 0.2545, 0.3252, 0.6645, 0.2612],
         [0.1935, 0.7102, 0.5032, 0.3095, 0.1066, 0.0958, 0.3336, 0.3753]],

        [[0.3435, 0.4422, 0.0218, 0.8240, 0.1337, 0.8707, 0.2374, 0.8759],
         [0.2984, 0.6672, 0.9125, 0.8459, 0.6046, 0.6513, 0.1312, 0.7935]],

        [[0.3697, 0.8184, 0.2459, 0.7395, 0.6239, 0.6514, 0.2199, 0.8016],
         [0.2897, 0.4645, 0.3337, 0.8838, 0.4184, 0.5080, 0.7330, 0.3640]]])
even_index_a1: tensor([[[0.4529, 0.4845, 0.1233, 0.7727, 0.7163, 0.0018, 0.1644, 0.1536],
         [0.6771, 0.6037, 0.8489, 0.3818, 0.6618, 0.1331, 0.0240, 0.8878]],

        [[0.2874, 0.2397, 0.1155, 0.6106, 0.2545, 0.3252, 0.6645, 0.2612],
         [0.1935, 0.7102, 0.5032, 0.3095, 0.1066, 0.0958, 0.3336, 0.3753]],

        [[0.3435, 0.4422, 0.0218, 0.8240, 0.1337, 0.8707, 0.2374, 0.8759],
         [0.2984, 0.6672, 0.9125, 0.8459, 0.6046, 0.6513, 0.1312, 0.7935]],

        [[0.3697, 0.8184, 0.2459, 0.7395, 0.6239, 0.6514, 0.2199, 0.8016],
         [0.2897, 0.4645, 0.3337, 0.8838, 0.4184, 0.5080, 0.7330, 0.3640]]])
<Size 0x7fee2f323460>
├── 'a' --> torch.Size([1, 3, 8])
├── 'b' --> torch.Size([1, 6])
└── 'c' --> <Size 0x7fedcf49d100>
    ├── 'd' --> torch.Size([1, 1])
    └── 'noise' --> torch.Size([1, 3, 4, 5])

The implement with treetensor API is much simpler and clearer.