Customized Operations For Different Fields

Here is another example of the custom operations implemented with both native torch API and treetensor API.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import copy

import torch

import treetensor.torch as ttorch

T, B = 3, 4


def with_nativetensor(batch_):
    mean_b_list = []
    even_index_a_list = []
    for i in range(len(batch_)):
        for k, v in batch_[i].items():
            if k == 'a':
                v = v.float()
                even_index_a_list.append(v[::2])
            elif k == 'b':
                v = v.float()
                transformed_v = torch.pow(v, 2) + 1.0
                mean_b_list.append(transformed_v.mean())
            elif k == 'c':
                for k1, v1 in v.items():
                    if k1 == 'd':
                        v1 = v1.float()
                    else:
                        print('ignore keys: {}'.format(k1))
            else:
                print('ignore keys: {}'.format(k))
    for i in range(len(batch_)):
        for k in batch_[i].keys():
            if k == 'd':
                batch_[i][k]['noise'] = torch.randn(size=(3, 4, 5))

    mean_b = sum(mean_b_list) / len(mean_b_list)
    even_index_a = torch.stack(even_index_a_list, dim=0)
    return batch_, mean_b, even_index_a


def with_treetensor(batch_):
    batch_ = [ttorch.tensor(b) for b in batch_]
    batch_ = ttorch.stack(batch_)
    batch_ = batch_.float()
    batch_.b = ttorch.pow(batch_.b, 2) + 1.0
    batch_.c.noise = ttorch.randn(size=(B, 3, 4, 5))
    mean_b = batch_.b.mean()
    even_index_a = batch_.a[:, ::2]
    batch_ = ttorch.split(batch_, split_size_or_sections=1, dim=0)
    return batch_, mean_b, even_index_a


def get_data():
    return {
        'a': torch.rand(size=(T, 8)),
        'b': torch.rand(size=(6,)),
        'c': {
            'd': torch.randint(0, 10, size=(1,))
        }
    }


if __name__ == "__main__":
    batch = [get_data() for _ in range(B)]
    batch0, mean0, even_index_a0 = with_nativetensor(copy.deepcopy(batch))
    batch1, mean1, even_index_a1 = with_treetensor(copy.deepcopy(batch))
    print(batch0)
    print('\n\n')
    print(batch1)

    assert torch.abs(mean0 - mean1) < 1e-6
    print('mean0 & mean1:', mean0, mean1)
    print('\n')

    assert torch.abs((even_index_a0 - even_index_a1).max()) < 1e-6
    print('even_index_a0:', even_index_a0)
    print('even_index_a1:', even_index_a1)

    assert len(batch0) == B
    assert len(batch1) == B
    assert isinstance(batch1[0], ttorch.Tensor)
    print(batch1[0].shape)

The output should be like below, and all the assertions can be passed.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
[{'a': tensor([[0.8053, 0.7445, 0.0793, 0.3184, 0.2756, 0.8252, 0.3445, 0.0752],
        [0.7182, 0.5773, 0.0218, 0.0600, 0.9528, 0.9006, 0.4050, 0.3537],
        [0.5020, 0.0985, 0.2068, 0.3102, 0.8672, 0.8716, 0.3657, 0.0718]]), 'b': tensor([0.5740, 0.5592, 0.7248, 0.7438, 0.5457, 0.7120]), 'c': {'d': tensor([0])}}, {'a': tensor([[0.5308, 0.7016, 0.6871, 0.0403, 0.9346, 0.9358, 0.0203, 0.9230],
        [0.2047, 0.4180, 0.1184, 0.1707, 0.3309, 0.0488, 0.1547, 0.3050],
        [0.8321, 0.4029, 0.9024, 0.6745, 0.2132, 0.8717, 0.2416, 0.9677]]), 'b': tensor([0.9478, 0.9531, 0.5402, 0.4392, 0.3258, 0.4395]), 'c': {'d': tensor([5])}}, {'a': tensor([[0.7133, 0.4646, 0.2921, 0.7726, 0.6104, 0.5284, 0.2939, 0.3569],
        [0.8337, 0.0498, 0.7647, 0.9300, 0.3821, 0.8607, 0.4000, 0.6861],
        [0.8760, 0.1902, 0.0465, 0.0177, 0.7976, 0.9254, 0.6127, 0.8822]]), 'b': tensor([0.5722, 0.8562, 0.5941, 0.8629, 0.7372, 0.4695]), 'c': {'d': tensor([7])}}, {'a': tensor([[0.1882, 0.1776, 0.0793, 0.3815, 0.4872, 0.0375, 0.6545, 0.2889],
        [0.0571, 0.8010, 0.0199, 0.3678, 0.5668, 0.4459, 0.4310, 0.0160],
        [0.6870, 0.2658, 0.9489, 0.0055, 0.5520, 0.0831, 0.1160, 0.0748]]), 'b': tensor([0.3212, 0.9066, 0.8391, 0.7212, 0.3519, 0.4270]), 'c': {'d': tensor([7])}}]



