Stack Structured Data

When the tensors form the tree structures, they are often needed to be stacked together, like the torch.stack() implemented in torch.

Stack With Native PyTorch API

Here is the common code implement with native pytorch API.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch

B = 4


def get_item():
    return {
        'obs': {
            'scalar': torch.randn(12),
            'image': torch.randn(3, 32, 32),
        },
        'action': torch.randint(0, 10, size=(1,)),
        'reward': torch.rand(1),
        'done': False,
    }


data = [get_item() for _ in range(B)]


# execute `stack` op
def stack(data, dim):
    elem = data[0]
    if isinstance(elem, torch.Tensor):
        return torch.stack(data, dim)
    elif isinstance(elem, dict):
        return {k: stack([item[k] for item in data], dim) for k in elem.keys()}
    elif isinstance(elem, bool):
        return torch.BoolTensor(data)
    else:
        raise TypeError("not support elem type: {}".format(type(elem)))


stacked_data = stack(data, dim=0)
# validate
print(stacked_data)
assert stacked_data['obs']['image'].shape == (B, 3, 32, 32)
assert stacked_data['action'].shape == (B, 1)
assert stacked_data['reward'].shape == (B, 1)
assert stacked_data['done'].shape == (B,)
assert stacked_data['done'].dtype == torch.bool

