Stack Structured Data

When the tensors form the tree structures, they are often needed to be stacked together, like the torch.stack() implemented in torch.

Stack With Native PyTorch API

Here is the common code implement with native pytorch API.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch

B = 4


def get_item():
    return {
        'obs': {
            'scalar': torch.randn(12),
            'image': torch.randn(3, 32, 32),
        },
        'action': torch.randint(0, 10, size=(1,)),
        'reward': torch.rand(1),
        'done': False,
    }


data = [get_item() for _ in range(B)]


# execute `stack` op
def stack(data, dim):
    elem = data[0]
    if isinstance(elem, torch.Tensor):
        return torch.stack(data, dim)
    elif isinstance(elem, dict):
        return {k: stack([item[k] for item in data], dim) for k in elem.keys()}
    elif isinstance(elem, bool):
        return torch.BoolTensor(data)
    else:
        raise TypeError("not support elem type: {}".format(type(elem)))


stacked_data = stack(data, dim=0)
# validate
print(stacked_data)
assert stacked_data['obs']['image'].shape == (B, 3, 32, 32)
assert stacked_data['action'].shape == (B, 1)
assert stacked_data['reward'].shape == (B, 1)
assert stacked_data['done'].shape == (B,)
assert stacked_data['done'].dtype == torch.bool

