Stack Structured Data

When the tensors form the tree structures, they are often needed to be stacked together, like the torch.stack() implemented in torch.

Stack With Native PyTorch API

Here is the common code implement with native pytorch API.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch

B = 4


def get_item():
    return {
        'obs': {
            'scalar': torch.randn(12),
            'image': torch.randn(3, 32, 32),
        },
        'action': torch.randint(0, 10, size=(1,)),
        'reward': torch.rand(1),
        'done': False,
    }


data = [get_item() for _ in range(B)]


# execute `stack` op
def stack(data, dim):
    elem = data[0]
    if isinstance(elem, torch.Tensor):
        return torch.stack(data, dim)
    elif isinstance(elem, dict):
        return {k: stack([item[k] for item in data], dim) for k in elem.keys()}
    elif isinstance(elem, bool):
        return torch.BoolTensor(data)
    else:
        raise TypeError("not support elem type: {}".format(type(elem)))


stacked_data = stack(data, dim=0)
# validate
print(stacked_data)
assert stacked_data['obs']['image'].shape == (B, 3, 32, 32)
assert stacked_data['action'].shape == (B, 1)
assert stacked_data['reward'].shape == (B, 1)
assert stacked_data['done'].shape == (B,)
assert stacked_data['done'].dtype == torch.bool

