Customized Operations For Different Fields

Here is another example of the custom operations implemented with both native torch API and treetensor API.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import copy

import torch

import treetensor.torch as ttorch

T, B = 3, 4


def with_nativetensor(batch_):
    mean_b_list = []
    even_index_a_list = []
    for i in range(len(batch_)):
        for k, v in batch_[i].items():
            if k == 'a':
                v = v.float()
                even_index_a_list.append(v[::2])
            elif k == 'b':
                v = v.float()
                transformed_v = torch.pow(v, 2) + 1.0
                mean_b_list.append(transformed_v.mean())
            elif k == 'c':
                for k1, v1 in v.items():
                    if k1 == 'd':
                        v1 = v1.float()
                    else:
                        print('ignore keys: {}'.format(k1))
            else:
                print('ignore keys: {}'.format(k))
    for i in range(len(batch_)):
        for k in batch_[i].keys():
            if k == 'd':
                batch_[i][k]['noise'] = torch.randn(size=(3, 4, 5))

    mean_b = sum(mean_b_list) / len(mean_b_list)
    even_index_a = torch.stack(even_index_a_list, dim=0)
    return batch_, mean_b, even_index_a


def with_treetensor(batch_):
    batch_ = [ttorch.tensor(b) for b in batch_]
    batch_ = ttorch.stack(batch_)
    batch_ = batch_.float()
    batch_.b = ttorch.pow(batch_.b, 2) + 1.0
    batch_.c.noise = ttorch.randn(size=(B, 3, 4, 5))
    mean_b = batch_.b.mean()
    even_index_a = batch_.a[:, ::2]
    batch_ = ttorch.split(batch_, split_size_or_sections=1, dim=0)
    return batch_, mean_b, even_index_a


def get_data():
    return {
        'a': torch.rand(size=(T, 8)),
        'b': torch.rand(size=(6,)),
        'c': {
            'd': torch.randint(0, 10, size=(1,))
        }
    }


if __name__ == "__main__":
    batch = [get_data() for _ in range(B)]
    batch0, mean0, even_index_a0 = with_nativetensor(copy.deepcopy(batch))
    batch1, mean1, even_index_a1 = with_treetensor(copy.deepcopy(batch))
    print(batch0)
    print('\n\n')
    print(batch1)

    assert torch.abs(mean0 - mean1) < 1e-6
    print('mean0 & mean1:', mean0, mean1)
    print('\n')

    assert torch.abs((even_index_a0 - even_index_a1).max()) < 1e-6
    print('even_index_a0:', even_index_a0)
    print('even_index_a1:', even_index_a1)

    assert len(batch0) == B
    assert len(batch1) == B
    assert isinstance(batch1[0], ttorch.Tensor)
    print(batch1[0].shape)

The output should be like below, and all the assertions can be passed.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
[{'a': tensor([[0.1018, 0.1229, 0.9646, 0.9868, 0.1387, 0.0339, 0.1518, 0.7598],
        [0.6458, 0.7569, 0.2063, 0.3366, 0.7734, 0.7482, 0.6035, 0.8083],
        [0.9991, 0.6408, 0.8581, 0.5391, 0.9900, 0.7575, 0.9536, 0.2315]]), 'b': tensor([0.6318, 0.3348, 0.1966, 0.4083, 0.5994, 0.5446]), 'c': {'d': tensor([1])}}, {'a': tensor([[0.3198, 0.6870, 0.1427, 0.1506, 0.2779, 0.4413, 0.7380, 0.3396],
        [0.3949, 0.3718, 0.5813, 0.4857, 0.6542, 0.0823, 0.5097, 0.6795],
        [0.1236, 0.3174, 0.1848, 0.9525, 0.8351, 0.1266, 0.8205, 0.8777]]), 'b': tensor([0.3457, 0.1058, 0.0081, 0.0902, 0.2323, 0.8620]), 'c': {'d': tensor([6])}}, {'a': tensor([[0.6110, 0.7065, 0.9817, 0.5466, 0.9681, 0.5271, 0.9148, 0.8300],
        [0.9745, 0.3634, 0.6490, 0.7711, 0.1651, 0.7651, 0.0820, 0.6859],
        [0.0513, 0.8733, 0.0833, 0.0243, 0.4405, 0.7640, 0.9343, 0.4131]]), 'b': tensor([0.6757, 0.6218, 0.2433, 0.3836, 0.1499, 0.0585]), 'c': {'d': tensor([1])}}, {'a': tensor([[0.7970, 0.5926, 0.9834, 0.6448, 0.3721, 0.1004, 0.0332, 0.3718],
        [0.8415, 0.5384, 0.6109, 0.2469, 0.3713, 0.3119, 0.5707, 0.2317],
        [0.3114, 0.2227, 0.2023, 0.6401, 0.1025, 0.0098, 0.7681, 0.4003]]), 'b': tensor([0.9396, 0.3754, 0.5673, 0.0941, 0.7585, 0.3193]), 'c': {'d': tensor([6])}}]