(<Tensor 0x7f747b6b64f0>
├── 'a' --> tensor([[[0.8053, 0.7445, 0.0793, 0.3184, 0.2756, 0.8252, 0.3445, 0.0752],
│                    [0.7182, 0.5773, 0.0218, 0.0600, 0.9528, 0.9006, 0.4050, 0.3537],
│                    [0.5020, 0.0985, 0.2068, 0.3102, 0.8672, 0.8716, 0.3657, 0.0718]]])
├── 'b' --> tensor([[1.3295, 1.3127, 1.5254, 1.5532, 1.2978, 1.5069]])
└── 'c' --> <Tensor 0x7f747b6b6550>
    ├── 'd' --> tensor([[0.]])
    └── 'noise' --> tensor([[[[-6.8501e-01,  4.6697e-01,  1.7599e+00,  1.2905e-01, -4.8906e-01],
                              [ 6.8946e-01,  1.7629e-01, -1.3696e+00,  1.1340e-03,  8.4019e-01],
                              [-1.2656e-01,  1.7946e+00,  1.2009e+00, -1.5991e+00,  1.8196e-01],
                              [ 1.2336e+00,  2.5105e-01,  6.7169e-02, -1.6914e+00, -3.5321e-01]],
                    
                             [[-4.8059e-01, -9.3635e-01, -1.2018e+00,  1.1593e+00, -3.0129e-01],
                              [ 1.0256e+00, -1.1489e+00,  4.6913e-01, -3.7563e-01, -9.6251e-01],
                              [-2.3017e+00,  1.1598e+00,  1.7594e+00, -2.4179e-01, -8.0267e-01],
                              [-5.4404e-01,  1.0516e+00, -6.2866e-01,  1.4423e+00, -1.6218e-02]],
                    
                             [[ 1.3444e+00, -6.3613e-01,  1.1578e+00, -3.5533e-01, -2.5801e-01],
                              [-8.2113e-01, -1.8916e+00,  9.3748e-01, -2.0547e+00,  1.5202e+00],
                              [-3.1873e-01,  2.6639e-02,  1.6233e+00,  4.8557e-01, -1.6574e+00],
                              [-8.8723e-02, -2.0476e+00,  1.1273e+00, -1.1973e+00, -9.3022e-01]]]])
, <Tensor 0x7f747b6b6610>
├── 'a' --> tensor([[[0.5308, 0.7016, 0.6871, 0.0403, 0.9346, 0.9358, 0.0203, 0.9230],
│                    [0.2047, 0.4180, 0.1184, 0.1707, 0.3309, 0.0488, 0.1547, 0.3050],
│                    [0.8321, 0.4029, 0.9024, 0.6745, 0.2132, 0.8717, 0.2416, 0.9677]]])
├── 'b' --> tensor([[1.8983, 1.9084, 1.2918, 1.1929, 1.1061, 1.1931]])
└── 'c' --> <Tensor 0x7f747b6b6430>
    ├── 'd' --> tensor([[5.]])
    └── 'noise' --> tensor([[[[ 2.9140e+00,  1.5177e+00, -1.6288e+00,  1.9607e-01,  2.1942e-01],
                              [ 2.3884e+00, -5.4665e-01,  1.5765e-01,  2.4808e-01, -1.6871e+00],
                              [ 5.9751e-01,  1.9206e-01,  1.0923e+00, -7.0509e-01, -2.5146e-01],
                              [ 8.7729e-01,  1.3504e+00, -7.5557e-01, -1.3701e-01, -5.8436e-01]],
                    
                             [[ 5.3209e-01, -1.3542e+00,  1.5011e+00, -8.6212e-01,  1.7817e+00],
                              [ 7.2772e-01, -1.6131e+00, -1.0709e+00, -9.3635e-01,  4.2625e-01],
                              [ 1.2432e+00,  1.5236e+00, -4.9223e-01,  8.4649e-01,  1.4618e+00],
                              [ 4.6363e-01,  6.8346e-01, -8.6416e-01,  1.7699e-01,  1.6948e+00]],
                    
                             [[ 5.6625e-01,  9.9995e-01,  1.0111e-01,  4.8011e-04,  9.1063e-01],
                              [-4.7770e-01,  1.7299e+00,  9.1445e-01, -2.8423e+00, -4.2896e-01],
                              [ 1.0115e+00, -6.0075e-01,  9.9652e-02,  9.1871e-01,  9.7710e-01],
                              [-6.0919e-02, -1.8909e-01,  9.6011e-01, -1.5709e-01,  1.3108e-01]]]])
, <Tensor 0x7f747b6b6670>
├── 'a' --> tensor([[[0.7133, 0.4646, 0.2921, 0.7726, 0.6104, 0.5284, 0.2939, 0.3569],
│                    [0.8337, 0.0498, 0.7647, 0.9300, 0.3821, 0.8607, 0.4000, 0.6861],
│                    [0.8760, 0.1902, 0.0465, 0.0177, 0.7976, 0.9254, 0.6127, 0.8822]]])
├── 'b' --> tensor([[1.3274, 1.7331, 1.3530, 1.7447, 1.5434, 1.2204]])
└── 'c' --> <Tensor 0x7f747b6b6640>
    ├── 'd' --> tensor([[7.]])
    └── 'noise' --> tensor([[[[-0.3763, -0.7753, -0.4001,  0.0321,  0.2464],
                              [-0.7008,  0.1068, -0.1251,  2.1447,  0.1137],
                              [ 1.2222, -0.0875, -1.0401, -0.3119,  1.0039],
                              [-0.4405,  0.6470,  0.9651, -0.2678, -0.9768]],
                    
                             [[-0.7289,  0.9036,  0.0264, -0.0244,  0.0816],
                              [-0.7036, -0.0935,  0.7025, -1.4874,  0.4035],
                              [ 0.6441,  0.4094, -0.3827,  0.0542, -1.2560],
                              [-1.2502,  0.4918,  0.4921,  0.9951, -1.8664]],
                    
                             [[-1.5643, -1.6597,  0.0849,  1.6072, -2.0404],
                              [-0.0367, -0.0865,  0.7401,  1.2236,  0.9519],
                              [-0.5182, -0.0128, -0.1113,  0.6622,  0.3690],
                              [-0.0197, -0.1582, -1.0284, -1.7841,  0.0648]]]])
, <Tensor 0x7f747b6b66d0>
├── 'a' --> tensor([[[0.1882, 0.1776, 0.0793, 0.3815, 0.4872, 0.0375, 0.6545, 0.2889],
│                    [0.0571, 0.8010, 0.0199, 0.3678, 0.5668, 0.4459, 0.4310, 0.0160],
│                    [0.6870, 0.2658, 0.9489, 0.0055, 0.5520, 0.0831, 0.1160, 0.0748]]])
├── 'b' --> tensor([[1.1032, 1.8219, 1.7040, 1.5201, 1.1238, 1.1823]])
└── 'c' --> <Tensor 0x7f747b6b66a0>
    ├── 'd' --> tensor([[7.]])
    └── 'noise' --> tensor([[[[ 1.0370,  1.3408, -0.4759,  0.3349,  0.9377],
                              [ 0.1157, -0.8470, -0.0512, -0.3285, -0.6335],
                              [ 0.6802,  0.6821, -0.2242,  0.5975, -0.5512],
                              [ 1.8505, -0.0149,  0.4083, -0.9496, -1.5367]],
                    
                             [[-1.0895,  0.5305, -0.4091,  0.1896, -0.3201],
                              [-1.7025,  1.3947, -0.5616, -0.7856,  1.7454],
                              [-1.3787, -1.3461, -0.0710,  0.9998, -1.4652],
                              [-0.6107,  1.1829,  0.9042,  1.1213, -0.4547]],
                    
                             [[-2.7062, -0.9415, -0.3098,  0.5568,  0.2289],
                              [ 0.7780, -0.8436,  0.0879, -1.7352, -0.1688],
                              [-1.3940, -0.8162,  1.5604, -1.2692,  1.2674],
                              [ 1.1491,  1.8704, -0.5248, -1.1473, -0.4696]]]])
)
mean0 & mean1: tensor(1.4372) tensor(1.4372)


