Customized Operations For Different Fields

Here is another example of the custom operations implemented with both native torch API and treetensor API.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import copy

import torch

import treetensor.torch as ttorch

T, B = 3, 4


def with_nativetensor(batch_):
    mean_b_list = []
    even_index_a_list = []
    for i in range(len(batch_)):
        for k, v in batch_[i].items():
            if k == 'a':
                v = v.float()
                even_index_a_list.append(v[::2])
            elif k == 'b':
                v = v.float()
                transformed_v = torch.pow(v, 2) + 1.0
                mean_b_list.append(transformed_v.mean())
            elif k == 'c':
                for k1, v1 in v.items():
                    if k1 == 'd':
                        v1 = v1.float()
                    else:
                        print('ignore keys: {}'.format(k1))
            else:
                print('ignore keys: {}'.format(k))
    for i in range(len(batch_)):
        for k in batch_[i].keys():
            if k == 'd':
                batch_[i][k]['noise'] = torch.randn(size=(3, 4, 5))

    mean_b = sum(mean_b_list) / len(mean_b_list)
    even_index_a = torch.stack(even_index_a_list, dim=0)
    return batch_, mean_b, even_index_a


def with_treetensor(batch_):
    batch_ = [ttorch.tensor(b) for b in batch_]
    batch_ = ttorch.stack(batch_)
    batch_ = batch_.float()
    batch_.b = ttorch.pow(batch_.b, 2) + 1.0
    batch_.c.noise = ttorch.randn(size=(B, 3, 4, 5))
    mean_b = batch_.b.mean()
    even_index_a = batch_.a[:, ::2]
    batch_ = ttorch.split(batch_, split_size_or_sections=1, dim=0)
    return batch_, mean_b, even_index_a


def get_data():
    return {
        'a': torch.rand(size=(T, 8)),
        'b': torch.rand(size=(6,)),
        'c': {
            'd': torch.randint(0, 10, size=(1,))
        }
    }


if __name__ == "__main__":
    batch = [get_data() for _ in range(B)]
    batch0, mean0, even_index_a0 = with_nativetensor(copy.deepcopy(batch))
    batch1, mean1, even_index_a1 = with_treetensor(copy.deepcopy(batch))
    print(batch0)
    print('\n\n')
    print(batch1)

    assert torch.abs(mean0 - mean1) < 1e-6
    print('mean0 & mean1:', mean0, mean1)
    print('\n')

    assert torch.abs((even_index_a0 - even_index_a1).max()) < 1e-6
    print('even_index_a0:', even_index_a0)
    print('even_index_a1:', even_index_a1)

    assert len(batch0) == B
    assert len(batch1) == B
    assert isinstance(batch1[0], ttorch.Tensor)
    print(batch1[0].shape)

The output should be like below, and all the assertions can be passed.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
[{'a': tensor([[0.4572, 0.6080, 0.6765, 0.4616, 0.7006, 0.7138, 0.8332, 0.5279],
        [0.5234, 0.1879, 0.9367, 0.2713, 0.7517, 0.4578, 0.8312, 0.1275],
        [0.5743, 0.4873, 0.6173, 0.8256, 0.9331, 0.5081, 0.7060, 0.8910]]), 'b': tensor([0.9206, 0.9682, 0.2597, 0.5851, 0.8320, 1.0000]), 'c': {'d': tensor([8])}}, {'a': tensor([[0.6489, 0.0320, 0.2625, 0.8833, 0.0545, 0.6756, 0.1604, 0.1900],
        [0.5600, 0.1550, 0.4830, 0.9901, 0.7525, 0.7140, 0.2213, 0.3848],
        [0.9487, 0.0209, 0.2177, 0.6618, 0.7286, 0.7615, 0.0964, 0.4275]]), 'b': tensor([0.2169, 0.8077, 0.7291, 0.2017, 0.9401, 0.4131]), 'c': {'d': tensor([4])}}, {'a': tensor([[0.7510, 0.6268, 0.8116, 0.1843, 0.0712, 0.7653, 0.0770, 0.9667],
        [0.5327, 0.2731, 0.9119, 0.0957, 0.3421, 0.8421, 0.7524, 0.9841],
        [0.4037, 0.5359, 0.0451, 0.8508, 0.2864, 0.8289, 0.6049, 0.8590]]), 'b': tensor([0.9390, 0.0167, 0.3553, 0.2527, 0.6955, 0.0793]), 'c': {'d': tensor([3])}}, {'a': tensor([[0.2746, 0.6944, 0.0830, 0.9218, 0.3689, 0.6639, 0.9689, 0.8173],
        [0.4369, 0.5265, 0.1490, 0.0346, 0.7360, 0.2842, 0.5820, 0.1121],
        [0.8793, 0.7142, 0.8760, 0.7279, 0.4149, 0.9174, 0.4245, 0.6824]]), 'b': tensor([0.1762, 0.6831, 0.7020, 0.2706, 0.3201, 0.9509]), 'c': {'d': tensor([4])}}]