(<Tensor 0x7f4c11e15ac0>
├── 'a' --> tensor([[[0.1018, 0.1229, 0.9646, 0.9868, 0.1387, 0.0339, 0.1518, 0.7598],
│                    [0.6458, 0.7569, 0.2063, 0.3366, 0.7734, 0.7482, 0.6035, 0.8083],
│                    [0.9991, 0.6408, 0.8581, 0.5391, 0.9900, 0.7575, 0.9536, 0.2315]]])
├── 'b' --> tensor([[1.3991, 1.1121, 1.0386, 1.1667, 1.3592, 1.2966]])
└── 'c' --> <Tensor 0x7f4c11e15b20>
    ├── 'd' --> tensor([[1.]])
    └── 'noise' --> tensor([[[[-0.5880,  2.2755,  0.7484,  0.0287,  0.3380],
                              [ 0.4480,  0.3728, -1.5667,  0.9229, -0.0397],
                              [ 1.1858,  0.5950,  1.3761, -0.5896, -0.3640],
                              [ 0.1571, -1.1151,  0.3706, -0.1462, -0.4969]],
                    
                             [[-1.2671,  1.5281,  2.4829, -0.5331,  0.5884],
                              [ 0.3366,  1.9647, -1.0251,  1.3016, -0.5457],
                              [-0.6922, -0.3392,  1.2724, -1.5485,  0.5084],
                              [ 0.6813, -0.4010,  1.5073,  1.5329, -0.5679]],
                    
                             [[ 0.8735, -0.3361, -0.1501, -1.1746,  0.4980],
                              [-1.8489,  1.6243, -1.3073, -1.1780,  0.9063],
                              [ 1.4177,  1.6821,  2.6253, -0.5108,  0.1157],
                              [ 2.6025, -2.1618,  2.0695, -0.3414,  1.9084]]]])
, <Tensor 0x7f4c11e15be0>
├── 'a' --> tensor([[[0.3198, 0.6870, 0.1427, 0.1506, 0.2779, 0.4413, 0.7380, 0.3396],
│                    [0.3949, 0.3718, 0.5813, 0.4857, 0.6542, 0.0823, 0.5097, 0.6795],
│                    [0.1236, 0.3174, 0.1848, 0.9525, 0.8351, 0.1266, 0.8205, 0.8777]]])
├── 'b' --> tensor([[1.1195, 1.0112, 1.0001, 1.0081, 1.0539, 1.7430]])
└── 'c' --> <Tensor 0x7f4c11e15a30>
    ├── 'd' --> tensor([[6.]])
    └── 'noise' --> tensor([[[[ 0.5339, -0.5147, -1.0812, -2.5419,  0.8775],
                              [-1.2760,  0.8684, -0.2687,  1.6961,  1.1588],
                              [-0.6761,  0.1458,  0.5305,  1.5527, -1.1650],
                              [ 1.0163,  0.5603,  0.9785, -0.1793,  0.0314]],
                    
                             [[-0.1252,  1.1299,  0.3169,  1.5040, -0.5238],
                              [ 0.2552,  0.3170,  0.9385, -0.4486, -0.7811],
                              [-0.1173, -0.9089, -2.1491,  0.3983, -1.2886],
                              [ 0.5881, -0.6633, -1.0676,  0.1214,  0.2026]],
                    
                             [[ 0.1844,  0.5130,  0.5463,  0.2661,  1.2480],
                              [ 0.7749, -0.6297,  1.2171,  1.2214, -0.0714],
                              [ 0.0407, -0.1869,  0.6574,  1.1993, -0.3189],
                              [-0.8170,  0.8329, -1.0978,  0.6092,  1.5736]]]])
, <Tensor 0x7f4c11e15c40>
├── 'a' --> tensor([[[0.6110, 0.7065, 0.9817, 0.5466, 0.9681, 0.5271, 0.9148, 0.8300],
│                    [0.9745, 0.3634, 0.6490, 0.7711, 0.1651, 0.7651, 0.0820, 0.6859],
│                    [0.0513, 0.8733, 0.0833, 0.0243, 0.4405, 0.7640, 0.9343, 0.4131]]])
├── 'b' --> tensor([[1.4565, 1.3867, 1.0592, 1.1471, 1.0225, 1.0034]])
└── 'c' --> <Tensor 0x7f4c11e15c10>
    ├── 'd' --> tensor([[1.]])
    └── 'noise' --> tensor([[[[ 0.0383,  0.8180,  0.3472,  1.3140, -0.7884],
                              [-1.1390,  0.3253, -0.4337, -0.9790,  0.3086],
                              [-1.9032,  0.3683, -1.0571, -0.4134,  0.0734],
                              [-0.8027, -3.0336, -0.6686, -0.1992, -0.5842]],
                    
                             [[-1.2057,  1.8317, -2.0265, -0.5646, -0.3632],
                              [-0.3409,  0.8284,  0.4916,  1.0882, -0.4506],
                              [-0.1837,  0.3143, -0.1099, -2.1298,  0.1751],
                              [-0.9796, -1.7407, -0.8685,  1.1067,  1.2986]],
                    
                             [[-1.7898,  0.6996,  1.1994,  0.2322,  0.6311],
                              [-0.7834,  0.9913, -1.0560,  1.2750,  0.3016],
                              [ 1.5431,  0.0910,  0.4831,  0.3254, -0.7582],
                              [ 1.9326,  0.6189,  1.0870, -0.6356,  0.7896]]]])
, <Tensor 0x7f4c11e15ca0>
├── 'a' --> tensor([[[0.7970, 0.5926, 0.9834, 0.6448, 0.3721, 0.1004, 0.0332, 0.3718],
│                    [0.8415, 0.5384, 0.6109, 0.2469, 0.3713, 0.3119, 0.5707, 0.2317],
│                    [0.3114, 0.2227, 0.2023, 0.6401, 0.1025, 0.0098, 0.7681, 0.4003]]])
├── 'b' --> tensor([[1.8828, 1.1409, 1.3218, 1.0089, 1.5753, 1.1019]])
└── 'c' --> <Tensor 0x7f4c11e15c70>
    ├── 'd' --> tensor([[6.]])
    └── 'noise' --> tensor([[[[ 0.0989,  0.0352, -1.6074,  0.6577, -0.6247],
                              [ 1.1357, -0.4504,  1.9912, -0.5223,  0.7360],
                              [-0.1203,  1.2087,  0.0854, -0.3729,  0.8584],
                              [-1.7039,  1.6009, -0.6874,  0.0073, -0.0476]],
                    
                             [[ 0.4003,  0.6010,  0.7245, -0.2537, -2.3287],
                              [ 0.9332, -0.9192, -0.6174, -0.7217,  1.3059],
                              [ 2.2037, -1.1794,  1.8018, -0.3159,  0.0279],
                              [-1.3474, -1.3609,  0.2150,  1.2202, -1.0155]],
                    
                             [[ 0.3434, -0.3591, -1.1696, -0.3053,  1.3126],
                              [-0.8759,  1.0126, -1.4366, -2.1002, -0.2968],
                              [-1.2088,  0.2267,  0.9093, -1.4701, -1.4230],
                              [ 0.3305, -0.8960,  1.1042,  2.1353, -1.1125]]]])
)
mean0 & mean1: tensor(1.2256) tensor(1.2256)