The output should be like below, and the assertion statements can be all passed.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
{'obs': {'scalar': tensor([[-1.5359, -1.8637,  0.5147,  0.2885,  3.0232, -1.5747, -0.9136, -0.5642,
         -0.1107, -0.6437, -0.6876, -2.1037],
        [-0.4803,  0.5412, -0.9797,  1.3545, -1.6375, -0.1112, -0.6469, -0.1644,
         -0.1071, -0.7095,  1.3717, -0.8619],
        [ 0.0595,  1.1289, -1.8427, -2.3444,  1.1401, -0.3017,  0.1171, -3.0900,
         -0.1801,  0.7504,  0.8529,  0.3215],
        [-2.2841, -0.9075,  1.3475, -0.8056,  1.3093,  1.6509, -0.0150, -2.0965,
          0.5121, -1.3069, -0.3245, -0.4751]]), 'image': tensor([[[[ 1.6610e+00, -6.6716e-01, -2.5674e-03,  ..., -1.1119e+00,
            1.0111e-01, -3.8947e-01],
          [-6.0846e-02, -4.2238e-01, -1.7890e+00,  ..., -5.1320e-01,
           -3.2174e-02,  1.5891e+00],
          [ 6.6530e-01,  3.7910e-01, -1.1383e+00,  ..., -2.8363e+00,
           -1.7331e+00,  8.3179e-01],
          ...,
          [-1.7364e+00,  8.8087e-02,  7.6660e-01,  ...,  7.0450e-02,
           -1.7959e-01,  3.5613e-02],
          [ 4.5644e-01, -8.5190e-01, -2.1841e-01,  ..., -3.5723e-01,
           -7.7395e-01,  1.1753e+00],
          [ 7.9634e-01, -9.4494e-01,  8.4530e-01,  ..., -8.8762e-01,
            2.0041e+00, -8.3095e-01]],

         [[ 3.9545e-01,  1.7853e-02,  1.2258e+00,  ...,  4.6404e-01,
            1.6273e+00,  6.7164e-01],
          [ 7.8439e-01, -1.7749e-02,  1.1747e+00,  ..., -1.4459e+00,
           -2.6065e-01, -2.0704e+00],
          [ 2.3344e-01, -3.5753e-01, -2.0137e-01,  ..., -3.0735e-01,
            1.0049e+00, -7.7714e-02],
          ...,
          [-5.1728e-01, -6.1153e-01,  3.1373e+00,  ...,  6.0743e-01,
            8.8225e-01, -1.7986e+00],
          [ 1.8084e-01, -3.8146e-01, -2.8826e-01,  ...,  9.1983e-01,
            1.3678e+00,  1.0265e+00],
          [-1.7859e+00,  1.5123e+00, -5.6433e-01,  ..., -7.3448e-01,
            6.5039e-01, -1.4252e-01]],

         [[ 3.3558e-01,  7.9159e-01, -7.8051e-01,  ..., -9.0632e-02,
           -4.1254e-01, -1.0340e+00],
          [ 1.4462e-01, -1.5996e-01,  1.5098e-01,  ..., -6.6997e-01,
           -1.5398e+00, -1.1684e+00],
          [ 3.0423e-02, -8.4161e-02, -8.2469e-01,  ..., -9.4156e-01,
           -6.0047e-01,  1.0833e+00],
          ...,
          [-1.2074e+00, -6.7881e-01, -8.8804e-01,  ...,  1.3672e+00,
           -2.3897e-01, -1.4595e+00],
          [-1.5093e+00,  1.9467e-01,  2.6247e-01,  ..., -2.4687e-01,
            8.1974e-01,  8.6477e-01],
          [-1.0264e+00,  4.6218e-01, -1.2689e+00,  ..., -1.4148e+00,
           -1.8802e+00,  1.1174e+00]]],


        [[[ 7.1722e-01,  2.7806e-01,  2.6875e-01,  ..., -4.0581e-01,
            7.5477e-01, -7.3399e-01],
          [ 1.2497e-01, -6.6935e-01, -1.1342e-01,  ..., -1.9850e+00,
            4.2486e-01,  4.6698e-01],
          [-1.7070e+00, -1.8608e+00, -2.9015e-01,  ..., -3.2438e-01,
            5.1105e-01,  6.0397e-01],
          ...,
          [ 2.9754e-01, -1.0941e+00, -1.7757e-01,  ...,  5.7503e-01,
            3.0076e-01, -2.7469e-02],
          [ 8.9610e-01, -4.6541e-01,  1.7398e+00,  ...,  7.6940e-01,
            9.9921e-01,  3.4635e-01],
          [-1.3408e+00, -1.1096e+00,  1.1439e+00,  ..., -1.0166e+00,
            1.5637e-01,  1.0217e+00]],

         [[ 6.9857e-01,  2.4762e-02, -1.6608e+00,  ...,  2.3986e+00,
            1.4482e+00, -1.1221e-01],
          [-1.2235e+00,  9.5569e-01,  8.0277e-01,  ...,  1.2390e-01,
            6.3446e-01, -4.8159e-01],
          [ 1.1137e-01, -4.1811e-01,  1.6600e+00,  ...,  9.5241e-01,
           -3.9915e-01, -3.4954e-01],
          ...,
          [-6.4935e-01,  3.8598e-01, -1.9078e+00,  ..., -1.0743e+00,
           -5.7636e-01, -9.9310e-02],
          [-6.7461e-01,  2.1996e+00,  6.1991e-01,  ...,  1.3554e+00,
            5.9770e-01, -4.0413e-01],
          [ 3.6842e-01,  2.8036e-01,  2.2252e-01,  ...,  2.9636e-01,
           -4.0549e-01,  2.7558e-01]],

         [[ 3.5620e-01,  6.0534e-01, -3.1990e-01,  ..., -2.7719e+00,
           -1.7094e+00, -3.4541e-02],
          [ 8.0527e-01, -2.9047e+00, -5.7613e-01,  ...,  4.7329e-01,
           -3.2298e-01,  5.8782e-01],
          [-8.8396e-01, -5.3898e-01, -8.7339e-02,  ...,  1.5157e+00,
            2.5611e-01, -1.2633e+00],
          ...,
          [-1.8839e+00, -8.8443e-01,  5.7534e-01,  ...,  2.1994e-01,
           -9.9814e-01, -6.8960e-01],
          [-4.5642e-01,  1.2131e+00,  7.1354e-01,  ...,  2.7715e-01,
           -1.1295e+00,  5.8713e-01],
          [ 3.3697e+00, -2.3707e-01,  8.7955e-01,  ...,  2.7682e-01,
           -1.1120e-01,  7.7729e-01]]],


        [[[-1.2380e+00,  2.2912e-01,  5.3711e-01,  ...,  6.1321e-02,
            2.9950e-01, -8.9248e-01],
          [ 1.5834e+00,  4.3304e-01, -1.2799e+00,  ...,  4.7555e-01,
            8.4607e-01, -1.6996e-01],
          [ 4.3139e-01, -1.9600e+00,  1.0313e+00,  ...,  9.2116e-01,
           -7.9623e-01, -1.0047e+00],
          ...,
          [-1.9196e+00, -1.3142e+00,  3.0070e-01,  ..., -7.7360e-01,
            1.6803e-01, -5.1410e-01],
          [-7.7753e-01, -5.5371e-02, -3.5943e-01,  ..., -5.0513e-01,
           -5.1643e-01, -1.5594e+00],
          [ 2.2311e+00,  5.3528e-01,  8.7125e-01,  ...,  2.2423e-01,
           -8.5476e-01, -1.4118e+00]],

         [[ 6.4727e-01,  3.2528e-02, -6.0068e-01,  ...,  8.2503e-01,
           -3.9128e-01,  1.0412e+00],
          [-5.6120e-03, -8.2240e-01,  1.0684e+00,  ...,  7.7922e-01,
           -1.8849e+00, -9.9019e-01],
          [ 2.1886e-01,  4.0195e-02, -1.2756e-01,  ..., -1.6997e-01,
           -8.5318e-01,  1.0684e+00],
          ...,
          [-1.5420e+00,  3.6812e-01,  5.6735e-01,  ...,  1.4701e+00,
           -6.5132e-02,  7.0469e-01],
          [-3.5373e-01, -3.0333e-01, -1.1018e+00,  ...,  1.0491e+00,
           -6.8210e-01,  2.9134e-01],
          [ 4.7230e-02, -1.1026e-01,  7.9086e-01,  ..., -5.4547e-01,
            1.9298e+00, -1.2211e+00]],

         [[-1.3403e-01, -7.8478e-01,  3.0228e-01,  ...,  2.5261e-01,
           -1.0661e+00,  3.5013e-01],
          [ 1.7791e-01,  8.1826e-01, -2.2315e-01,  ..., -1.3331e+00,
           -3.5003e-01, -4.1320e-01],
          [ 2.9800e-01,  4.4868e-01, -1.2710e+00,  ...,  7.0277e-01,
           -6.2180e-01, -2.0842e+00],
          ...,
          [ 5.5338e-01, -2.1307e+00, -7.8954e-01,  ...,  5.1934e-01,
            1.8357e-04,  2.2463e+00],
          [ 7.2245e-01,  9.7585e-03, -1.3792e+00,  ..., -1.3492e+00,
            1.1175e+00,  9.5284e-01],
          [-1.0783e+00, -1.6424e+00,  2.9628e-01,  ..., -3.7582e-01,
            6.3133e-01,  4.2903e-01]]],


        [[[-1.3136e+00,  5.4848e-01, -9.0181e-02,  ...,  1.7092e-01,
            1.2252e+00,  1.2149e-01],
          [ 4.0342e-01, -4.7260e-01, -1.2950e+00,  ..., -4.2847e-01,
            3.4140e-01,  7.4329e-01],
          [-2.0541e-01,  5.0495e-01,  3.0750e-01,  ...,  9.2752e-01,
            1.2537e+00,  6.9711e-02],
          ...,
          [-3.9146e-01,  2.7727e-01,  2.0646e+00,  ...,  1.7233e+00,
            8.8880e-01, -1.0380e-01],
          [-9.6824e-02, -1.5755e+00, -8.1249e-01,  ..., -9.0718e-01,
           -6.0762e-01, -5.3578e-01],
          [ 1.2190e+00, -5.5851e-01,  1.1164e+00,  ..., -1.0015e-01,
           -7.5845e-02, -1.6652e+00]],

         [[ 1.0584e+00, -8.3761e-01,  3.9262e-01,  ...,  1.8085e-01,
           -4.7859e-01, -1.1761e+00],
          [-2.3864e-01,  2.5746e-02,  2.1723e-01,  ...,  7.2606e-01,
            6.6296e-01,  8.9619e-01],
          [ 2.1951e+00,  9.7526e-01,  9.4685e-02,  ...,  1.2771e+00,
            8.1224e-01, -5.0146e-02],
          ...,
          [ 1.0919e-01,  4.0990e-01,  1.4036e+00,  ..., -1.2443e+00,
           -1.3147e-01,  1.0036e+00],
          [-3.5938e-01, -2.2983e-01,  3.6190e-01,  ...,  7.2910e-01,
           -5.7873e-01, -7.2678e-01],
          [-2.5017e-02, -2.5089e-01,  1.7725e+00,  ...,  9.0255e-01,
            1.3763e-01,  5.9141e-01]],

         [[-1.6782e+00,  1.0001e+00,  2.7471e+00,  ..., -6.1081e-02,
            1.6030e+00,  1.3063e+00],
          [ 2.7889e+00,  3.3484e-01,  6.0145e-01,  ..., -4.5082e-02,
           -1.0295e+00, -2.8439e-01],
          [ 1.0260e+00, -2.3331e-02, -4.2302e-01,  ...,  1.9972e+00,
           -2.5823e-01,  1.6595e+00],
          ...,
          [-3.0104e-01, -9.8066e-01,  1.3314e-01,  ...,  1.2524e-01,
           -8.2533e-01, -2.0243e-01],
          [-7.5341e-01, -1.9238e+00,  1.0450e+00,  ...,  2.0502e-01,
            1.0519e+00, -4.7291e-01],
          [-5.4682e-01, -1.5921e+00,  1.3838e+00,  ...,  7.9409e-01,
            6.9568e-01,  3.7058e-01]]]])}, 'action': tensor([[2],
        [0],
        [9],
        [5]]), 'reward': tensor([[0.0673],
        [0.3907],
        [0.7002],
        [0.3012]]), 'done': tensor([False, False, False, False])}