The output should be like below, and the assertion statements can be all passed.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
{'obs': {'scalar': tensor([[ 1.2353,  0.1033, -0.0788,  0.3673,  1.5222, -0.1400,  0.8218,  0.9439,
         -0.3665, -0.5841, -0.1208, -0.0191],
        [ 0.9209,  0.8111, -1.4704, -1.1986,  0.0567, -0.5912, -0.2208, -0.2141,
          0.3711,  0.4984,  0.9925, -0.0610],
        [-0.7212, -1.1298, -0.9845, -1.1398, -0.2015, -0.2729,  0.3626, -0.4082,
          0.2805,  0.2625, -0.0895,  0.1263],
        [ 1.1465,  0.2838,  0.8829,  1.1629, -1.1590,  0.6518,  1.2072, -1.2327,
         -1.4515, -1.4140, -1.7307,  0.2184]]), 'image': tensor([[[[ 0.3335, -1.3065,  0.0261,  ..., -1.4714,  0.5597,  0.1360],
          [ 0.4465,  0.3327, -0.5750,  ..., -0.5399, -0.9296, -0.8244],
          [ 0.9897, -0.1005, -0.8503,  ..., -0.4076, -0.7188,  0.4234],
          ...,
          [-0.1135,  0.2755, -1.4059,  ...,  0.1674,  0.2051,  0.1162],
          [ 0.2281, -0.7233, -0.8744,  ..., -0.9613,  0.5847,  1.5659],
          [-0.1295,  0.5437,  0.5021,  ..., -0.5251, -1.9206,  0.0096]],

         [[-1.3309, -1.7988,  0.2972,  ...,  0.6385,  0.9122,  0.1015],
          [ 1.8630, -0.9602, -0.4700,  ..., -0.2996,  0.2686, -1.4156],
          [ 0.2450,  1.2799,  0.8750,  ..., -0.3291,  1.5261, -0.7365],
          ...,
          [ 0.0395,  0.8058,  1.9806,  ..., -0.0413,  0.0883, -0.9220],
          [-0.2305,  0.1159, -0.7544,  ..., -2.0230, -1.4943, -1.1350],
          [ 0.1117, -0.5315,  0.5944,  ...,  0.4788,  0.0436, -0.5985]],

         [[-0.2830,  2.0172,  0.0190,  ..., -1.4725, -0.3333, -0.2303],
          [ 1.5592,  1.4137,  1.6759,  ..., -0.2728, -1.6782,  1.0658],
          [-1.7359, -1.1321,  0.4097,  ...,  0.4581, -0.4147,  0.4442],
          ...,
          [-1.3449, -1.4193, -0.8890,  ...,  0.6328, -0.5590, -0.4163],
          [-1.1040,  2.1431,  2.3058,  ...,  0.1790,  0.5301,  1.4601],
          [-1.4273,  0.1942, -0.7758,  ...,  0.1424,  1.1765, -0.4766]]],


        [[[ 0.9146,  0.3598, -0.2676,  ..., -0.4067,  0.2047, -0.4180],
          [-0.4818,  0.0106, -1.4120,  ..., -0.1858, -0.6795, -1.4480],
          [ 0.8597,  1.4061,  1.1302,  ..., -0.8007,  0.2118, -0.5153],
          ...,
          [ 0.7371,  0.6654, -0.6888,  ...,  0.1592,  0.7637,  0.4083],
          [-1.0556,  0.8332,  0.4223,  ..., -1.3296, -0.2946,  0.1338],
          [ 0.5926, -0.2620,  0.8530,  ..., -1.0015, -0.3341, -1.6636]],

         [[-2.2510, -0.3986,  0.2453,  ...,  0.7056,  1.7304,  1.3177],
          [-0.0552, -0.2365,  0.2881,  ..., -0.4662, -0.0364, -0.0245],
          [ 0.1124,  1.0734,  0.7138,  ..., -2.1025,  0.8246,  0.7277],
          ...,
          [-0.7811,  0.4033, -0.1675,  ..., -0.7742, -0.7077, -0.0896],
          [-1.1604,  1.2810,  0.8874,  ...,  0.3938, -0.2955, -0.0441],
          [-0.8316,  0.0604,  0.0685,  ...,  1.2924, -0.6923, -1.1963]],

         [[ 0.6020, -0.7626,  1.1469,  ...,  1.8116, -0.3537,  0.7856],
          [-1.4319,  1.0362,  1.6734,  ...,  1.4811, -1.0094,  0.1670],
          [-0.6686,  0.9627, -0.2125,  ..., -1.4370, -0.4441,  0.0290],
          ...,
          [-0.9022, -0.6266, -0.5947,  ...,  2.8485,  0.0414,  1.0336],
          [ 0.5311, -1.5195,  2.5945,  ..., -1.5883, -0.1427, -0.1882],
          [-1.1037, -1.5452,  0.7929,  ..., -0.9975,  1.2051, -0.9413]]],


        [[[-0.5478, -0.4352,  1.0987,  ..., -0.0724, -1.2327, -0.7871],
          [ 0.9770,  1.7980, -0.6592,  ...,  0.6271,  0.2800,  0.3116],
          [-0.5589, -1.6918, -0.1333,  ..., -0.8579,  0.8355, -0.7909],
          ...,
          [ 1.5151,  1.2520, -1.0537,  ..., -1.4403, -0.8065,  0.2727],
          [ 0.1922, -1.8616,  0.0978,  ...,  0.5034,  1.5134, -1.5712],
          [-0.6260, -0.4096, -0.8920,  ..., -0.6702, -0.7610,  0.5073]],

         [[ 1.6124,  0.6425,  0.9418,  ..., -0.3850, -0.2208, -1.3104],
          [ 0.0473,  0.9078,  0.3801,  ...,  1.9370,  0.2786,  0.1005],
          [-0.3935, -2.0966,  0.0998,  ...,  0.6551,  0.0832,  0.2786],
          ...,
          [ 1.0032,  0.3113,  1.0083,  ...,  0.7403, -0.8312, -2.0272],
          [ 1.0482, -2.1290, -0.7063,  ...,  0.3639,  1.0533,  1.8313],
          [-1.4896, -0.1280, -0.9721,  ...,  0.3050, -0.2711, -1.1549]],

         [[ 1.1161,  1.1090,  2.0915,  ...,  0.1841, -0.7887, -2.2378],
          [ 0.5129,  1.9897, -0.4782,  ...,  0.2218, -1.0197, -1.0471],
          [-0.1897,  0.3928,  0.3233,  ...,  0.0418,  1.0819, -1.3346],
          ...,
          [-0.5513, -0.5820, -0.3209,  ..., -0.3596, -2.1402, -0.4233],
          [-1.9409,  0.1363, -0.1097,  ...,  0.5669,  0.5416, -0.9868],
          [ 0.3515, -0.1241,  0.1267,  ...,  1.1888, -1.0528, -1.1948]]],


        [[[ 0.1024, -1.1462, -0.4391,  ..., -0.1048,  0.5747, -0.7521],
          [ 0.0155,  0.6030,  1.1195,  ..., -2.1208,  0.6655,  1.2259],
          [ 1.2126,  0.6675,  0.2437,  ..., -0.6546,  0.0316, -0.6425],
          ...,
          [ 1.5437, -0.4587, -0.1753,  ...,  0.0680, -3.1304,  0.4386],
          [ 0.3996, -0.9759, -1.4080,  ..., -1.6013,  0.6383,  0.8671],
          [-0.5478, -0.0450, -0.0663,  ..., -1.0068,  0.3400, -1.2715]],

         [[ 0.0880,  0.0899, -0.7186,  ..., -1.3469, -2.2685,  1.1066],
          [-1.4203,  0.3606,  1.9161,  ..., -0.3989, -0.8464, -0.1847],
          [ 0.0992, -1.0557, -0.1339,  ..., -0.2093,  1.1721, -1.8974],
          ...,
          [-1.9481,  2.9970,  2.6994,  ..., -0.3116,  0.7401,  1.5904],
          [ 0.1872, -0.8145,  1.3964,  ...,  0.7552, -0.6232, -0.9321],
          [ 0.4413,  1.7538,  0.5035,  ...,  2.1893, -1.2340, -0.0737]],

         [[ 0.2138, -0.8470,  0.7586,  ..., -2.7832, -0.2185,  0.0269],
          [ 0.9355, -1.3381, -0.4485,  ..., -1.9584,  0.2992,  1.0370],
          [ 0.3561, -0.5251,  0.2807,  ..., -0.3152,  0.1020,  0.3501],
          ...,
          [ 1.4126, -1.9330, -0.3176,  ..., -0.1806,  0.1100,  0.8590],
          [-0.5786, -0.3379, -0.6208,  ...,  0.0932,  1.2057, -0.3077],
          [-0.2188,  0.7052, -0.4103,  ..., -1.2760,  0.3201,  0.5910]]]])}, 'action': tensor([[8],
        [6],
        [3],
        [5]]), 'reward': tensor([[0.7108],
        [0.8147],
        [0.3506],
        [0.5117]]), 'done': tensor([False, False, False, False])}