even_index_a0: tensor([[[0.8053, 0.7445, 0.0793, 0.3184, 0.2756, 0.8252, 0.3445, 0.0752],
         [0.5020, 0.0985, 0.2068, 0.3102, 0.8672, 0.8716, 0.3657, 0.0718]],

        [[0.5308, 0.7016, 0.6871, 0.0403, 0.9346, 0.9358, 0.0203, 0.9230],
         [0.8321, 0.4029, 0.9024, 0.6745, 0.2132, 0.8717, 0.2416, 0.9677]],

        [[0.7133, 0.4646, 0.2921, 0.7726, 0.6104, 0.5284, 0.2939, 0.3569],
         [0.8760, 0.1902, 0.0465, 0.0177, 0.7976, 0.9254, 0.6127, 0.8822]],

        [[0.1882, 0.1776, 0.0793, 0.3815, 0.4872, 0.0375, 0.6545, 0.2889],
         [0.6870, 0.2658, 0.9489, 0.0055, 0.5520, 0.0831, 0.1160, 0.0748]]])
even_index_a1: tensor([[[0.8053, 0.7445, 0.0793, 0.3184, 0.2756, 0.8252, 0.3445, 0.0752],
         [0.5020, 0.0985, 0.2068, 0.3102, 0.8672, 0.8716, 0.3657, 0.0718]],

        [[0.5308, 0.7016, 0.6871, 0.0403, 0.9346, 0.9358, 0.0203, 0.9230],
         [0.8321, 0.4029, 0.9024, 0.6745, 0.2132, 0.8717, 0.2416, 0.9677]],

        [[0.7133, 0.4646, 0.2921, 0.7726, 0.6104, 0.5284, 0.2939, 0.3569],
         [0.8760, 0.1902, 0.0465, 0.0177, 0.7976, 0.9254, 0.6127, 0.8822]],

        [[0.1882, 0.1776, 0.0793, 0.3815, 0.4872, 0.0375, 0.6545, 0.2889],
         [0.6870, 0.2658, 0.9489, 0.0055, 0.5520, 0.0831, 0.1160, 0.0748]]])
<Size 0x7f74e7cee3d0>
├── 'a' --> torch.Size([1, 3, 8])
├── 'b' --> torch.Size([1, 6])
└── 'c' --> <Size 0x7f747b7318b0>
    ├── 'd' --> torch.Size([1, 1])
    └── 'noise' --> torch.Size([1, 3, 4, 5])

The implement with treetensor API is much simpler and clearer.