Stack Structured Data

When the tensors form the tree structures, they are often needed to be stacked together, like the torch.stack() implemented in torch.

Stack With Native PyTorch API

Here is the common code implement with native pytorch API.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch

B = 4


def get_item():
    return {
        'obs': {
            'scalar': torch.randn(12),
            'image': torch.randn(3, 32, 32),
        },
        'action': torch.randint(0, 10, size=(1,)),
        'reward': torch.rand(1),
        'done': False,
    }


data = [get_item() for _ in range(B)]


# execute `stack` op
def stack(data, dim):
    elem = data[0]
    if isinstance(elem, torch.Tensor):
        return torch.stack(data, dim)
    elif isinstance(elem, dict):
        return {k: stack([item[k] for item in data], dim) for k in elem.keys()}
    elif isinstance(elem, bool):
        return torch.BoolTensor(data)
    else:
        raise TypeError("not support elem type: {}".format(type(elem)))


stacked_data = stack(data, dim=0)
# validate
print(stacked_data)
assert stacked_data['obs']['image'].shape == (B, 3, 32, 32)
assert stacked_data['action'].shape == (B, 1)
assert stacked_data['reward'].shape == (B, 1)
assert stacked_data['done'].shape == (B,)
assert stacked_data['done'].dtype == torch.bool