We can see that the structure process function need to be fully implemented (like the function stack). This code is actually not clear, and due to hard coding, if you need to support more data types (such as integer), you must make special modifications to the function.

Stack With TreeTensor API

The same workflow can be implemented with treetensor API like the code below.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch

import treetensor.torch as ttorch

B = 4


def get_item():
    return {
        'obs': {
            'scalar': torch.randn(12),
            'image': torch.randn(3, 32, 32),
        },
        'action': torch.randint(0, 10, size=(1,)),
        'reward': torch.rand(1),
        'done': False,
    }


data = [get_item() for _ in range(B)]
# execute `stack` op
data = [ttorch.tensor(d) for d in data]
stacked_data = ttorch.stack(data, dim=0)
# validate
print(stacked_data)
assert stacked_data.obs.image.shape == (B, 3, 32, 32)
assert stacked_data.action.shape == (B, 1)
assert stacked_data.reward.shape == (B, 1)
assert stacked_data.done.shape == (B,)
assert stacked_data.done.dtype == torch.bool

The output should be like below, and the assertion statements can be all passed as well.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
<Tensor 0x7f7ed7e2e190>
├── 'action' --> tensor([[1],
│                        [4],
│                        [1],
│                        [9]])
├── 'done' --> tensor([False, False, False, False])
├── 'obs' --> <Tensor 0x7f7ed7e2e2b0>
│   ├── 'image' --> tensor([[[[-6.5067e-01, -5.7017e-02,  1.3348e+00,  ..., -1.1105e+00,
│   │                          -1.9789e+00, -4.5939e-01],
│   │                         [-1.4942e+00, -3.8909e-01,  1.9084e+00,  ..., -7.9206e-01,
│   │                           2.9019e-01, -5.6907e-01],
│   │                         [ 3.0546e-01, -1.9415e+00, -6.3637e-01,  ..., -1.3183e-01,
│   │                          -1.2035e+00, -2.8270e-01],
│   │                         ...,
│   │                         [ 7.0679e-02,  1.9211e+00,  4.4281e-01,  ..., -1.8972e+00,
│   │                           9.8271e-01,  1.3561e+00],
│   │                         [ 3.8094e-01, -1.8031e+00,  2.3186e-02,  ..., -6.9736e-01,
│   │                          -6.8178e-01, -5.6575e-01],
│   │                         [-2.6018e-01,  1.1391e-01,  1.7080e+00,  ...,  2.5416e-01,
│   │                           1.1239e-01,  8.5587e-01]],
│   │               
│   │                        [[ 3.5841e-02, -1.0841e+00, -8.1357e-01,  ...,  1.0408e+00,
│   │                           1.5137e+00,  4.6669e-01],
│   │                         [-3.7593e-01, -4.4182e-03,  1.2054e+00,  ...,  7.3049e-01,
│   │                           2.3633e-01, -6.2146e-01],
│   │                         [-7.2992e-01, -1.8667e+00,  1.1283e+00,  ..., -8.1484e-01,
│   │                          -1.7670e+00,  6.0650e-01],
│   │                         ...,
│   │                         [ 6.5336e-01,  1.2182e+00, -3.4346e-01,  ...,  1.8065e+00,
│   │                           5.9048e-02,  2.7013e-01],
│   │                         [ 8.9180e-01,  4.9900e-01, -9.2269e-01,  ...,  1.1265e-01,
│   │                           6.3621e-01,  2.3077e-01],
│   │                         [-1.0900e+00, -2.7075e-01, -7.2448e-01,  ..., -1.6313e+00,
│   │                          -2.2749e-01, -3.8969e-01]],
│   │               
│   │                        [[-6.6002e-01,  8.0073e-01, -1.8583e-01,  ..., -1.6530e+00,
│   │                          -1.0803e+00, -7.5254e-01],
│   │                         [-4.9484e-02, -1.5223e+00,  1.2138e+00,  ..., -8.1763e-01,
│   │                           1.3795e+00,  4.2996e-02],
│   │                         [ 2.2986e+00,  1.1630e+00, -8.9653e-02,  ...,  5.6633e-01,
│   │                          -7.4397e-02, -3.8360e-01],
│   │                         ...,
│   │                         [-2.9382e-01,  1.2292e+00, -2.5730e-01,  ...,  6.0046e-02,
│   │                           1.7697e-01, -2.4991e-01],
│   │                         [-1.7776e+00, -6.9649e-01, -1.1052e+00,  ..., -1.3405e+00,
│   │                          -6.1992e-01,  4.3417e-01],
│   │                         [-6.5567e-01,  1.0582e+00,  2.8457e+00,  ...,  1.0724e+00,
│   │                          -1.5049e+00, -1.7024e-01]]],
│   │               
│   │               
│   │                       [[[-2.5400e-01, -1.3619e+00,  3.8326e-01,  ..., -5.2720e-01,
│   │                           1.0081e+00,  8.1495e-01],
│   │                         [ 1.2168e+00, -2.4751e-01, -1.4769e+00,  ..., -9.0399e-01,
│   │                          -3.4580e-01, -3.7792e-01],
│   │                         [-4.0533e-01,  8.9454e-01,  9.4218e-01,  ..., -6.3098e-04,
│   │                          -1.8149e+00, -1.8157e+00],
│   │                         ...,
│   │                         [-1.2883e+00,  1.7482e+00, -7.9928e-01,  ..., -1.2684e-01,
│   │                          -1.9047e+00,  5.3503e-01],
│   │                         [-1.0586e+00,  1.6642e+00, -6.9757e-01,  ...,  2.0572e-01,
│   │                          -1.5871e+00, -1.3664e+00],
│   │                         [-1.3114e+00, -5.6127e-03,  2.4721e+00,  ...,  9.4424e-02,
│   │                           4.2446e-01,  6.6568e-01]],
│   │               
│   │                        [[ 6.6559e-01, -9.1791e-02, -3.9473e-01,  ..., -1.4638e+00,
│   │                          -2.6762e-01,  2.3911e-02],
│   │                         [-2.2307e-01,  3.9621e-01,  3.2369e-01,  ...,  3.9045e-01,
│   │                          -8.8678e-01, -4.9955e-01],
│   │                         [-6.0989e-01, -1.3537e-01,  9.7192e-01,  ...,  4.9585e-01,
│   │                           5.0144e-01,  1.3295e+00],
│   │                         ...,
│   │                         [ 6.1883e-01, -5.8183e-01,  2.4139e-01,  ..., -1.1012e+00,
│   │                          -1.4963e+00,  5.0656e-01],
│   │                         [ 4.7936e-02, -8.1351e-01,  1.8778e+00,  ...,  3.4093e-01,
│   │                          -5.8294e-01, -1.0132e+00],
│   │                         [-7.1835e-01,  3.5500e-01, -2.0428e+00,  ..., -6.5871e-01,
│   │                           4.6334e-01, -2.6424e-01]],
│   │               
│   │                        [[-4.9320e-01,  2.5458e-01, -3.6310e-01,  ...,  5.8579e-01,
│   │                           6.2967e-02,  2.6136e-01],
│   │                         [ 2.7235e-01, -6.6473e-01, -5.5563e-01,  ..., -2.2000e-01,
│   │                           3.5085e-01,  1.5374e+00],
│   │                         [-6.6595e-01, -4.9000e-01,  1.4402e+00,  ..., -2.0986e+00,
│   │                          -1.0589e+00, -7.6051e-02],
│   │                         ...,
│   │                         [ 1.0236e+00, -2.4821e-01,  1.6759e+00,  ...,  6.8597e-01,
│   │                          -2.0588e-01,  7.0033e-01],
│   │                         [ 4.5180e-01,  4.4919e-01, -8.9658e-01,  ...,  8.5825e-01,
│   │                          -2.1104e-01,  3.2880e-01],
│   │                         [-1.8963e-01, -8.5202e-01, -7.8316e-01,  ..., -1.2966e+00,
│   │                           5.9421e-01,  7.3266e-01]]],
│   │               
│   │               
│   │                       [[[ 2.5090e-01,  3.1097e+00, -2.0726e+00,  ..., -1.1221e+00,
│   │                          -1.7222e+00, -6.2985e-01],
│   │                         [-1.0651e-01, -3.3995e-01,  2.4290e-01,  ...,  2.1829e-01,
│   │                          -2.1395e+00, -2.4241e-01],
│   │                         [-6.0132e-01, -1.4614e-01, -4.9323e-01,  ..., -3.9859e-02,
│   │                           6.2971e-01, -1.0381e+00],
│   │                         ...,
│   │                         [-1.8512e+00,  1.9662e-01,  1.3120e-01,  ...,  1.6327e+00,
│   │                          -4.5661e-01, -4.1747e-01],
│   │                         [ 1.0196e+00,  1.5984e+00, -1.9554e-01,  ...,  1.0019e+00,
│   │                           3.0147e-01,  1.0334e-01],
│   │                         [ 1.9358e+00,  1.2909e-01, -7.1994e-02,  ...,  1.5809e+00,
│   │                           3.6263e-01,  4.9806e-01]],
│   │               
│   │                        [[-8.1348e-01, -6.3421e-02, -5.9149e-01,  ..., -1.4494e+00,
│   │                           7.4378e-01, -2.7964e+00],
│   │                         [ 1.0237e+00, -2.5409e-02, -3.2456e-01,  ...,  4.6357e-01,
│   │                          -2.4108e-01, -4.9019e-01],
│   │                         [-6.7226e-02, -1.6025e-01, -1.1548e+00,  ...,  8.7820e-01,
│   │                           1.8348e-01,  8.4725e-01],
│   │                         ...,
│   │                         [ 1.2021e+00,  2.6264e-01,  1.2604e+00,  ..., -5.1962e-01,
│   │                          -1.0610e+00, -3.1482e-01],
│   │                         [-6.1719e-01, -1.6781e+00, -9.1082e-01,  ...,  1.1717e+00,
│   │                           6.1630e-01,  8.7576e-01],
│   │                         [-8.8643e-01, -6.1786e-01, -2.6284e-01,  ...,  7.6767e-01,
│   │                          -2.7001e-01, -2.6676e-02]],
│   │               
│   │                        [[ 1.4624e+00, -2.0786e+00, -8.3197e-01,  ..., -5.8758e-01,
│   │                           9.5089e-01,  9.9594e-01],
│   │                         [ 5.1441e-01, -7.3228e-01, -1.0505e+00,  ..., -4.6881e-01,
│   │                          -8.2299e-01,  1.4207e-01],
│   │                         [ 6.9117e-01,  7.7686e-01,  1.5675e+00,  ...,  3.6168e-01,
│   │                           1.1759e+00, -6.9354e-01],
│   │                         ...,
│   │                         [-4.3398e-01,  2.2861e-01, -1.6391e+00,  ..., -3.1763e-01,
│   │                           1.2390e+00, -9.0526e-01],
│   │                         [-7.5701e-01,  2.8755e-01,  2.7837e-01,  ..., -3.9153e-01,
│   │                           1.0674e-01,  9.4768e-01],
│   │                         [ 8.9926e-01, -1.7745e+00, -1.7205e+00,  ..., -4.4676e-01,
│   │                          -3.1382e-01, -7.5050e-01]]],
│   │               
│   │               
│   │                       [[[ 4.9637e-01,  8.2664e-01,  1.4861e+00,  ..., -2.0576e-01,
│   │                           1.9123e-02, -2.4364e-01],
│   │                         [-4.0719e-01, -9.2935e-01, -1.1414e+00,  ..., -2.2103e-01,
│   │                          -3.1725e-01, -1.0175e+00],
│   │                         [ 5.1094e-02,  2.7139e-01,  3.0121e-01,  ...,  3.4371e-01,
│   │                           2.9004e-01, -4.0349e-01],
│   │                         ...,
│   │                         [-7.1851e-01, -1.0919e+00,  1.0409e+00,  ..., -8.4823e-02,
│   │                           1.7234e+00, -4.6345e-02],
│   │                         [-1.0423e+00, -1.2923e+00,  2.4411e-01,  ...,  8.3934e-01,
│   │                          -8.7717e-01, -5.0312e-01],
│   │                         [ 1.2418e+00,  1.5908e+00,  4.8031e-01,  ...,  5.8055e-01,
│   │                          -8.5589e-01,  2.6312e-01]],
│   │               
│   │                        [[-1.1426e-01, -6.9286e-02,  4.4598e-01,  ...,  1.3482e+00,
│   │                          -6.7631e-01,  1.4084e+00],
│   │                         [ 2.2362e-01,  6.1128e-01,  1.4168e+00,  ..., -1.7923e+00,
│   │                          -3.3763e-01, -7.1333e-01],
│   │                         [-8.8015e-01, -2.0147e+00, -6.1399e-01,  ...,  3.8723e-01,
│   │                          -8.1193e-01,  1.9102e+00],
│   │                         ...,
│   │                         [ 8.8305e-01,  5.1360e-01, -1.0769e+00,  ...,  1.6497e+00,
│   │                           2.7970e-01,  6.3625e-02],
│   │                         [-1.4371e+00,  1.3043e+00,  9.3124e-03,  ...,  4.7523e-02,
│   │                           1.9467e-01,  1.5000e+00],
│   │                         [-1.0681e+00, -1.5133e+00,  1.9540e+00,  ...,  9.8661e-02,
│   │                          -5.2596e-01, -1.7673e+00]],
│   │               
│   │                        [[-7.5222e-01,  1.9047e+00, -1.0827e+00,  ..., -1.3944e-02,
│   │                          -5.5678e-01,  2.5135e+00],
│   │                         [-3.5125e-01, -1.2564e-01,  8.8505e-01,  ...,  7.7109e-01,
│   │                          -1.7165e+00, -3.7643e-01],
│   │                         [-1.0743e+00, -8.0511e-01,  8.4417e-01,  ...,  1.2801e+00,
│   │                           2.3557e-01, -8.8140e-02],
│   │                         ...,
│   │                         [-6.7643e-01, -3.3399e-02,  2.8724e-01,  ..., -4.1695e-01,
│   │                          -1.0534e+00,  6.8937e-01],
│   │                         [-1.1684e+00,  3.8485e-01, -1.1519e+00,  ...,  3.3425e-01,
│   │                           1.1960e+00, -3.4992e-01],
│   │                         [-7.5796e-01, -9.3403e-01,  3.3890e-01,  ..., -7.2021e-01,
│   │                           5.9619e-01,  1.7029e-01]]]])
│   └── 'scalar' --> tensor([[ 9.7821e-01, -7.0237e-01,  5.5582e-01, -1.3539e+00,  9.1369e-01,
│                              1.9138e+00, -4.9497e-01, -5.3560e-01,  6.8744e-01, -1.0841e+00,
│                             -1.3985e+00, -1.6926e-01],
│                            [ 1.6206e+00,  2.4936e-01,  2.4617e-01, -2.5374e+00,  1.0992e+00,
│                             -1.9476e+00, -1.2746e-01,  7.3098e-01,  2.4362e-01, -6.7332e-01,
│                             -2.0711e+00, -1.2680e+00],
│                            [ 1.0915e-01,  9.4545e-01,  8.3464e-01,  1.3476e-01, -4.2530e-01,
│                             -9.1630e-01,  1.3294e+00, -7.9084e-01, -2.5933e-01, -7.0110e-01,
│                              1.0540e+00, -2.0960e-03],
│                            [ 1.0235e+00,  4.6681e-01,  6.9040e-01, -1.5396e-01, -1.0360e+00,
│                              2.8806e-01,  1.7229e+00,  1.4992e+00,  1.1843e+00, -6.4562e-01,
│                             -1.4174e+00, -5.8246e-01]])
└── 'reward' --> tensor([[0.0512],
                         [0.4139],
                         [0.6208],
                         [0.2947]])

This code looks much simpler and clearer.