Customized Operations For Different Fields

Here is another example of the custom operations implemented with both native torch API and treetensor API.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import copy

import torch

import treetensor.torch as ttorch

T, B = 3, 4


def with_nativetensor(batch_):
    mean_b_list = []
    even_index_a_list = []
    for i in range(len(batch_)):
        for k, v in batch_[i].items():
            if k == 'a':
                v = v.float()
                even_index_a_list.append(v[::2])
            elif k == 'b':
                v = v.float()
                transformed_v = torch.pow(v, 2) + 1.0
                mean_b_list.append(transformed_v.mean())
            elif k == 'c':
                for k1, v1 in v.items():
                    if k1 == 'd':
                        v1 = v1.float()
                    else:
                        print('ignore keys: {}'.format(k1))
            else:
                print('ignore keys: {}'.format(k))
    for i in range(len(batch_)):
        for k in batch_[i].keys():
            if k == 'd':
                batch_[i][k]['noise'] = torch.randn(size=(3, 4, 5))

    mean_b = sum(mean_b_list) / len(mean_b_list)
    even_index_a = torch.stack(even_index_a_list, dim=0)
    return batch_, mean_b, even_index_a


def with_treetensor(batch_):
    batch_ = [ttorch.tensor(b) for b in batch_]
    batch_ = ttorch.stack(batch_)
    batch_ = batch_.float()
    batch_.b = ttorch.pow(batch_.b, 2) + 1.0
    batch_.c.noise = ttorch.randn(size=(B, 3, 4, 5))
    mean_b = batch_.b.mean()
    even_index_a = batch_.a[:, ::2]
    batch_ = ttorch.split(batch_, split_size_or_sections=1, dim=0)
    return batch_, mean_b, even_index_a


def get_data():
    return {
        'a': torch.rand(size=(T, 8)),
        'b': torch.rand(size=(6,)),
        'c': {
            'd': torch.randint(0, 10, size=(1,))
        }
    }


if __name__ == "__main__":
    batch = [get_data() for _ in range(B)]
    batch0, mean0, even_index_a0 = with_nativetensor(copy.deepcopy(batch))
    batch1, mean1, even_index_a1 = with_treetensor(copy.deepcopy(batch))
    print(batch0)
    print('\n\n')
    print(batch1)

    assert torch.abs(mean0 - mean1) < 1e-6
    print('mean0 & mean1:', mean0, mean1)
    print('\n')

    assert torch.abs((even_index_a0 - even_index_a1).max()) < 1e-6
    print('even_index_a0:', even_index_a0)
    print('even_index_a1:', even_index_a1)

    assert len(batch0) == B
    assert len(batch1) == B
    assert isinstance(batch1[0], ttorch.Tensor)
    print(batch1[0].shape)

The output should be like below, and all the assertions can be passed.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
[{'a': tensor([[0.8586, 0.8745, 0.2379, 0.5585, 0.4995, 0.4912, 0.6396, 0.3216],
        [0.8695, 0.2370, 0.3972, 0.8458, 0.5265, 0.6883, 0.8920, 0.2215],
        [0.3978, 0.3845, 0.2481, 0.6023, 0.3278, 0.4616, 0.1159, 0.6446]]), 'b': tensor([0.0108, 0.5429, 0.9307, 0.0859, 0.9104, 0.6995]), 'c': {'d': tensor([6])}}, {'a': tensor([[0.3686, 0.3797, 0.9587, 0.0678, 0.8838, 0.7381, 0.0863, 0.2462],
        [0.5136, 0.8755, 0.4092, 0.7654, 0.2754, 0.3902, 0.9426, 0.3342],
        [0.4133, 0.3055, 0.6118, 0.5253, 0.6120, 0.2868, 0.4473, 0.9701]]), 'b': tensor([0.1724, 0.0885, 0.7045, 0.5513, 0.8351, 0.5251]), 'c': {'d': tensor([3])}}, {'a': tensor([[0.1634, 0.9050, 0.1383, 0.4304, 0.6519, 0.7408, 0.8242, 0.4771],
        [0.5757, 0.9322, 0.4528, 0.7131, 0.8326, 0.8706, 0.0143, 0.4793],
        [0.6991, 0.2571, 0.7627, 0.8777, 0.9678, 0.5340, 0.9114, 0.8534]]), 'b': tensor([0.1379, 0.1931, 0.7572, 0.1832, 0.4489, 0.7069]), 'c': {'d': tensor([1])}}, {'a': tensor([[0.5449, 0.3750, 0.1559, 0.1598, 0.8304, 0.5090, 0.8535, 0.4835],
        [0.7706, 0.3237, 0.6379, 0.6963, 0.8241, 0.8523, 0.1088, 0.6494],
        [0.4307, 0.0856, 0.3791, 0.7000, 0.2590, 0.6379, 0.9880, 0.7518]]), 'b': tensor([0.8498, 0.5255, 0.4614, 0.1046, 0.8997, 0.7453]), 'c': {'d': tensor([5])}}]