The output should be like below, and the assertion statements can be all passed.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
{'obs': {'scalar': tensor([[-0.0474, -0.6267, -1.3194,  1.3442, -2.4609, -0.5388,  0.0422, -1.0698,
         -1.1849, -0.3751, -1.1569,  1.3876],
        [-0.2088, -0.7598, -1.1054, -1.1143, -0.9409,  0.1055,  1.2774,  0.7490,
          0.1556,  1.1857,  0.2342,  0.9975],
        [ 0.7275,  0.0235, -1.3393,  0.8602,  1.8367, -0.5077, -0.8846,  1.8345,
         -0.3257, -1.1880,  0.0906, -0.4038],
        [-0.5391, -0.6842,  0.8453, -0.5388,  0.1375, -0.0336, -0.3076, -0.4212,
          0.8219,  1.1042,  0.1134, -0.5790]]), 'image': tensor([[[[-0.2751,  0.2849, -0.4398,  ..., -0.8361,  1.4208, -1.1693],
          [-1.3456,  0.8265, -1.4024,  ...,  1.0499, -1.1010, -1.9177],
          [-0.2765,  0.4968,  0.4232,  ...,  3.3506,  0.1763, -2.4639],
          ...,
          [ 1.8089, -0.0888,  0.2680,  ...,  0.8729, -0.3355, -0.7094],
          [-0.0260,  2.3524, -0.5009,  ...,  1.0931,  0.7457, -0.1578],
          [ 0.1746, -0.2618, -0.7880,  ..., -0.7283, -0.5056, -0.6450]],

         [[-2.4831,  0.6116, -0.2278,  ..., -0.2452, -0.6157,  1.0735],
          [-0.1020, -2.2370, -1.6688,  ...,  0.4849, -0.5533,  0.0878],
          [-0.3830,  0.1553,  1.2283,  ...,  0.0569, -1.4088, -0.4156],
          ...,
          [-0.7294,  0.8529,  0.7134,  ..., -0.9616, -1.5958, -1.4483],
          [-0.3974,  0.8537, -0.0476,  ..., -1.4329,  0.9812,  0.1319],
          [-0.6973,  1.5629, -0.2870,  ...,  1.4818, -1.0325, -0.2504]],

         [[-0.2226, -1.3849, -2.1680,  ..., -0.1532,  0.7338, -0.1767],
          [ 0.0933, -0.0511,  0.5087,  ..., -0.1556,  0.9099,  0.2634],
          [-0.0989, -0.7362,  1.0113,  ...,  1.1245,  1.1827, -0.4475],
          ...,
          [-0.1442, -0.5391, -2.1120,  ..., -0.6900, -0.3014, -0.1759],
          [-1.3388,  0.2005,  0.2434,  ...,  1.6742, -0.5577,  0.2453],
          [-0.3675, -0.5611,  0.6748,  ...,  0.1551,  0.8148,  0.7405]]],


        [[[-0.3502,  0.8338,  0.2708,  ..., -0.0123,  1.0926,  0.6610],
          [-0.0094, -1.0525,  0.3287,  ...,  0.7370,  0.8891,  1.0431],
          [ 0.1888,  0.3276,  0.5379,  ..., -0.7131,  1.9821,  0.7732],
          ...,
          [ 0.1138, -1.6056,  0.6366,  ..., -1.2302, -0.3910, -0.0265],
          [-0.4184, -0.9246,  0.6108,  ...,  0.2189,  0.3337, -0.2046],
          [-2.4134,  1.9054, -0.0702,  ...,  0.3756,  0.5060, -0.1964]],

         [[-0.5459,  0.7165,  0.7111,  ..., -0.3358, -2.8014, -0.3046],
          [ 0.3116,  0.9859,  0.2579,  ..., -1.6148,  1.8107, -0.1606],
          [ 0.3406,  2.0406,  0.3973,  ..., -0.4990,  0.4736,  0.5269],
          ...,
          [-0.2485, -0.4891,  2.5243,  ...,  1.8118, -0.0243,  0.6212],
          [ 1.1387, -0.1737,  1.0413,  ...,  0.8144, -0.3411,  0.1024],
          [-1.0099, -0.6349,  2.6201,  ...,  0.0811,  0.9929,  2.0772]],

         [[-0.7934, -1.2366,  0.5791,  ..., -0.0818,  0.4937,  0.4859],
          [-0.1064,  1.4198, -0.1747,  ...,  0.4362, -1.2629, -0.5691],
          [ 0.1293, -2.0610,  1.9908,  ...,  1.1718,  0.8963, -0.5334],
          ...,
          [ 0.3928, -0.0289,  2.1853,  ...,  0.2185,  0.3709,  0.1734],
          [-0.6801,  0.0650,  0.3967,  ...,  1.0298,  1.6349, -1.1922],
          [-0.6088, -0.7305,  0.6763,  ...,  0.1512,  1.6489, -1.4524]]],


        [[[-0.0420,  0.4647, -1.5241,  ...,  0.8838,  0.1826,  1.7378],
          [ 0.7445,  0.7431, -2.3356,  ..., -0.0118,  0.4899, -0.8337],
          [ 1.0866,  0.1677, -0.1634,  ...,  0.4219, -0.1374, -0.0456],
          ...,
          [ 0.8754,  0.5750,  1.4385,  ...,  0.7784,  0.4058,  0.4234],
          [ 0.6471, -0.3631,  0.6372,  ..., -0.1434, -1.7048,  1.1129],
          [ 0.3831,  2.0619,  0.5491,  ...,  1.7297,  0.2793, -0.7170]],

         [[ 1.5339, -1.3078, -1.3366,  ..., -0.6555,  0.2967,  0.2050],
          [ 2.0565, -0.6604,  1.7751,  ..., -0.9680,  0.0179, -0.7194],
          [ 1.2867, -0.6521,  0.8652,  ..., -0.3076,  0.0210, -1.6127],
          ...,
          [-0.9632, -0.2090, -0.6809,  ..., -0.9559,  2.3938, -0.5868],
          [-0.6778,  0.9912,  0.1721,  ...,  0.2588,  0.3010, -0.7527],
          [-0.5230,  0.3787, -0.5170,  ...,  0.3577,  2.4580,  0.5111]],

         [[-0.3902,  1.6515, -1.5147,  ..., -1.2826, -0.6861, -0.3882],
          [-1.8575,  0.2926, -0.1161,  ..., -0.8605, -0.6181, -1.0980],
          [-1.4343, -0.5434,  0.8706,  ..., -1.0140,  0.3919, -0.4772],
          ...,
          [ 1.6469, -0.7275, -0.9569,  ..., -1.1410,  0.2409, -0.6920],
          [-1.5096,  0.7378, -0.7718,  ...,  0.8376, -1.1837,  0.5662],
          [ 0.1438, -1.4634,  0.0567,  ...,  1.7126, -0.1545, -0.1851]]],


        [[[-1.9919, -2.0915,  1.6117,  ..., -0.1704, -0.4007,  1.0255],
          [-1.9141,  0.7429,  0.3941,  ...,  0.6496, -0.6669, -0.1110],
          [-1.2933,  0.3415,  0.6939,  ...,  0.6096,  0.1182, -2.1521],
          ...,
          [-0.3522, -0.7301,  0.5915,  ...,  1.0247,  0.4818, -2.8824],
          [-1.0129,  0.5479, -0.2106,  ..., -1.2869, -0.6228, -1.7448],
          [ 0.6500,  0.8102, -0.1323,  ..., -0.6634,  0.1760, -0.5788]],

         [[ 0.2172, -0.5834, -0.7828,  ...,  0.1899, -1.0895,  1.2113],
          [ 0.2995, -2.2824, -0.8321,  ...,  0.4185,  0.6189,  0.7172],
          [-0.2500,  0.7318, -0.7020,  ...,  0.1634, -0.7869,  0.3379],
          ...,
          [-0.8895, -0.2712, -0.4805,  ..., -0.9760, -0.2789, -1.1124],
          [ 1.8720,  0.0467, -0.1050,  ..., -0.1172, -1.0756,  1.1421],
          [ 0.5143, -1.2392,  0.8083,  ..., -1.3443,  0.7761, -1.8245]],

         [[ 0.3234, -0.2622, -0.8256,  ..., -2.1402, -0.5509, -0.7607],
          [ 0.7452,  0.6947,  0.7951,  ...,  0.0744, -0.4635, -0.6039],
          [ 0.5742,  0.3557, -0.0880,  ...,  0.7482, -0.0893,  0.4751],
          ...,
          [-1.3953,  1.3579,  0.0245,  ...,  0.0954, -0.1335,  0.9655],
          [-0.6046, -0.4598,  0.5230,  ..., -0.2763, -0.6108,  1.0981],
          [-0.3839, -0.0712, -1.4534,  ..., -0.0674, -0.7260, -0.2902]]]])}, 'action': tensor([[8],
        [3],
        [6],
        [6]]), 'reward': tensor([[0.1228],
        [0.6469],
        [0.5704],
        [0.6775]]), 'done': tensor([False, False, False, False])}