(<Tensor 0x7ffb457152e0>
├── 'a' --> tensor([[[0.4572, 0.6080, 0.6765, 0.4616, 0.7006, 0.7138, 0.8332, 0.5279],
│                    [0.5234, 0.1879, 0.9367, 0.2713, 0.7517, 0.4578, 0.8312, 0.1275],
│                    [0.5743, 0.4873, 0.6173, 0.8256, 0.9331, 0.5081, 0.7060, 0.8910]]])
├── 'b' --> tensor([[1.8476, 1.9375, 1.0674, 1.3423, 1.6922, 2.0000]])
└── 'c' --> <Tensor 0x7ffb45715340>
    ├── 'd' --> tensor([[8.]])
    └── 'noise' --> tensor([[[[-4.6412e-01, -5.7780e-01,  8.0361e-02,  6.8379e-01, -1.9820e+00],
                              [ 1.2988e+00, -1.0754e+00,  2.8268e-01,  9.3715e-01, -3.8230e-01],
                              [ 5.1139e-01,  3.5472e-01,  3.1486e-01, -1.0898e+00, -1.4912e-01],
                              [-7.7645e-01, -9.3090e-01, -7.2468e-01, -6.2890e-01,  5.9210e-01]],
                    
                             [[-1.6167e+00, -4.9676e-02, -1.3844e+00,  1.1189e+00,  5.4293e-01],
                              [-1.2420e+00, -8.8608e-01,  1.8914e+00, -4.5926e-02,  3.0714e-01],
                              [ 4.6076e-01,  2.2135e-01,  4.6556e-01, -8.7983e-04,  7.5775e-01],
                              [ 1.0634e+00, -1.3722e+00,  8.8384e-01, -9.6968e-01,  2.7805e-01]],
                    
                             [[-1.0646e-01,  4.8121e-01,  3.7468e+00, -1.7515e+00,  6.9178e-01],
                              [-1.1562e-01, -2.0747e-01,  8.3824e-01,  8.6832e-02, -7.3484e-01],
                              [-4.9480e-01,  1.5194e+00, -6.6665e-02,  4.9116e-01, -4.9920e-01],
                              [ 6.5903e-02, -1.8747e+00, -1.3733e+00, -1.0089e+00,  6.9826e-01]]]])
, <Tensor 0x7ffb457153a0>
├── 'a' --> tensor([[[0.6489, 0.0320, 0.2625, 0.8833, 0.0545, 0.6756, 0.1604, 0.1900],
│                    [0.5600, 0.1550, 0.4830, 0.9901, 0.7525, 0.7140, 0.2213, 0.3848],
│                    [0.9487, 0.0209, 0.2177, 0.6618, 0.7286, 0.7615, 0.0964, 0.4275]]])
├── 'b' --> tensor([[1.0471, 1.6523, 1.5316, 1.0407, 1.8837, 1.1707]])
└── 'c' --> <Tensor 0x7ffb45715280>
    ├── 'd' --> tensor([[4.]])
    └── 'noise' --> tensor([[[[-0.1838,  0.7267,  0.0637,  0.6531, -1.8037],
                              [ 0.8776, -0.4766, -1.1770, -0.9471, -1.0448],
                              [ 0.2301, -1.2404,  1.1525, -0.3125, -1.1866],
                              [-0.7215,  0.6348,  0.0648,  0.0371,  0.4146]],
                    
                             [[ 0.6134, -0.5340, -1.0677,  0.2898, -0.6210],
                              [ 0.9603,  0.1196,  1.2600, -0.2438, -1.6112],
                              [-0.0975,  0.0962,  1.4514,  0.2641, -0.2844],
                              [-1.5464, -1.6153, -1.5432, -0.3866,  0.6795]],
                    
                             [[-1.0840,  1.0936,  0.7873,  0.7940,  0.1981],
                              [ 0.3649,  1.2445,  0.2280, -1.5792, -0.1313],
                              [-0.6048,  2.2737, -0.9488,  0.1801,  1.1222],
                              [-1.1301,  0.1280,  0.2690, -0.1754,  0.8604]]]])
, <Tensor 0x7ffb45715400>
├── 'a' --> tensor([[[0.7510, 0.6268, 0.8116, 0.1843, 0.0712, 0.7653, 0.0770, 0.9667],
│                    [0.5327, 0.2731, 0.9119, 0.0957, 0.3421, 0.8421, 0.7524, 0.9841],
│                    [0.4037, 0.5359, 0.0451, 0.8508, 0.2864, 0.8289, 0.6049, 0.8590]]])
├── 'b' --> tensor([[1.8817, 1.0003, 1.1263, 1.0638, 1.4837, 1.0063]])
└── 'c' --> <Tensor 0x7ffb457153d0>
    ├── 'd' --> tensor([[3.]])
    └── 'noise' --> tensor([[[[ 0.0319, -1.3346, -0.9664,  0.5136,  1.6280],
                              [-1.7270, -0.3694, -1.0170,  0.9421,  1.7422],
                              [ 0.6287, -0.1087, -0.6610,  0.5958, -1.8940],
                              [-1.0542,  0.0810, -0.6481, -1.7150,  0.5486]],
                    
                             [[ 0.4111,  0.4410,  0.7241, -2.3636, -0.4487],
                              [ 1.4885,  0.3644,  0.6788, -0.9568, -1.0254],
                              [ 0.8320, -0.2660,  0.3278, -0.9613,  0.5035],
                              [-0.1423, -1.0684, -0.1095,  0.3794, -1.1159]],
                    
                             [[ 0.1101, -1.6556, -0.2003, -0.0055, -1.6199],
                              [ 0.2687, -0.8249, -0.1211, -1.0027, -0.1828],
                              [-1.7109,  0.3344,  1.0112,  0.5179,  0.3586],
                              [-1.6385,  0.2950, -0.0268, -1.7857,  0.2965]]]])
, <Tensor 0x7ffb45715460>
├── 'a' --> tensor([[[0.2746, 0.6944, 0.0830, 0.9218, 0.3689, 0.6639, 0.9689, 0.8173],
│                    [0.4369, 0.5265, 0.1490, 0.0346, 0.7360, 0.2842, 0.5820, 0.1121],
│                    [0.8793, 0.7142, 0.8760, 0.7279, 0.4149, 0.9174, 0.4245, 0.6824]]])
├── 'b' --> tensor([[1.0310, 1.4666, 1.4927, 1.0732, 1.1024, 1.9043]])
└── 'c' --> <Tensor 0x7ffb45715430>
    ├── 'd' --> tensor([[4.]])
    └── 'noise' --> tensor([[[[-0.1884, -1.1912, -0.0673, -0.3911,  0.6435],
                              [ 0.8307, -0.6875, -0.6923, -1.0944,  0.8886],
                              [-3.0533,  0.7851, -1.2864, -2.0481, -0.2371],
                              [-0.3512,  0.1200,  0.3074, -0.2534,  0.5491]],
                    
                             [[-0.7389,  1.0999,  0.2123,  1.1953,  0.0311],
                              [-1.4438, -0.7222,  0.3789, -0.2664,  0.5498],
                              [ 0.5244,  0.4530,  1.4953,  0.9694,  1.0271],
                              [ 1.6876, -0.7055,  0.9537,  0.7722,  0.4099]],
                    
                             [[ 0.1919,  1.8940, -0.5216,  0.9805,  0.4321],
                              [ 0.4256,  1.4371,  0.2084, -0.2604, -0.2164],
                              [-0.1987,  0.7131,  2.0972,  0.7461, -0.4851],
                              [-0.2310, -1.3240, -1.4171, -1.7092,  1.2812]]]])
)
mean0 & mean1: tensor(1.4102) tensor(1.4102)