(<Tensor 0x7f4ed5b9f2e0>
├── 'a' --> tensor([[[0.8586, 0.8745, 0.2379, 0.5585, 0.4995, 0.4912, 0.6396, 0.3216],
│                    [0.8695, 0.2370, 0.3972, 0.8458, 0.5265, 0.6883, 0.8920, 0.2215],
│                    [0.3978, 0.3845, 0.2481, 0.6023, 0.3278, 0.4616, 0.1159, 0.6446]]])
├── 'b' --> tensor([[1.0001, 1.2947, 1.8663, 1.0074, 1.8288, 1.4892]])
└── 'c' --> <Tensor 0x7f4ed5b9f340>
    ├── 'd' --> tensor([[6.]])
    └── 'noise' --> tensor([[[[ 1.3119,  1.0073, -0.0826, -0.0412, -1.2066],
                              [ 1.1487, -0.9194, -1.0484,  1.0572, -0.3298],
                              [ 1.0939,  1.0182,  0.5228,  1.0135, -0.0601],
                              [ 0.6180, -0.4755, -1.2756, -0.3167, -0.5141]],
                    
                             [[ 1.4911,  0.2002, -0.2504,  0.1903,  0.0089],
                              [-1.3901,  0.5571, -0.1145,  1.1175, -0.5574],
                              [-0.2847, -1.4312, -1.4599, -0.0578,  0.7877],
                              [ 0.1002,  0.2290,  0.8220, -0.1056,  1.4887]],
                    
                             [[-0.0857,  1.6059, -2.1704, -0.4478,  0.7624],
                              [ 1.5118,  0.8130,  0.0370, -0.6705,  2.8412],
                              [-1.1921, -1.5076,  0.3315,  1.5197,  0.0273],
                              [ 0.6083, -2.3666, -0.6193, -0.1556,  0.7426]]]])
, <Tensor 0x7f4ed5b9f3a0>
├── 'a' --> tensor([[[0.3686, 0.3797, 0.9587, 0.0678, 0.8838, 0.7381, 0.0863, 0.2462],
│                    [0.5136, 0.8755, 0.4092, 0.7654, 0.2754, 0.3902, 0.9426, 0.3342],
│                    [0.4133, 0.3055, 0.6118, 0.5253, 0.6120, 0.2868, 0.4473, 0.9701]]])
├── 'b' --> tensor([[1.0297, 1.0078, 1.4964, 1.3040, 1.6973, 1.2758]])
└── 'c' --> <Tensor 0x7f4ed5b9f280>
    ├── 'd' --> tensor([[3.]])
    └── 'noise' --> tensor([[[[-0.6261,  1.0927, -0.8166, -1.0715,  0.3092],
                              [ 1.8608, -0.4651, -1.7187, -0.8299,  0.4223],
                              [-1.4070,  1.3266, -1.7648,  0.2992, -0.2652],
                              [ 0.7462,  0.1425, -0.2780, -0.7319,  0.1484]],
                    
                             [[-0.4234,  0.2794, -0.4522,  0.0267,  0.4890],
                              [ 1.3390,  0.1871, -0.9754,  0.9092, -0.5868],
                              [-0.4940,  0.1136, -2.5127,  0.5628, -0.8442],
                              [-0.2515,  1.7598, -0.3733,  0.0215, -0.0501]],
                    
                             [[-1.5443,  1.0383,  0.3005,  1.2338,  1.5016],
                              [-0.8317, -0.9826, -0.4229,  0.8076, -0.3408],
                              [-2.7238,  0.8420,  0.2547, -0.9996, -1.1935],
                              [-0.0824, -0.2600, -0.2532,  0.3305, -0.5450]]]])
, <Tensor 0x7f4ed5b9f400>
├── 'a' --> tensor([[[0.1634, 0.9050, 0.1383, 0.4304, 0.6519, 0.7408, 0.8242, 0.4771],
│                    [0.5757, 0.9322, 0.4528, 0.7131, 0.8326, 0.8706, 0.0143, 0.4793],
│                    [0.6991, 0.2571, 0.7627, 0.8777, 0.9678, 0.5340, 0.9114, 0.8534]]])
├── 'b' --> tensor([[1.0190, 1.0373, 1.5733, 1.0336, 1.2015, 1.4997]])
└── 'c' --> <Tensor 0x7f4ed5b9f3d0>
    ├── 'd' --> tensor([[1.]])
    └── 'noise' --> tensor([[[[-0.7697, -0.6644, -0.6809,  0.5710,  0.8831],
                              [ 0.5016, -1.6522,  0.6898,  0.9388, -2.8419],
                              [ 1.0921, -1.8615,  0.0894,  1.0522,  0.0959],
                              [ 0.1511, -0.5223,  0.0411, -0.3828, -0.8376]],
                    
                             [[ 0.2205, -0.8009,  2.1709,  1.3212,  1.9122],
                              [ 0.1257,  1.5379, -0.6085,  0.6063,  0.4691],
                              [ 0.9203,  0.6534, -0.6768, -0.8586,  0.1047],
                              [ 0.0038, -2.3927, -2.6818,  1.6763,  1.0318]],
                    
                             [[-0.3389, -1.0578, -0.3507,  0.4921,  0.0679],
                              [ 1.2042, -1.0514, -2.4790,  1.0759, -0.7012],
                              [ 0.3001, -2.3341,  0.5223, -0.7165, -0.3541],
                              [-0.0838,  0.2886,  0.9576, -0.5891,  0.4655]]]])
, <Tensor 0x7f4ed5b9f460>
├── 'a' --> tensor([[[0.5449, 0.3750, 0.1559, 0.1598, 0.8304, 0.5090, 0.8535, 0.4835],
│                    [0.7706, 0.3237, 0.6379, 0.6963, 0.8241, 0.8523, 0.1088, 0.6494],
│                    [0.4307, 0.0856, 0.3791, 0.7000, 0.2590, 0.6379, 0.9880, 0.7518]]])
├── 'b' --> tensor([[1.7221, 1.2761, 1.2129, 1.0109, 1.8095, 1.5554]])
└── 'c' --> <Tensor 0x7f4ed5b9f430>
    ├── 'd' --> tensor([[5.]])
    └── 'noise' --> tensor([[[[-0.1033, -0.5404, -0.5412,  0.6728, -0.0648],
                              [-1.2375,  1.0447, -0.0787, -1.6941,  0.6986],
                              [ 1.5592, -0.3467, -0.4237, -0.0275, -0.3738],
                              [-1.5760,  0.7306, -0.0981, -0.3530,  1.3364]],
                    
                             [[-1.0773, -2.0509, -0.0376, -1.1306, -0.2804],
                              [-1.6774, -1.1321, -0.0292, -0.8410, -0.0336],
                              [ 0.2438, -0.1650,  1.4451,  0.7929, -2.7484],
                              [ 1.3637,  0.7989,  1.5782,  0.9569,  0.3596]],
                    
                             [[ 0.1463,  1.9744,  0.2162, -1.7828, -0.3763],
                              [-0.7063,  0.0489,  1.0735, -0.5237,  1.1289],
                              [-0.8941,  0.8918,  1.3383,  0.0798,  0.6454],
                              [-0.4515,  0.3357, -0.5779,  0.8850,  1.2229]]]])
)
mean0 & mean1: tensor(1.3437) tensor(1.3437)