even_index_a0: tensor([[[0.1018, 0.1229, 0.9646, 0.9868, 0.1387, 0.0339, 0.1518, 0.7598],
         [0.9991, 0.6408, 0.8581, 0.5391, 0.9900, 0.7575, 0.9536, 0.2315]],

        [[0.3198, 0.6870, 0.1427, 0.1506, 0.2779, 0.4413, 0.7380, 0.3396],
         [0.1236, 0.3174, 0.1848, 0.9525, 0.8351, 0.1266, 0.8205, 0.8777]],

        [[0.6110, 0.7065, 0.9817, 0.5466, 0.9681, 0.5271, 0.9148, 0.8300],
         [0.0513, 0.8733, 0.0833, 0.0243, 0.4405, 0.7640, 0.9343, 0.4131]],

        [[0.7970, 0.5926, 0.9834, 0.6448, 0.3721, 0.1004, 0.0332, 0.3718],
         [0.3114, 0.2227, 0.2023, 0.6401, 0.1025, 0.0098, 0.7681, 0.4003]]])
even_index_a1: tensor([[[0.1018, 0.1229, 0.9646, 0.9868, 0.1387, 0.0339, 0.1518, 0.7598],
         [0.9991, 0.6408, 0.8581, 0.5391, 0.9900, 0.7575, 0.9536, 0.2315]],

        [[0.3198, 0.6870, 0.1427, 0.1506, 0.2779, 0.4413, 0.7380, 0.3396],
         [0.1236, 0.3174, 0.1848, 0.9525, 0.8351, 0.1266, 0.8205, 0.8777]],

        [[0.6110, 0.7065, 0.9817, 0.5466, 0.9681, 0.5271, 0.9148, 0.8300],
         [0.0513, 0.8733, 0.0833, 0.0243, 0.4405, 0.7640, 0.9343, 0.4131]],

        [[0.7970, 0.5926, 0.9834, 0.6448, 0.3721, 0.1004, 0.0332, 0.3718],
         [0.3114, 0.2227, 0.2023, 0.6401, 0.1025, 0.0098, 0.7681, 0.4003]]])
<Size 0x7f4c71ce3460>
├── 'a' --> torch.Size([1, 3, 8])
├── 'b' --> torch.Size([1, 6])
└── 'c' --> <Size 0x7f4c11e8f460>
    ├── 'd' --> torch.Size([1, 1])
    └── 'noise' --> torch.Size([1, 3, 4, 5])

The implement with treetensor API is much simpler and clearer.