The output should be like below, and the assertion statements can be all passed.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
{'obs': {'scalar': tensor([[-0.3351, -0.1259,  0.3573,  0.1006, -0.1950,  0.1696, -0.3207, -0.8324,
          0.1065, -0.7528, -0.5869, -0.2768],
        [ 1.2297, -1.5818,  0.0084,  0.3185,  1.1294,  0.6031,  0.5275,  0.8198,
          0.5236,  0.5929, -1.6149,  1.0130],
        [-1.0182,  0.4831, -1.1071,  0.3499, -0.2571, -0.4848, -0.7927,  0.4253,
         -0.2021,  1.0703,  1.1100,  0.3981],
        [ 0.5490,  0.5042,  0.2389,  0.5608,  0.7692, -0.8354, -1.4992,  0.7702,
          0.4011, -0.0440,  0.0415,  1.8720]]), 'image': tensor([[[[ 1.3229, -0.8716, -0.5201,  ..., -1.7171, -0.3480, -0.6178],
          [ 0.5856, -1.9395, -0.1946,  ..., -0.2466, -0.0354, -0.8401],
          [-0.8362, -0.1538,  0.3545,  ..., -0.0742,  0.0822, -1.3594],
          ...,
          [-0.2681,  0.5599,  0.0481,  ...,  0.8415,  0.9783,  0.8391],
          [-0.5444,  0.7178,  1.3271,  ...,  2.1216, -1.0232,  0.6937],
          [ 1.2708, -0.6082,  0.0975,  ...,  0.5321,  0.5657,  0.6502]],

         [[-0.9882, -0.9879,  0.8478,  ...,  0.8742,  0.3866,  0.6720],
          [-0.0458,  0.1999, -1.0143,  ...,  0.1094,  0.8357, -0.5042],
          [-0.6810, -1.2305, -0.1122,  ...,  0.3082, -0.0464,  0.3935],
          ...,
          [ 0.4467,  0.0236,  0.5386,  ..., -0.2159, -0.2341, -0.5941],
          [-0.4505,  0.8014,  0.0963,  ...,  1.2086,  1.4461,  0.1361],
          [ 0.3138,  0.4180,  1.0563,  ..., -1.3622, -0.0071,  0.1185]],

         [[ 0.0205,  1.0267,  0.8875,  ...,  0.0946,  0.0641, -0.3391],
          [-0.2682,  0.3358, -1.5016,  ..., -0.1353,  1.3412,  3.1268],
          [-0.4951, -1.0584,  1.7760,  ...,  0.1599, -2.1287, -1.8060],
          ...,
          [-0.4016,  0.4433,  1.2721,  ..., -1.0459, -0.6383, -2.4636],
          [-1.2933, -1.1499, -0.9965,  ..., -1.4237,  0.5473, -0.2568],
          [-0.0671, -0.4307,  1.9886,  ...,  0.5697,  0.7320, -0.4353]]],


        [[[-0.1503, -1.8366, -0.4838,  ..., -1.1936,  1.0478, -1.7940],
          [ 0.3380,  0.8890,  0.9268,  ...,  0.0724, -0.1561, -0.9094],
          [ 0.1479,  0.8062, -0.3749,  ..., -1.3912, -2.3529,  0.8405],
          ...,
          [ 1.5979,  0.8244,  0.0859,  ...,  0.9944,  0.1761,  0.2597],
          [ 0.0925,  0.2674, -0.7231,  ..., -0.3440,  0.0486, -0.6796],
          [ 0.5608,  0.7240,  1.2998,  ...,  0.2972,  0.7439, -2.1746]],

         [[ 0.0820,  0.9264, -1.2592,  ..., -2.0586, -0.9630,  0.8584],
          [ 1.5254,  1.2367, -0.0197,  ...,  0.9132,  0.8712, -1.5988],
          [-1.7202,  0.3157,  1.3335,  ..., -0.6468, -0.3373,  0.1145],
          ...,
          [-1.2744,  1.4311, -0.3819,  ...,  0.4480,  0.4738,  0.9918],
          [-1.1632,  0.6067, -0.9438,  ..., -0.4843, -1.0377,  1.3053],
          [ 1.7335, -1.1148,  1.5145,  ...,  1.1556,  0.8116, -0.5298]],

         [[-0.4737,  0.9869, -0.4846,  ..., -1.9288,  0.0879, -1.4880],
          [-0.9600,  2.0672,  0.7097,  ..., -0.1599, -0.8630,  1.5076],
          [ 0.3192, -0.4560, -0.8168,  ..., -2.0192,  0.5181, -2.0826],
          ...,
          [ 0.6605, -0.1408, -0.5407,  ...,  3.1973, -0.4022,  0.8287],
          [ 0.3709,  0.0871,  0.8658,  ..., -1.1408, -0.1646, -0.2891],
          [ 0.7675, -1.3801,  1.7737,  ...,  0.3924,  0.5483, -2.2317]]],


        [[[ 0.0230, -1.4104, -1.7528,  ...,  1.2160, -0.8705, -0.2576],
          [ 0.3742,  0.3509, -0.7124,  ..., -0.8648, -0.3988, -0.0461],
          [-0.1556,  0.1907, -0.4975,  ...,  0.6398, -0.5950,  1.1047],
          ...,
          [ 0.1687, -0.4023, -1.5695,  ...,  2.0084,  0.3169,  0.7543],
          [ 0.8956, -1.1573,  0.9436,  ...,  0.5605, -0.2544, -0.2831],
          [-0.9700,  1.1746, -1.6967,  ...,  2.8828,  0.3744, -1.0704]],

         [[ 0.9853,  0.2705, -0.5625,  ..., -2.5664,  0.3964, -0.7360],
          [-2.2696, -0.5592, -0.9178,  ..., -0.8681,  1.2429, -1.4404],
          [-0.2250,  0.4140,  0.1918,  ...,  0.2766, -0.0757, -0.3960],
          ...,
          [ 1.3200, -0.6153, -0.1150,  ...,  1.6429,  0.4845, -0.8107],
          [-0.9539, -0.0348, -1.0374,  ...,  0.1507, -0.9086, -0.6902],
          [ 1.1785,  0.1917, -1.2539,  ...,  0.3257,  1.2774,  0.6315]],

         [[-0.1545,  1.0468,  1.0480,  ..., -0.8229, -0.7238,  0.1174],
          [ 1.6378, -0.6208,  0.0344,  ...,  0.3586,  0.2879, -0.4291],
          [-0.8060, -0.1069, -1.2261,  ...,  1.4295,  0.0420,  1.2709],
          ...,
          [-0.4132, -1.4151,  0.0343,  ...,  0.7979, -0.5116, -1.4211],
          [-0.3901, -0.7760, -1.3162,  ..., -0.5963, -0.1922, -2.1373],
          [-1.1634, -1.4445, -0.9949,  ...,  0.2960, -0.7559, -0.5371]]],


        [[[ 2.0506, -0.2867,  0.5433,  ...,  0.7519,  0.4510,  0.8075],
          [-0.2614, -0.3523, -0.7353,  ...,  0.1183, -1.3313,  0.0712],
          [-1.8622, -2.0756, -0.8768,  ...,  1.2949,  1.2531,  0.4125],
          ...,
          [-0.5050,  1.4302, -0.0805,  ..., -0.4584,  1.2567,  0.8355],
          [ 2.6209,  0.9807,  0.1869,  ..., -1.5986,  1.4966,  0.1247],
          [-0.8346,  0.5572,  1.5136,  ..., -0.2018,  0.9896, -0.5918]],

         [[-0.8786,  1.3171, -0.2521,  ...,  0.7857, -0.5526,  1.9196],
          [ 0.4546, -1.0081, -1.6379,  ...,  2.4751,  0.8986,  0.1311],
          [ 0.6860,  1.8421,  0.5672,  ...,  0.8133, -1.8296,  0.4899],
          ...,
          [-1.4292,  0.7916,  0.4430,  ...,  0.6200, -0.3771,  0.2293],
          [-0.8640,  1.2343, -1.6491,  ..., -0.1046, -0.2318,  0.6356],
          [-2.2208,  1.5240, -1.8812,  ..., -0.7712,  0.3497, -1.3062]],

         [[-1.0601, -0.6546,  0.8033,  ...,  0.3120, -1.2025, -1.4959],
          [ 0.1110,  1.1086,  0.2919,  ...,  1.3319, -0.0159,  0.4849],
          [-1.4752, -0.2678, -1.3847,  ...,  1.2644, -1.9096, -0.1499],
          ...,
          [-2.3981, -1.0582,  1.4940,  ...,  0.6826,  0.2384, -0.5825],
          [ 0.5245,  0.2936, -0.9051,  ..., -1.3163,  0.8272, -0.6536],
          [-0.1354, -0.1840,  0.1083,  ..., -0.3179, -1.4856, -0.6886]]]])}, 'action': tensor([[1],
        [3],
        [1],
        [2]]), 'reward': tensor([[0.9338],
        [0.9579],
        [0.3797],
        [0.9969]]), 'done': tensor([False, False, False, False])}