We can see that the structure process function need to be fully implemented (like the function stack). This code is actually not clear, and due to hard coding, if you need to support more data types (such as integer), you must make special modifications to the function.

Stack With TreeTensor API

The same workflow can be implemented with treetensor API like the code below.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch

import treetensor.torch as ttorch

B = 4


def get_item():
    return {
        'obs': {
            'scalar': torch.randn(12),
            'image': torch.randn(3, 32, 32),
        },
        'action': torch.randint(0, 10, size=(1,)),
        'reward': torch.rand(1),
        'done': False,
    }


data = [get_item() for _ in range(B)]
# execute `stack` op
data = [ttorch.tensor(d) for d in data]
stacked_data = ttorch.stack(data, dim=0)
# validate
print(stacked_data)
assert stacked_data.obs.image.shape == (B, 3, 32, 32)
assert stacked_data.action.shape == (B, 1)
assert stacked_data.reward.shape == (B, 1)
assert stacked_data.done.shape == (B,)
assert stacked_data.done.dtype == torch.bool

The output should be like below, and the assertion statements can be all passed as well.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
<Tensor 0x7f67b3933b20>
├── 'action' --> tensor([[6],
│                        [5],
│                        [9],
│                        [1]])
├── 'done' --> tensor([False, False, False, False])
├── 'obs' --> <Tensor 0x7f674cb97ee0>
│   ├── 'image' --> tensor([[[[ 1.5096e+00, -2.5924e+00, -1.2985e+00,  ...,  4.1741e-01,
│   │                          -7.1140e-01, -5.9812e-01],
│   │                         [ 1.8629e+00,  1.3916e+00, -5.3894e-01,  ...,  1.5501e+00,
│   │                          -5.2787e-02,  1.5425e+00],
│   │                         [-7.7700e-01, -9.6292e-02, -1.2389e+00,  ..., -3.7741e-03,
│   │                          -1.0594e+00,  8.5284e-01],
│   │                         ...,
│   │                         [ 2.2239e+00,  3.5538e-01, -2.0282e-01,  ...,  1.9719e-01,
│   │                          -2.1059e+00,  1.6486e+00],
│   │                         [-1.7360e+00,  5.4855e-01,  6.4995e-01,  ...,  1.4169e+00,
│   │                           1.9953e+00, -1.2208e+00],
│   │                         [-2.4211e+00,  3.2995e-01,  1.4672e+00,  ...,  1.2811e+00,
│   │                          -6.5591e-02, -8.9733e-01]],
│   │               
│   │                        [[ 1.0635e+00, -1.6979e+00,  1.2968e+00,  ..., -1.2375e+00,
│   │                           1.9904e+00,  9.4620e-01],
│   │                         [-4.4523e-01,  2.0002e+00,  7.0469e-02,  ...,  4.8900e-01,
│   │                          -5.1857e-01, -2.0646e+00],
│   │                         [ 4.1355e-01, -1.7892e+00, -5.6645e-01,  ...,  4.6617e-01,
│   │                           7.5357e-01,  7.9125e-01],
│   │                         ...,
│   │                         [-3.8282e-01, -4.4059e-01, -2.5233e-02,  ..., -1.1164e+00,
│   │                          -1.3296e+00,  8.9965e-02],
│   │                         [-1.2467e+00,  9.2784e-01, -8.3671e-01,  ..., -4.6411e-01,
│   │                           4.9811e-01,  1.3054e-01],
│   │                         [-7.6199e-01, -3.1625e-01, -2.5220e-01,  ...,  4.2622e-01,
│   │                          -2.9130e-01, -5.7645e-01]],
│   │               
│   │                        [[-8.2144e-01,  6.9815e-01, -1.1575e-01,  ...,  1.0502e+00,
│   │                          -4.9453e-02,  4.9258e-01],
│   │                         [-3.4399e-01,  7.3040e-02,  2.8991e-01,  ...,  9.3092e-01,
│   │                           2.4635e-01,  3.2799e+00],
│   │                         [ 1.3362e+00,  1.0832e-01,  2.3803e+00,  ..., -1.8618e-01,
│   │                          -1.2159e+00,  4.3580e-01],
│   │                         ...,
│   │                         [-4.3481e-01, -3.4832e-01,  1.8724e+00,  ..., -8.7521e-01,
│   │                           9.7985e-01, -2.6546e-01],
│   │                         [ 6.8399e-01, -4.6111e-01,  1.2375e+00,  ..., -4.3992e-01,
│   │                           7.7478e-01, -1.0183e+00],
│   │                         [ 1.4008e+00,  5.0574e-01, -1.0997e+00,  ...,  1.5970e+00,
│   │                          -1.5338e-01,  3.0924e-01]]],
│   │               
│   │               
│   │                       [[[-9.8478e-01,  2.0089e+00,  6.4709e-01,  ...,  2.8368e-01,
│   │                          -8.1748e-01, -1.7254e+00],
│   │                         [ 5.9830e-02,  1.7519e+00, -5.6280e-01,  ...,  1.8556e+00,
│   │                          -4.9722e-01, -3.5804e-02],
│   │                         [-7.6478e-02,  8.0604e-01, -7.2962e-01,  ...,  1.0894e+00,
│   │                           2.1358e-01, -1.0206e+00],
│   │                         ...,
│   │                         [ 8.2956e-01, -2.9412e-01,  1.8450e+00,  ..., -3.2224e-01,
│   │                          -2.1809e-01,  1.3046e-01],
│   │                         [-8.4717e-01,  6.6377e-01,  1.5835e+00,  ..., -1.7277e+00,
│   │                           1.1663e+00,  2.1333e-01],
│   │                         [ 9.6119e-01, -1.6126e+00,  1.1323e+00,  ..., -5.3644e-03,
│   │                          -1.2464e-01,  1.0318e+00]],
│   │               
│   │                        [[-9.8105e-01, -1.4832e+00,  1.1134e-01,  ...,  1.7990e+00,
│   │                          -7.8146e-01, -2.8395e-01],
│   │                         [ 7.5591e-01,  8.6817e-01,  7.2440e-01,  ..., -4.2268e-01,
│   │                           1.3767e-01, -3.7290e-01],
│   │                         [-1.5976e-01,  7.8546e-01,  5.3195e-01,  ...,  1.7872e+00,
│   │                           9.8206e-01,  2.1067e+00],
│   │                         ...,
│   │                         [ 1.1195e+00,  2.1198e-02,  3.1971e-01,  ...,  1.6998e-01,
│   │                          -4.6388e-01, -1.4809e-01],
│   │                         [-1.4483e+00, -1.3471e+00, -6.9803e-01,  ...,  2.3547e-01,
│   │                          -2.3611e+00, -7.0573e-01],
│   │                         [ 9.2485e-01, -6.4721e-01,  3.8117e-02,  ..., -2.5686e+00,
│   │                           4.5731e-01,  3.4798e-01]],
│   │               
│   │                        [[-3.5736e-01,  7.6402e-01,  2.6200e-01,  ...,  5.7103e-01,
│   │                           9.9594e-01, -7.4228e-01],
│   │                         [-8.1190e-01, -1.2320e+00, -3.3982e-01,  ..., -5.5238e-01,
│   │                          -1.0162e+00, -9.5020e-01],
│   │                         [ 4.0352e-01, -7.3156e-01,  4.7960e-01,  ..., -5.1536e-01,
│   │                           3.9957e-01, -6.6148e-01],
│   │                         ...,
│   │                         [-3.2271e-01, -1.7166e-01, -7.7436e-01,  ..., -4.5268e-01,
│   │                           2.7514e-01, -1.2822e+00],
│   │                         [ 1.6007e+00, -9.3195e-01, -1.1446e+00,  ..., -9.0853e-02,
│   │                           2.7956e+00, -1.4393e+00],
│   │                         [ 1.0222e-01, -5.8582e-01, -1.3279e+00,  ..., -1.5334e-01,
│   │                           7.2398e-01, -1.4310e+00]]],
│   │               
│   │               
│   │                       [[[ 1.5704e-01, -1.7927e-01, -2.9434e-01,  ..., -1.0364e+00,
│   │                          -3.4321e-01, -1.2803e+00],
│   │                         [-1.8595e+00, -1.9555e+00,  6.8676e-01,  ..., -5.0676e-01,
│   │                          -7.8089e-01,  5.3276e-01],
│   │                         [ 4.8389e-01,  2.2939e+00,  1.3774e-01,  ...,  7.2845e-01,
│   │                           3.8985e-01, -2.2856e+00],
│   │                         ...,
│   │                         [ 1.3233e+00,  3.5061e-01,  2.2866e+00,  ..., -2.6043e-01,
│   │                          -2.4169e+00, -1.1341e+00],
│   │                         [-1.9741e-01, -2.4740e+00, -8.6849e-01,  ...,  6.1612e-01,
│   │                           3.9695e-01,  2.8495e-01],
│   │                         [ 1.6165e+00,  1.3103e+00, -5.2383e-01,  ..., -7.0674e-01,
│   │                          -5.6912e-01,  2.2296e-01]],
│   │               
│   │                        [[-1.4132e+00,  3.3454e-01, -1.7366e+00,  ..., -2.8714e-01,
│   │                          -1.0491e+00,  1.1373e+00],
│   │                         [-6.8534e-01,  4.8490e-01,  1.6887e+00,  ...,  2.1171e-01,
│   │                           6.4883e-01,  1.9160e+00],
│   │                         [-6.7286e-01,  7.7061e-01, -5.6293e-01,  ..., -6.3796e-01,
│   │                           1.1811e+00, -3.7864e-01],
│   │                         ...,
│   │                         [ 4.8821e-01, -7.9141e-01,  6.6298e-01,  ..., -4.0639e-01,
│   │                          -5.6672e-01,  6.0164e-01],
│   │                         [ 1.6093e+00,  9.5938e-01, -2.7474e-01,  ...,  4.7161e-01,
│   │                           1.3705e+00,  3.5498e-01],
│   │                         [-3.2744e-01, -4.7183e-01, -1.4731e+00,  ..., -1.2090e+00,
│   │                           1.7744e+00, -1.7377e-01]],
│   │               
│   │                        [[-7.9389e-01, -1.7566e-01,  1.1357e+00,  ..., -4.1468e-01,
│   │                          -1.9169e+00,  2.4644e+00],
│   │                         [-2.8844e+00, -9.2150e-01,  4.0514e-01,  ...,  4.3330e-01,
│   │                          -3.6057e-01, -1.0335e+00],
│   │                         [-6.9818e-02, -3.8488e-01,  1.1589e+00,  ..., -8.3541e-01,
│   │                          -1.3521e+00, -6.7402e-01],
│   │                         ...,
│   │                         [ 2.0459e+00,  7.6486e-01,  1.6111e-01,  ...,  8.8801e-01,
│   │                           4.5756e-04,  5.9226e-01],
│   │                         [-3.3977e-01,  1.4775e-01, -6.0727e-02,  ..., -1.4168e+00,
│   │                           5.4174e-01, -5.8324e-01],
│   │                         [-6.2560e-01,  4.0346e-01, -3.7235e-01,  ..., -1.9103e+00,
│   │                           3.0439e+00, -9.0031e-01]]],
│   │               
│   │               
│   │                       [[[ 5.8669e-01, -2.1086e+00,  1.0279e+00,  ...,  1.3415e-01,
│   │                          -3.0481e-01,  1.0725e+00],
│   │                         [ 1.1445e+00, -7.9624e-01, -6.9794e-01,  ..., -3.9854e-01,
│   │                           1.8777e+00,  1.2824e+00],
│   │                         [ 1.5032e+00,  2.4661e-01, -2.0454e+00,  ..., -4.0469e-02,
│   │                           4.1951e-01,  9.8916e-01],
│   │                         ...,
│   │                         [-8.6445e-01, -1.4327e-01,  4.5598e-01,  ...,  4.3978e-01,
│   │                           3.3533e-01, -3.5248e-01],
│   │                         [-6.1360e-01,  4.5503e-01,  6.2883e-01,  ..., -5.8277e-01,
│   │                           1.2092e-01,  4.2707e-01],
│   │                         [-4.0543e-01,  1.6866e+00, -1.2542e-01,  ..., -6.2828e-01,
│   │                           6.3958e-01, -1.3999e+00]],
│   │               
│   │                        [[-4.2511e-01, -3.8454e-01,  1.0229e+00,  ..., -9.9500e-01,
│   │                           3.4999e-01, -8.7348e-01],
│   │                         [ 6.4947e-02,  2.8181e+00,  7.9655e-01,  ..., -1.0276e+00,
│   │                          -1.4040e-01, -8.4645e-01],
│   │                         [-7.8912e-01, -1.4459e-01, -1.1645e-01,  ..., -1.2434e+00,
│   │                           5.5709e-01, -4.0519e-01],
│   │                         ...,
│   │                         [-1.5093e+00,  5.4285e-01,  1.2440e-01,  ..., -2.0399e+00,
│   │                           8.6725e-01, -7.4004e-01],
│   │                         [-1.1745e+00, -1.4943e-01, -5.6547e-01,  ...,  4.3259e-01,
│   │                          -1.1197e+00, -1.3377e+00],
│   │                         [ 3.6998e-01,  1.5458e+00, -8.3772e-01,  ...,  1.2480e+00,
│   │                           1.1511e+00, -6.7690e-01]],
│   │               
│   │                        [[-1.3623e+00,  7.4699e-01, -8.0144e-01,  ...,  8.3665e-01,
│   │                           5.7975e-01,  5.3105e-01],
│   │                         [ 1.0468e-01,  1.1169e-02,  8.8025e-01,  ...,  2.1460e-01,
│   │                           6.0716e-01,  1.7213e+00],
│   │                         [-1.4123e-01,  3.9434e-01,  7.4176e-01,  ...,  1.2088e+00,
│   │                           1.6890e+00, -9.7466e-01],
│   │                         ...,
│   │                         [-1.3751e+00,  8.6904e-02,  4.9621e-01,  ..., -8.1262e-01,
│   │                          -6.9502e-01,  1.0489e+00],
│   │                         [-8.3935e-01, -1.2021e+00, -9.7100e-02,  ..., -1.3282e+00,
│   │                           9.0960e-01, -5.5755e-01],
│   │                         [-2.3514e-01,  7.6488e-01, -3.4905e-01,  ..., -1.1052e+00,
│   │                           1.0690e+00, -9.9615e-01]]]])
│   └── 'scalar' --> tensor([[ 1.0388, -0.5856, -0.2967,  0.3854, -0.7474, -1.2775, -0.5164,  0.2292,
│                             -1.3179, -0.5619,  0.7937,  1.9121],
│                            [ 0.9892,  0.8679,  0.9397,  0.2684, -0.7935, -1.2487, -0.8953,  0.7985,
│                             -0.4398, -1.3260, -0.1225,  0.8914],
│                            [-0.1819, -1.2881, -1.7916,  1.0303,  0.0899, -0.1217, -1.0786,  0.6543,
│                              0.2724, -1.6755,  1.6962, -0.5407],
│                            [-1.0845,  1.0362,  1.1486,  1.7401,  0.8451, -1.6223, -0.7594, -0.5500,
│                             -0.2333,  1.6559, -0.4189,  0.5876]])
└── 'reward' --> tensor([[0.2597],
                         [0.5514],
                         [0.0990],
                         [0.1488]])

This code looks much simpler and clearer.