We can see that the structure process function need to be fully implemented (like the function stack). This code is actually not clear, and due to hard coding, if you need to support more data types (such as integer), you must make special modifications to the function.

Stack With TreeTensor API

The same workflow can be implemented with treetensor API like the code below.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch

import treetensor.torch as ttorch

B = 4


def get_item():
    return {
        'obs': {
            'scalar': torch.randn(12),
            'image': torch.randn(3, 32, 32),
        },
        'action': torch.randint(0, 10, size=(1,)),
        'reward': torch.rand(1),
        'done': False,
    }


data = [get_item() for _ in range(B)]
# execute `stack` op
data = [ttorch.tensor(d) for d in data]
stacked_data = ttorch.stack(data, dim=0)
# validate
print(stacked_data)
assert stacked_data.obs.image.shape == (B, 3, 32, 32)
assert stacked_data.action.shape == (B, 1)
assert stacked_data.reward.shape == (B, 1)
assert stacked_data.done.shape == (B,)
assert stacked_data.done.dtype == torch.bool

The output should be like below, and the assertion statements can be all passed as well.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
<Tensor 0x7f3d243ae190>
├── 'action' --> tensor([[7],
│                        [7],
│                        [2],
│                        [2]])
├── 'done' --> tensor([False, False, False, False])
├── 'obs' --> <Tensor 0x7f3d243ae2b0>
│   ├── 'image' --> tensor([[[[ 0.0791,  0.8872,  1.0240,  ..., -0.4558, -0.2995,  0.4920],
│   │                         [-0.4198, -1.7697, -0.6815,  ...,  0.4069,  2.3706,  0.4846],
│   │                         [ 0.2284,  0.3191,  0.4235,  ..., -0.6834,  0.1046,  0.4281],
│   │                         ...,
│   │                         [-0.0241, -0.2738,  0.0994,  ..., -0.4086,  1.3089,  1.1076],
│   │                         [-1.6397,  2.2323, -1.9866,  ..., -1.1885, -1.1003,  0.5095],
│   │                         [-0.1616,  0.4194,  0.3447,  ..., -1.4410, -1.0595, -1.3054]],
│   │               
│   │                        [[ 0.5532,  1.7987,  0.6725,  ...,  0.9116,  0.1138, -0.5087],
│   │                         [-1.2482, -0.1449, -0.6920,  ...,  0.4026, -0.3073, -1.2992],
│   │                         [-0.8882,  1.0068, -0.6378,  ..., -1.6998,  0.0469,  1.3722],
│   │                         ...,
│   │                         [-0.1030, -1.0744,  0.0287,  ..., -0.9524, -0.2263, -0.0691],
│   │                         [ 1.9076, -0.8239, -0.3963,  ..., -1.6327,  0.1308, -0.2509],
│   │                         [ 1.1278, -1.4004,  1.1440,  ..., -0.7515,  0.2153, -1.8422]],
│   │               
│   │                        [[ 0.3734,  0.2164, -0.7244,  ...,  1.2590, -0.0717, -1.1597],
│   │                         [ 0.1956,  0.6762,  0.2630,  ...,  0.3518, -0.3724,  1.9386],
│   │                         [ 0.5327, -0.3009, -0.8345,  ..., -0.4405,  1.2826,  1.3088],
│   │                         ...,
│   │                         [ 0.6457, -1.0980, -0.9002,  ..., -0.8903, -0.8940, -0.6419],
│   │                         [ 0.9915, -1.3307,  0.4343,  ..., -2.2540, -0.6882, -0.4414],
│   │                         [ 0.8994, -0.5798,  0.5876,  ..., -0.3843,  0.1621, -0.3746]]],
│   │               
│   │               
│   │                       [[[ 0.7907,  1.2970,  0.9571,  ..., -0.5582, -0.0405, -1.1310],
│   │                         [ 0.1079, -1.2159, -0.6550,  ...,  1.2299, -1.3267, -2.4603],
│   │                         [ 0.3104,  0.1867,  0.1131,  ...,  1.1507,  0.9352, -0.2045],
│   │                         ...,
│   │                         [-0.7419, -2.8449, -1.3869,  ..., -0.6817,  0.6444, -0.7880],
│   │                         [-1.4280, -0.8149, -0.2399,  ...,  0.4458,  1.4619,  0.1235],
│   │                         [-0.2581, -0.7713, -0.8334,  ...,  0.4113, -0.6817,  0.4658]],
│   │               
│   │                        [[ 0.9982, -1.3032, -0.3881,  ...,  0.9156,  0.9392,  1.6742],
│   │                         [-1.2746,  0.4936, -0.1683,  ...,  1.7532,  1.1859,  0.5688],
│   │                         [ 0.9160,  0.6627, -1.8103,  ...,  0.4949, -0.9998, -1.3738],
│   │                         ...,
│   │                         [-0.9601,  0.0071,  1.1889,  ...,  1.4923,  0.3814, -0.7239],
│   │                         [ 0.7037, -0.1911, -0.4539,  ..., -0.1205,  1.4515,  0.7515],
│   │                         [ 0.9311,  0.5985, -0.0129,  ..., -0.3412,  0.9672, -1.4772]],
│   │               
│   │                        [[ 0.5502,  0.6046,  1.2751,  ..., -2.0525, -0.0460,  1.8959],
│   │                         [ 1.1351,  0.6971, -0.5836,  ..., -0.3437, -2.7175,  0.3434],
│   │                         [-0.7622,  1.4680, -0.0728,  ...,  0.6859, -0.6063, -0.0471],
│   │                         ...,
│   │                         [-0.0408, -1.2061,  0.1526,  ..., -0.1962, -1.4930, -1.3408],
│   │                         [-0.0765,  0.6524, -0.5823,  ..., -0.4107,  0.1003,  1.1323],
│   │                         [-0.1797,  0.4413, -0.5075,  ..., -0.8266,  1.1323, -0.8279]]],
│   │               
│   │               
│   │                       [[[ 1.2919,  0.8392, -0.4312,  ...,  0.1050, -0.3016, -0.0981],
│   │                         [-0.6781,  0.0393,  0.2945,  ...,  1.8350, -1.2779,  0.3858],
│   │                         [ 1.9833, -0.8696,  2.0524,  ...,  0.2332,  1.6123, -2.2659],
│   │                         ...,
│   │                         [ 1.5953, -1.1926,  0.3764,  ..., -1.5459, -0.6217,  0.1397],
│   │                         [-0.7837, -0.3437,  0.5292,  ...,  0.7078,  0.1553, -0.2891],
│   │                         [-0.6039,  0.9226, -0.4776,  ...,  1.6451, -1.0019,  1.7044]],
│   │               
│   │                        [[ 0.0923, -0.3256, -2.1885,  ..., -0.3502, -0.9551, -1.7428],
│   │                         [ 1.3288,  1.2030,  1.3103,  ...,  1.7191,  0.2865,  0.2143],
│   │                         [-0.8596, -0.5981, -1.4051,  ..., -0.0115,  1.0409, -0.8187],
│   │                         ...,
│   │                         [-0.4436, -1.0942, -0.5818,  ...,  0.4918,  0.0229,  0.3424],
│   │                         [-1.6137, -1.7928, -0.9347,  ..., -1.6439, -0.8557,  1.1684],
│   │                         [-0.5163,  0.1093, -0.1803,  ...,  1.3647,  0.0103, -0.6510]],
│   │               
│   │                        [[ 0.5011, -1.0931,  0.9437,  ..., -0.5579, -1.9619, -0.2903],
│   │                         [ 0.7258, -0.1732, -1.5473,  ..., -0.4777,  0.2735, -1.0263],
│   │                         [ 0.4865,  1.1720, -0.6594,  ..., -0.4230,  1.5394,  1.4809],
│   │                         ...,
│   │                         [ 0.4357, -0.8231,  2.9393,  ...,  0.5074, -1.4053,  0.3930],
│   │                         [ 0.8048,  1.6432,  0.1634,  ..., -0.7126,  0.8041, -0.0140],
│   │                         [ 0.0584,  1.8815,  0.2343,  ...,  0.4873,  0.6774, -0.6150]]],
│   │               
│   │               
│   │                       [[[ 1.3604, -1.6832, -0.1336,  ...,  0.7399,  0.1142, -1.3361],
│   │                         [-1.5409,  0.6699,  0.0171,  ..., -0.3744,  0.7404, -0.4100],
│   │                         [-0.1642,  1.1593,  0.7386,  ..., -0.9751,  0.0185, -0.1900],
│   │                         ...,
│   │                         [ 0.1972, -0.5856, -0.3571,  ...,  1.3776,  0.4157, -0.1640],
│   │                         [ 0.2938,  1.2468, -1.6340,  ..., -0.7121,  0.0604,  0.3688],
│   │                         [ 1.0029, -2.3024,  0.3247,  ...,  0.0377, -0.6577, -0.2586]],
│   │               
│   │                        [[ 0.9143,  0.1542, -0.6300,  ...,  0.3324,  0.7772, -1.6399],
│   │                         [ 0.1387, -1.1285,  0.8991,  ..., -0.1102, -1.7244, -0.2450],
│   │                         [-0.1183,  0.2169, -0.1358,  ...,  0.9768, -0.9488, -0.4885],
│   │                         ...,
│   │                         [-0.3008, -0.0661, -1.0024,  ...,  0.8578, -0.2944, -1.7844],
│   │                         [ 0.3122,  1.1041,  0.9070,  ...,  1.3342, -1.9862,  0.0636],
│   │                         [-2.1131,  1.6279, -0.6652,  ..., -1.1855,  0.7817,  0.2524]],
│   │               
│   │                        [[-0.2601,  0.4819,  0.2600,  ..., -0.0747,  1.7166,  0.5367],
│   │                         [ 0.0428, -0.6690, -1.7239,  ...,  0.9543,  2.2662,  0.0496],
│   │                         [-1.3442, -0.8918, -0.0961,  ...,  2.7132, -0.1731, -0.2380],
│   │                         ...,
│   │                         [-2.8024, -0.6169, -0.6925,  ..., -0.0106,  0.3174, -1.9551],
│   │                         [-0.9647,  0.4923,  2.9837,  ..., -0.3070,  0.4370, -0.6259],
│   │                         [-0.3058, -0.3659, -0.5358,  ..., -0.6162, -1.9454, -0.3422]]]])
│   └── 'scalar' --> tensor([[-1.4763, -0.0544,  0.7519, -1.0781, -0.4852,  1.7914, -1.3496,  1.0081,
│                              0.5050,  0.9409, -0.2093, -0.3679],
│                            [ 0.5313,  0.4912, -0.3377,  1.7403, -0.6472,  0.9690, -1.1610, -0.0227,
│                             -1.4047, -0.3975,  0.8537, -1.2538],
│                            [-0.9429,  0.3886, -0.8699,  1.1482,  0.5062, -0.7958,  0.1384,  0.4363,
│                              1.6568,  2.3146, -0.8395,  1.0176],
│                            [-1.6640, -0.9070,  0.6325,  0.6311,  0.5759,  0.4054, -0.1846, -0.8301,
│                              0.0672,  2.1676,  0.1712, -0.2086]])
└── 'reward' --> tensor([[0.4076],
                         [0.9543],
                         [0.0243],
                         [0.8797]])

This code looks much simpler and clearer.