We can see that the structure process function need to be fully implemented (like the function stack). This code is actually not clear, and due to hard coding, if you need to support more data types (such as integer), you must make special modifications to the function.

Stack With TreeTensor API

The same workflow can be implemented with treetensor API like the code below.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch

import treetensor.torch as ttorch

B = 4


def get_item():
    return {
        'obs': {
            'scalar': torch.randn(12),
            'image': torch.randn(3, 32, 32),
        },
        'action': torch.randint(0, 10, size=(1,)),
        'reward': torch.rand(1),
        'done': False,
    }


data = [get_item() for _ in range(B)]
# execute `stack` op
data = [ttorch.tensor(d) for d in data]
stacked_data = ttorch.stack(data, dim=0)
# validate
print(stacked_data)
assert stacked_data.obs.image.shape == (B, 3, 32, 32)
assert stacked_data.action.shape == (B, 1)
assert stacked_data.reward.shape == (B, 1)
assert stacked_data.done.shape == (B,)
assert stacked_data.done.dtype == torch.bool

The output should be like below, and the assertion statements can be all passed as well.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
<Tensor 0x7efd11b6e190>
├── 'action' --> tensor([[9],
│                        [6],
│                        [7],
│                        [4]])
├── 'done' --> tensor([False, False, False, False])
├── 'obs' --> <Tensor 0x7efd11b6e2b0>
│   ├── 'image' --> tensor([[[[-0.5732,  0.8344, -1.0223,  ..., -0.2090,  0.1845, -0.9798],
│   │                         [ 0.3833,  0.4508,  0.9805,  ...,  0.8677, -0.0809, -0.8127],
│   │                         [ 0.2668,  2.2395,  0.0136,  ..., -0.7457, -0.8102, -1.0628],
│   │                         ...,
│   │                         [ 0.4891,  0.5430, -1.1001,  ..., -3.0651,  0.4237,  0.1526],
│   │                         [-0.6681, -0.5715, -1.9152,  ..., -1.1833,  0.8430, -1.3117],
│   │                         [-0.1768, -0.5922,  0.0052,  ...,  0.2508, -0.2034, -0.6893]],
│   │               
│   │                        [[-0.7054, -0.5051, -0.5287,  ...,  0.0790, -0.5077, -0.1399],
│   │                         [ 1.0243, -0.0588,  0.0838,  ..., -0.5730, -0.9745, -0.8683],
│   │                         [ 0.7260, -0.7596, -0.1146,  ..., -1.4850, -1.9563,  2.1057],
│   │                         ...,
│   │                         [-1.0210,  0.5583, -1.1609,  ..., -0.1452,  1.0786,  1.7359],
│   │                         [ 0.4482,  0.7848, -1.1571,  ...,  1.1232, -0.6887,  1.5728],
│   │                         [ 1.5050, -1.8953,  1.0109,  ...,  0.8700,  0.8622,  0.2754]],
│   │               
│   │                        [[-1.1503, -0.7084,  1.0231,  ..., -0.5393,  0.9846, -0.1232],
│   │                         [-1.3168, -1.3793, -0.8136,  ...,  0.1536, -0.2737,  1.6785],
│   │                         [-0.8723,  1.8593, -1.3534,  ..., -0.6998,  0.8589, -0.7367],
│   │                         ...,
│   │                         [ 0.0724,  1.1717, -1.6241,  ...,  1.6922,  0.2225, -0.2648],
│   │                         [-0.3165,  1.0434,  1.8459,  ...,  0.2657, -1.0708,  0.2261],
│   │                         [-0.7817,  1.0338,  0.8447,  ..., -0.1364,  1.8607,  0.0036]]],
│   │               
│   │               
│   │                       [[[ 0.6728, -2.2775,  1.0129,  ..., -0.4786,  0.9160, -1.7867],
│   │                         [-1.2547, -0.2507, -0.1652,  ..., -1.7247, -0.5768,  0.6350],
│   │                         [-1.2193, -1.5167, -0.4552,  ..., -0.4723,  0.9449, -1.1327],
│   │                         ...,
│   │                         [-1.2971, -0.5279,  1.1395,  ..., -0.7726,  2.3904, -0.2529],
│   │                         [ 0.1931,  0.2596, -1.7800,  ...,  0.5067,  0.6260, -0.1827],
│   │                         [ 0.4150,  1.3600,  1.8252,  ...,  0.1649,  1.1977, -0.6525]],
│   │               
│   │                        [[-1.7694,  0.5962, -1.9379,  ...,  0.4922, -0.1481,  0.6446],
│   │                         [ 0.1537,  0.4932,  0.2944,  ..., -3.0711,  0.2793,  0.6164],
│   │                         [ 2.0914, -0.6013, -0.5937,  ...,  1.3267, -0.1042,  0.6393],
│   │                         ...,
│   │                         [-0.6306,  1.5955,  1.2344,  ...,  0.2686, -0.3381, -0.5856],
│   │                         [ 0.0046, -1.2674,  0.6883,  ..., -0.5549, -0.6916,  0.2441],
│   │                         [ 2.6666,  0.4556,  0.2041,  ...,  2.6959,  0.6204,  1.0789]],
│   │               
│   │                        [[-0.6672,  0.3087,  0.0972,  ...,  0.8252, -0.3511,  1.2537],
│   │                         [ 1.2895,  0.9178, -0.4619,  ...,  1.7720,  0.5443,  0.9298],
│   │                         [ 0.4141,  0.8033, -0.3142,  ...,  0.5016, -0.4243,  2.4158],
│   │                         ...,
│   │                         [-2.2795,  0.9444,  0.3314,  ...,  0.1576, -0.9113, -0.3169],
│   │                         [ 1.5382,  1.4105, -0.0833,  ..., -0.7289, -0.7923, -0.4795],
│   │                         [ 0.8599,  0.2141,  0.1497,  ...,  0.4588, -0.5701, -1.1283]]],
│   │               
│   │               
│   │                       [[[-0.2001,  1.9614,  0.1852,  ..., -0.4629,  0.8015, -1.8775],
│   │                         [ 0.3143, -1.9862,  0.3923,  ..., -0.2104,  0.7495, -1.1515],
│   │                         [ 1.0600, -0.6314, -0.0457,  ..., -0.0987, -1.5044, -1.1107],
│   │                         ...,
│   │                         [ 1.2137, -0.1378, -1.0158,  ...,  1.6571, -0.2726,  0.2322],
│   │                         [ 0.0416, -1.1558,  1.0013,  ..., -1.1214,  1.5013,  0.3246],
│   │                         [-2.3091,  1.0148, -1.3923,  ..., -0.0211,  0.7656,  2.6124]],
│   │               
│   │                        [[ 0.7376,  0.2116,  2.1040,  ..., -0.3533, -1.4093,  0.3420],
│   │                         [-0.6422,  0.3422, -0.2772,  ..., -1.3639,  0.5644, -0.9539],
│   │                         [-0.7717, -0.4928,  0.2333,  ..., -0.4962, -1.6556,  1.2008],
│   │                         ...,
│   │                         [ 0.9790, -0.0728, -2.6756,  ...,  0.2060, -1.4609, -0.9001],
│   │                         [ 0.2462, -1.0579, -0.9658,  ...,  0.6712,  0.0615, -0.1638],
│   │                         [ 0.7070,  0.9415, -0.4385,  ..., -0.8157, -1.0987, -0.5419]],
│   │               
│   │                        [[ 0.7663, -1.1431,  0.2525,  ...,  0.1915,  0.1587, -0.6593],
│   │                         [ 0.5508,  0.8743, -0.1899,  ..., -0.2319,  0.6989,  1.7095],
│   │                         [ 1.3306,  1.1043, -0.8372,  ...,  0.2038,  0.6524,  1.2261],
│   │                         ...,
│   │                         [-0.8433,  0.2687, -0.6223,  ...,  0.2569, -1.4367,  0.4794],
│   │                         [ 0.1051,  0.5794,  1.7141,  ..., -0.8876, -0.2867,  0.1891],
│   │                         [-1.3062,  0.8728, -1.5480,  ...,  0.3719, -0.7660, -1.5097]]],
│   │               
│   │               
│   │                       [[[ 0.3837,  0.1884,  1.8520,  ..., -1.1670, -0.2771,  0.0285],
│   │                         [ 2.3019, -0.4867, -0.5563,  ...,  1.4447, -0.7029, -0.5985],
│   │                         [ 1.3024, -0.0033,  0.7293,  ..., -0.8781, -0.4363,  0.9175],
│   │                         ...,
│   │                         [ 1.0268,  1.0299, -0.4278,  ..., -0.9385,  1.5609,  0.9430],
│   │                         [ 0.4735, -0.4662, -0.0858,  ...,  0.8537, -0.2668,  1.8159],
│   │                         [ 0.6604,  0.3323,  2.2560,  ..., -0.3285, -0.6398, -0.0751]],
│   │               
│   │                        [[ 0.1686, -0.0583,  1.4614,  ..., -0.1291, -0.3572, -1.1013],
│   │                         [-0.7564, -0.3031, -0.1124,  ...,  0.2009, -0.1057,  0.0472],
│   │                         [ 0.0804,  0.1531, -0.2332,  ..., -2.6403, -0.8711,  0.2703],
│   │                         ...,
│   │                         [-0.1483,  1.5654,  0.8513,  ...,  0.8361,  1.3659,  1.2603],
│   │                         [-0.7314,  0.4022,  0.8142,  ...,  0.3308, -1.5659,  1.4795],
│   │                         [ 0.4246,  0.7734, -0.4845,  ..., -1.2918,  0.5764, -1.4315]],
│   │               
│   │                        [[ 0.1185,  0.9395, -0.8611,  ..., -1.2157, -0.0884,  0.3314],
│   │                         [-0.4321, -1.8121,  0.3759,  ...,  0.9538, -2.8811, -0.6340],
│   │                         [-1.4063, -0.1336, -0.2102,  ...,  0.9600,  0.1759, -1.2172],
│   │                         ...,
│   │                         [-0.0128,  1.0064, -0.4070,  ...,  0.7015, -0.5696, -1.0145],
│   │                         [-0.8918, -0.5549,  0.3510,  ...,  1.0406, -0.5270,  0.2893],
│   │                         [-0.8304, -1.4494, -1.2213,  ...,  0.3438,  1.0227,  0.6940]]]])
│   └── 'scalar' --> tensor([[ 0.2119,  0.1033,  0.6893,  1.0892,  0.6119, -2.2826,  3.1246,  1.8825,
│                              1.2022, -1.3559,  0.4848,  0.9887],
│                            [-1.1223, -0.1629,  1.8560,  1.1081,  0.2372, -1.6475,  0.7757, -0.1036,
│                              0.5017, -1.5180, -2.0523, -0.2046],
│                            [-0.1956, -0.3936,  1.0658,  0.3773, -0.4860, -2.3401, -0.9749, -0.6971,
│                              0.2317, -0.2357,  0.4269,  1.8999],
│                            [ 2.1771, -0.8896,  1.9440,  0.3170,  0.3320, -0.1950, -0.4934, -1.8298,
│                              0.2721, -0.0442, -0.6164,  0.0381]])
└── 'reward' --> tensor([[0.1355],
                         [0.6663],
                         [0.4162],
                         [0.8796]])

This code looks much simpler and clearer.