even_index_a0: tensor([[[0.8586, 0.8745, 0.2379, 0.5585, 0.4995, 0.4912, 0.6396, 0.3216],
         [0.3978, 0.3845, 0.2481, 0.6023, 0.3278, 0.4616, 0.1159, 0.6446]],

        [[0.3686, 0.3797, 0.9587, 0.0678, 0.8838, 0.7381, 0.0863, 0.2462],
         [0.4133, 0.3055, 0.6118, 0.5253, 0.6120, 0.2868, 0.4473, 0.9701]],

        [[0.1634, 0.9050, 0.1383, 0.4304, 0.6519, 0.7408, 0.8242, 0.4771],
         [0.6991, 0.2571, 0.7627, 0.8777, 0.9678, 0.5340, 0.9114, 0.8534]],

        [[0.5449, 0.3750, 0.1559, 0.1598, 0.8304, 0.5090, 0.8535, 0.4835],
         [0.4307, 0.0856, 0.3791, 0.7000, 0.2590, 0.6379, 0.9880, 0.7518]]])
even_index_a1: tensor([[[0.8586, 0.8745, 0.2379, 0.5585, 0.4995, 0.4912, 0.6396, 0.3216],
         [0.3978, 0.3845, 0.2481, 0.6023, 0.3278, 0.4616, 0.1159, 0.6446]],

        [[0.3686, 0.3797, 0.9587, 0.0678, 0.8838, 0.7381, 0.0863, 0.2462],
         [0.4133, 0.3055, 0.6118, 0.5253, 0.6120, 0.2868, 0.4473, 0.9701]],

        [[0.1634, 0.9050, 0.1383, 0.4304, 0.6519, 0.7408, 0.8242, 0.4771],
         [0.6991, 0.2571, 0.7627, 0.8777, 0.9678, 0.5340, 0.9114, 0.8534]],

        [[0.5449, 0.3750, 0.1559, 0.1598, 0.8304, 0.5090, 0.8535, 0.4835],
         [0.4307, 0.0856, 0.3791, 0.7000, 0.2590, 0.6379, 0.9880, 0.7518]]])
<Size 0x7f4f3c9157c0>
├── 'a' --> torch.Size([1, 3, 8])
├── 'b' --> torch.Size([1, 6])
└── 'c' --> <Size 0x7f4ed5b9aee0>
    ├── 'd' --> torch.Size([1, 1])
    └── 'noise' --> torch.Size([1, 3, 4, 5])

The implement with treetensor API is much simpler and clearer.