even_index_a0: tensor([[[0.4572, 0.6080, 0.6765, 0.4616, 0.7006, 0.7138, 0.8332, 0.5279],
         [0.5743, 0.4873, 0.6173, 0.8256, 0.9331, 0.5081, 0.7060, 0.8910]],

        [[0.6489, 0.0320, 0.2625, 0.8833, 0.0545, 0.6756, 0.1604, 0.1900],
         [0.9487, 0.0209, 0.2177, 0.6618, 0.7286, 0.7615, 0.0964, 0.4275]],

        [[0.7510, 0.6268, 0.8116, 0.1843, 0.0712, 0.7653, 0.0770, 0.9667],
         [0.4037, 0.5359, 0.0451, 0.8508, 0.2864, 0.8289, 0.6049, 0.8590]],

        [[0.2746, 0.6944, 0.0830, 0.9218, 0.3689, 0.6639, 0.9689, 0.8173],
         [0.8793, 0.7142, 0.8760, 0.7279, 0.4149, 0.9174, 0.4245, 0.6824]]])
even_index_a1: tensor([[[0.4572, 0.6080, 0.6765, 0.4616, 0.7006, 0.7138, 0.8332, 0.5279],
         [0.5743, 0.4873, 0.6173, 0.8256, 0.9331, 0.5081, 0.7060, 0.8910]],

        [[0.6489, 0.0320, 0.2625, 0.8833, 0.0545, 0.6756, 0.1604, 0.1900],
         [0.9487, 0.0209, 0.2177, 0.6618, 0.7286, 0.7615, 0.0964, 0.4275]],

        [[0.7510, 0.6268, 0.8116, 0.1843, 0.0712, 0.7653, 0.0770, 0.9667],
         [0.4037, 0.5359, 0.0451, 0.8508, 0.2864, 0.8289, 0.6049, 0.8590]],

        [[0.2746, 0.6944, 0.0830, 0.9218, 0.3689, 0.6639, 0.9689, 0.8173],
         [0.8793, 0.7142, 0.8760, 0.7279, 0.4149, 0.9174, 0.4245, 0.6824]]])
<Size 0x7ffbac129f40>
├── 'a' --> torch.Size([1, 3, 8])
├── 'b' --> torch.Size([1, 6])
└── 'c' --> <Size 0x7ffb45710ee0>
    ├── 'd' --> torch.Size([1, 1])
    └── 'noise' --> torch.Size([1, 3, 4, 5])

The implement with treetensor API is much simpler and clearer.