Stack Structured Data

When the tensors form the tree structures, they are often needed to be stacked together, like the torch.stack() implemented in torch.

Stack With Native PyTorch API

Here is the common code implement with native pytorch API.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch

B = 4


def get_item():
    return {
        'obs': {
            'scalar': torch.randn(12),
            'image': torch.randn(3, 32, 32),
        },
        'action': torch.randint(0, 10, size=(1,)),
        'reward': torch.rand(1),
        'done': False,
    }


data = [get_item() for _ in range(B)]


# execute `stack` op
def stack(data, dim):
    elem = data[0]
    if isinstance(elem, torch.Tensor):
        return torch.stack(data, dim)
    elif isinstance(elem, dict):
        return {k: stack([item[k] for item in data], dim) for k in elem.keys()}
    elif isinstance(elem, bool):
        return torch.BoolTensor(data)
    else:
        raise TypeError("not support elem type: {}".format(type(elem)))


stacked_data = stack(data, dim=0)
# validate
print(stacked_data)
assert stacked_data['obs']['image'].shape == (B, 3, 32, 32)
assert stacked_data['action'].shape == (B, 1)
assert stacked_data['reward'].shape == (B, 1)
assert stacked_data['done'].shape == (B,)
assert stacked_data['done'].dtype == torch.bool

The output should be like below, and the assertion statements can be all passed.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
{'obs': {'scalar': tensor([[-0.7858,  1.4707, -1.0140,  0.3425,  0.9461,  1.2568, -0.5284, -0.5226,
         -0.2006, -0.5780,  0.1889, -0.1170],
        [ 0.0133,  0.1080,  1.9860, -0.0293,  0.4622, -0.4333, -0.2437, -0.5971,
         -0.4327, -0.8991,  0.4368, -0.3291],
        [-0.6527, -0.4469,  0.4806,  1.4929,  0.9057,  1.0086, -0.1577, -1.0740,
         -1.5294,  2.7303, -0.8944,  2.3386],
        [ 1.4629, -1.3995,  1.2038, -0.0669,  1.2354,  0.2135, -0.0766,  1.1779,
          0.9396, -0.3776, -0.7392, -0.0039]]), 'image': tensor([[[[-2.4659e-01, -1.7185e+00,  2.2048e+00,  ..., -6.3458e-01,
           -1.3307e+00,  5.2473e-01],
          [ 4.6504e-01,  1.4175e+00, -1.7332e+00,  ...,  9.5714e-01,
           -2.7737e-01, -1.6383e-01],
          [ 6.0394e-01, -2.7145e+00,  3.2075e-01,  ...,  4.3352e-01,
            4.4158e-01,  1.2664e+00],
          ...,
          [ 7.2162e-01, -6.4988e-01, -1.4958e+00,  ..., -7.2869e-02,
           -6.4163e-01,  7.0806e-01],
          [-4.0786e-01,  9.3613e-01, -8.7288e-01,  ..., -7.8312e-01,
           -8.6041e-01, -7.8952e-02],
          [ 2.2861e-01, -1.5923e+00, -4.0406e-01,  ...,  7.4344e-01,
            5.3471e-01, -7.8023e-01]],

         [[ 6.5521e-01,  7.7477e-01, -3.9315e-01,  ...,  4.0745e-01,
           -1.7927e+00, -7.6946e-01],
          [ 1.8634e-01,  2.7587e-01, -1.5150e-02,  ..., -1.0080e+00,
            3.3304e-01,  4.1316e-01],
          [-2.0981e+00,  8.9304e-03, -9.6102e-01,  ..., -2.4784e-01,
           -6.5200e-02,  9.6248e-01],
          ...,
          [ 2.9297e-01,  4.0522e-01, -1.0095e-01,  ..., -6.5406e-01,
           -7.3121e-01,  9.7984e-01],
          [ 3.1364e-01, -9.1747e-01,  2.5849e-01,  ...,  2.1181e-01,
            3.9146e-01, -3.9142e-01],
          [ 1.4762e-01, -2.4625e-01,  1.2718e+00,  ...,  4.3475e-01,
           -3.2120e-01, -1.6193e-03]],

         [[ 1.3476e+00,  1.1615e+00,  8.6449e-01,  ..., -4.0798e-01,
            5.1508e-01, -6.3406e-03],
          [ 1.2268e+00, -8.4481e-01,  9.3971e-01,  ...,  4.3556e-01,
            9.8673e-01, -6.9953e-02],
          [-7.4190e-01, -7.6777e-01, -5.1350e-01,  ..., -3.9492e-01,
            9.6998e-01,  7.2279e-01],
          ...,
          [-2.6256e-01,  6.9896e-01,  1.1152e+00,  ..., -6.1269e-01,
            1.9845e-01, -2.1720e+00],
          [ 6.5521e-01, -4.2927e-01, -4.3635e-01,  ...,  1.2591e+00,
           -2.3400e+00, -2.2657e-01],
          [-2.0195e+00,  8.0540e-01, -4.0448e-01,  ..., -1.3057e+00,
           -1.5000e+00,  1.5276e+00]]],


        [[[-1.1149e+00, -4.3801e-02,  5.5974e-01,  ...,  2.4219e+00,
           -9.0018e-01, -1.1646e+00],
          [ 3.2169e-01,  3.3518e-02,  7.2638e-02,  ..., -6.2362e-01,
           -2.6447e-02, -1.1346e-01],
          [ 5.5555e-01,  4.4253e-01, -1.0729e+00,  ...,  5.7889e-01,
           -3.7286e-01, -9.2874e-01],
          ...,
          [-9.2832e-01, -3.3340e-01, -4.2290e-01,  ...,  8.0095e-01,
           -6.2390e-01, -1.0970e+00],
          [ 6.5387e-01, -9.3481e-03, -1.8161e-01,  ..., -9.2834e-03,
           -9.1717e-02, -1.4855e-01],
          [-9.7934e-01,  8.7183e-01, -9.2308e-01,  ..., -4.8028e-01,
            8.1331e-01,  3.0530e-01]],

         [[-2.6336e+00,  1.4878e+00, -3.4662e-01,  ..., -1.5026e+00,
            1.0350e+00,  6.7063e-01],
          [-9.6939e-01,  3.9733e-01, -6.2983e-02,  ..., -9.0607e-01,
            8.0565e-01, -4.0444e-01],
          [-1.4645e+00, -9.2370e-01,  1.2829e+00,  ..., -6.1346e-01,
           -6.3182e-01, -1.1295e+00],
          ...,
          [-2.3609e-01, -4.9302e-01, -2.4148e+00,  ..., -4.5901e-01,
           -1.2246e-01, -2.6093e-01],
          [ 2.4740e-01,  4.6137e-01,  3.1498e-01,  ...,  1.9607e-02,
           -9.7382e-01,  1.0212e+00],
          [-3.0173e-01, -5.6957e-01, -6.9360e-01,  ...,  9.7358e-01,
           -7.9809e-01, -4.6035e-01]],

         [[-5.8716e-02,  1.3388e+00,  1.2717e+00,  ...,  1.1048e+00,
           -2.7483e-02,  3.4618e-01],
          [ 2.3906e-01,  1.5526e+00, -1.0565e+00,  ..., -6.0341e-01,
            1.6294e+00,  7.6178e-01],
          [-1.6971e-01,  7.2057e-01,  3.4025e-01,  ..., -3.9501e-02,
            1.8495e+00,  5.9107e-01],
          ...,
          [ 1.2788e+00, -8.3949e-01, -6.0674e-01,  ..., -6.8831e-01,
            1.6703e+00, -1.2544e-01],
          [-8.5759e-01,  7.2581e-01, -3.3760e-01,  ...,  1.5534e+00,
           -7.9265e-01,  2.6507e+00],
          [ 1.8689e+00,  1.5379e+00,  9.0901e-02,  ...,  6.9986e-01,
           -1.0397e+00,  1.6670e-01]]],


        [[[ 3.7998e-01, -2.1333e+00, -1.1548e-01,  ...,  1.7599e+00,
            7.9498e-02, -1.4721e+00],
          [-1.7937e+00,  2.1941e+00,  1.2975e+00,  ..., -3.1042e-01,
           -2.3745e-01,  1.8411e+00],
          [ 9.6877e-02,  4.6438e-01,  2.6683e+00,  ...,  2.4183e+00,
           -1.7441e-01,  6.6555e-01],
          ...,
          [-3.4976e-01,  1.5864e-01, -3.7359e-01,  ...,  8.3493e-01,
            6.1088e-01,  3.4430e-01],
          [-5.8951e-01,  1.5485e+00, -1.6572e-01,  ...,  8.1930e-01,
            2.8922e-02,  3.0849e+00],
          [ 1.0942e+00, -1.0089e-02, -1.1290e+00,  ..., -8.1528e-01,
            1.4440e+00,  1.5700e+00]],

         [[-2.6796e-01, -1.6170e+00,  2.9437e-02,  ...,  1.3213e+00,
            1.4277e+00, -5.6638e-01],
          [ 4.3161e-01, -4.6236e-01, -2.6068e-01,  ..., -1.3690e+00,
           -1.2225e+00, -3.2934e+00],
          [-6.2500e-01,  3.9914e-01, -1.9641e+00,  ...,  7.0300e-01,
            9.5558e-01, -5.5686e-01],
          ...,
          [ 8.0336e-01, -1.5005e+00, -7.9467e-01,  ...,  4.9953e-01,
           -1.5310e+00,  3.6723e+00],
          [-1.7362e+00,  9.1720e-01,  1.4141e-01,  ..., -7.2332e-01,
            1.5085e+00, -9.2864e-01],
          [-1.1214e-02,  6.0045e-02, -1.5762e+00,  ..., -2.6273e-02,
            3.9949e-01, -9.5913e-01]],

         [[ 2.0140e+00,  1.4902e+00,  1.1214e+00,  ...,  1.9922e+00,
           -9.4417e-01, -3.5837e-01],
          [ 7.3942e-01, -1.9667e+00, -1.4590e+00,  ..., -2.0315e-01,
           -7.0945e-01,  1.0442e+00],
          [-1.8128e+00,  8.2397e-01,  2.0722e-01,  ..., -1.9210e+00,
           -1.0287e-01,  6.6719e-01],
          ...,
          [ 2.0776e-02, -2.5954e+00, -1.8228e-01,  ..., -1.8908e+00,
           -7.6734e-01,  9.2329e-01],
          [-1.3040e-01,  8.6745e-01,  3.3627e-01,  ...,  1.5454e+00,
            1.4802e+00, -1.9542e+00],
          [-8.4142e-01,  7.8260e-01,  2.7708e-01,  ...,  3.8767e-01,
           -2.0852e-01, -6.4420e-01]]],


        [[[-1.3228e+00, -1.5100e-01, -1.4488e+00,  ...,  2.3245e+00,
           -1.1512e+00, -7.6958e-01],
          [-4.5921e-01, -1.3463e+00, -1.9701e+00,  ...,  1.6080e+00,
            4.7170e-02, -2.8114e-01],
          [-4.3051e-01,  8.1597e-01, -9.5485e-01,  ..., -1.4957e-01,
            2.9732e-01,  1.2486e+00],
          ...,
          [ 4.4421e-01, -4.2911e-01,  9.8360e-01,  ...,  5.7245e-01,
            1.9190e+00, -6.4668e-01],
          [-1.0912e+00, -5.4077e-01, -6.8238e-01,  ...,  1.3725e+00,
            2.4469e-01, -5.1720e-01],
          [ 7.3529e-01,  4.6675e-01,  6.2338e-01,  ..., -1.3810e-01,
            1.1567e-01, -1.7377e+00]],

         [[-1.0339e+00,  6.6975e-01,  2.5359e+00,  ..., -1.0351e+00,
            1.3270e-01,  1.0043e+00],
          [-1.2821e+00,  1.5787e+00,  1.5811e+00,  ..., -1.1534e+00,
            5.9315e-01,  1.9910e+00],
          [-8.0427e-01, -8.8088e-01,  1.1389e+00,  ..., -7.4648e-01,
            1.0220e-01,  3.5086e-01],
          ...,
          [-1.1831e+00, -6.0290e-01, -1.3301e+00,  ..., -4.7929e-03,
           -6.3689e-01, -4.9730e-02],
          [ 3.2546e-01,  2.1362e-01,  2.5618e-02,  ...,  2.9573e-01,
           -4.0383e-01,  2.2898e-01],
          [-1.6504e-01, -5.1286e-01,  8.7477e-02,  ...,  2.5851e-02,
            7.1362e-01, -3.4253e-01]],

         [[ 1.6580e-01,  1.2562e+00, -1.8057e+00,  ..., -5.6336e-01,
            1.7791e+00, -1.0500e-02],
          [-8.7682e-01, -1.4825e-01, -9.5634e-01,  ...,  1.1160e+00,
            6.7715e-01, -8.7418e-01],
          [ 1.6844e-01,  3.6256e-01, -4.5611e-01,  ...,  1.8643e-01,
           -6.6307e-01, -8.5916e-01],
          ...,
          [ 5.6802e-01, -2.6814e-01,  1.2839e+00,  ...,  4.4790e-01,
           -1.3898e+00,  4.1646e-01],
          [-1.5383e+00, -4.9714e-02, -4.2053e-01,  ...,  3.6633e-01,
           -8.7118e-01, -1.5102e+00],
          [-5.7814e-01, -9.2667e-01,  9.1067e-01,  ...,  4.6444e-01,
           -5.0832e-01,  1.1588e+00]]]])}, 'action': tensor([[5],
        [5],
        [5],
        [1]]), 'reward': tensor([[0.2205],
        [0.7495],
        [0.5091],
        [0.5871]]), 'done': tensor([False, False, False, False])}

We can see that the structure process function need to be fully implemented (like the function stack). This code is actually not clear, and due to hard coding, if you need to support more data types (such as integer), you must make special modifications to the function.

Stack With TreeTensor API

The same workflow can be implemented with treetensor API like the code below.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch

import treetensor.torch as ttorch

B = 4


def get_item():
    return {
        'obs': {
            'scalar': torch.randn(12),
            'image': torch.randn(3, 32, 32),
        },
        'action': torch.randint(0, 10, size=(1,)),
        'reward': torch.rand(1),
        'done': False,
    }


data = [get_item() for _ in range(B)]
# execute `stack` op
data = [ttorch.tensor(d) for d in data]
stacked_data = ttorch.stack(data, dim=0)
# validate
print(stacked_data)
assert stacked_data.obs.image.shape == (B, 3, 32, 32)
assert stacked_data.action.shape == (B, 1)
assert stacked_data.reward.shape == (B, 1)
assert stacked_data.done.shape == (B,)
assert stacked_data.done.dtype == torch.bool

The output should be like below, and the assertion statements can be all passed as well.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
<Tensor 0x7fbce4e6e190>
├── 'action' --> tensor([[6],
│                        [7],
│                        [7],
│                        [8]])
├── 'done' --> tensor([False, False, False, False])
├── 'obs' --> <Tensor 0x7fbce4e6e2b0>
│   ├── 'image' --> tensor([[[[-1.0246e+00, -8.8430e-01, -1.2725e+00,  ...,  4.8009e-01,
│   │                          -3.9986e-01, -1.9462e+00],
│   │                         [-1.7302e+00,  1.9129e+00, -1.1278e+00,  ..., -9.2231e-02,
│   │                          -9.4902e-01,  4.5159e-01],
│   │                         [ 4.1292e-01, -5.7044e-01,  9.1422e-01,  ..., -1.0290e+00,
│   │                          -1.9491e+00,  6.3939e-01],
│   │                         ...,
│   │                         [ 1.6791e+00,  4.3329e-01,  1.0022e+00,  ...,  9.8396e-01,
│   │                           2.7916e-01, -1.0737e+00],
│   │                         [ 2.9107e+00, -1.7424e+00,  2.8302e-01,  ..., -1.2364e+00,
│   │                          -5.8209e-01, -1.4043e+00],
│   │                         [-1.3203e+00, -9.0617e-01, -6.5877e-01,  ...,  6.4202e-01,
│   │                           1.0257e-01, -1.2502e+00]],
│   │               
│   │                        [[ 3.3720e-01,  2.2884e+00, -1.0749e+00,  ..., -4.9531e-01,
│   │                           4.7366e-01, -2.0852e-01],
│   │                         [ 3.8023e-01,  1.5508e+00, -1.3958e+00,  ...,  2.0193e+00,
│   │                          -5.8085e-01, -3.9019e-02],
│   │                         [-2.5457e-01, -9.1367e-01, -3.9968e-01,  ..., -4.5006e-01,
│   │                           2.5135e-01, -3.2401e-01],
│   │                         ...,
│   │                         [-6.8582e-01, -3.0150e+00,  5.3803e-01,  ...,  5.1124e-01,
│   │                          -4.0379e-01,  1.7022e+00],
│   │                         [ 2.4487e-01, -1.2174e-02,  1.9573e-01,  ...,  1.2873e+00,
│   │                          -2.6438e-03,  2.1685e+00],
│   │                         [-3.1065e+00, -9.0986e-01, -7.1704e-02,  ..., -1.3690e+00,
│   │                          -9.6121e-01,  1.6493e+00]],
│   │               
│   │                        [[ 1.0562e+00,  8.8660e-01, -4.0473e-02,  ..., -1.2579e+00,
│   │                          -1.1151e-01, -4.7223e-01],
│   │                         [-1.4035e-01,  1.2632e+00,  2.0630e-01,  ...,  8.4397e-01,
│   │                          -7.3512e-01,  3.0108e-01],
│   │                         [-1.5472e+00,  8.6881e-01,  3.8060e+00,  ...,  1.1473e+00,
│   │                           7.6952e-02, -2.6615e-01],
│   │                         ...,
│   │                         [ 1.8450e-01,  6.7810e-02, -4.1651e-02,  ..., -5.6022e-01,
│   │                          -7.1901e-01,  8.8822e-03],
│   │                         [-2.7449e-02,  4.8721e-01, -1.3217e+00,  ...,  5.9462e-01,
│   │                          -1.0834e+00, -9.3535e-01],
│   │                         [ 1.0771e-01,  3.6619e-01, -4.9350e-01,  ...,  1.2455e+00,
│   │                          -9.4500e-01, -4.3979e-01]]],
│   │               
│   │               
│   │                       [[[-9.9957e-01,  2.6408e+00, -5.8568e-01,  ...,  6.9457e-02,
│   │                           4.0335e-01, -3.4052e-02],
│   │                         [ 9.3848e-01,  7.3566e-01, -1.0368e+00,  ..., -8.0928e-01,
│   │                           1.9038e+00, -2.1124e-01],
│   │                         [-5.1587e-01,  1.2943e-01, -3.6656e-01,  ..., -2.9761e+00,
│   │                          -1.3710e+00, -7.3658e-01],
│   │                         ...,
│   │                         [-8.1212e-02,  8.1294e-01,  2.0697e-01,  ...,  3.5471e-01,
│   │                          -2.1863e-02, -2.0153e-02],
│   │                         [-4.6929e-01, -1.0739e+00, -1.0317e+00,  ..., -1.8212e+00,
│   │                           1.5082e+00,  1.5142e+00],
│   │                         [-7.9227e-01,  4.6795e-01,  8.4578e-01,  ..., -2.1940e-01,
│   │                           2.6860e-01,  4.7401e-01]],
│   │               
│   │                        [[-7.7893e-01, -2.6731e-01,  6.0581e-01,  ...,  1.0476e+00,
│   │                           5.0203e-02,  4.7208e-01],
│   │                         [ 1.8065e+00,  8.1147e-01, -1.3274e+00,  ..., -1.5226e-01,
│   │                          -1.2736e+00, -4.1463e-01],
│   │                         [-3.6127e-01,  2.9274e-01, -2.5918e-01,  ..., -1.9267e+00,
│   │                           6.9801e-01,  4.9498e-01],
│   │                         ...,
│   │                         [ 9.7332e-01, -4.0105e-01, -2.7137e-01,  ..., -1.0338e+00,
│   │                           6.8956e-02, -4.5122e-01],
│   │                         [ 1.6762e+00, -8.0366e-01, -4.0861e-01,  ..., -4.4810e-01,
│   │                          -1.3389e-01, -1.0669e+00],
│   │                         [-6.2461e-01,  3.5509e-02,  6.3614e-01,  ..., -4.1386e-01,
│   │                          -3.5633e-01,  1.1718e+00]],
│   │               
│   │                        [[-9.4570e-01,  7.1720e-01,  1.1973e+00,  ...,  2.7508e+00,
│   │                          -8.4332e-01,  1.7858e-01],
│   │                         [-5.5591e-01,  1.1040e+00,  1.5426e+00,  ...,  8.7637e-01,
│   │                          -1.6209e+00, -1.3327e+00],
│   │                         [ 1.9170e+00,  6.7683e-01, -2.7034e-01,  ..., -4.0786e-01,
│   │                          -5.3200e-01, -9.0168e-02],
│   │                         ...,
│   │                         [-3.0385e-02,  3.6850e-01, -3.4176e-01,  ...,  1.1125e+00,
│   │                           1.6566e-02,  9.7451e-01],
│   │                         [ 1.2883e+00,  1.0864e+00, -9.8243e-01,  ..., -5.3393e-01,
│   │                           6.9060e-02, -2.5049e-01],
│   │                         [-2.0220e+00, -7.2638e-01, -9.2629e-02,  ..., -4.4089e-01,
│   │                          -8.8905e-01, -8.6600e-01]]],
│   │               
│   │               
│   │                       [[[ 8.3656e-01, -1.9921e+00,  3.0603e-01,  ..., -4.8074e-01,
│   │                          -1.2629e+00, -1.1213e+00],
│   │                         [ 2.3799e-01,  2.6111e-01,  6.2038e-01,  ..., -1.5259e+00,
│   │                           1.8924e+00, -1.4035e+00],
│   │                         [ 2.8834e-01,  2.2810e-01, -1.3176e+00,  ..., -6.1899e-01,
│   │                           2.3440e-01, -1.8566e-02],
│   │                         ...,
│   │                         [ 8.4749e-02, -1.4309e+00, -1.1970e+00,  ..., -8.5013e-01,
│   │                           2.0141e+00, -8.3226e-01],
│   │                         [-6.0416e-01,  7.8808e-01,  6.2459e-01,  ...,  1.5436e+00,
│   │                          -1.2380e+00,  8.9732e-01],
│   │                         [ 4.7690e-01,  1.1065e-01,  7.7136e-01,  ...,  2.9811e-01,
│   │                           2.7168e-01,  2.6345e-02]],
│   │               
│   │                        [[ 1.1992e+00,  7.2626e-01, -4.7406e-01,  ..., -6.2764e-01,
│   │                          -4.6584e-01,  6.7850e-01],
│   │                         [-4.8615e-01, -1.0980e+00,  2.2046e+00,  ..., -1.7566e+00,
│   │                           3.1901e-01,  1.3555e+00],
│   │                         [-9.9374e-01, -1.1082e+00,  9.4718e-02,  ...,  1.4147e-03,
│   │                           6.8525e-01,  1.2805e+00],
│   │                         ...,
│   │                         [ 1.2790e+00,  1.3898e+00, -2.3188e+00,  ...,  6.4948e-01,
│   │                           2.0650e-02, -9.1545e-01],
│   │                         [ 2.6767e-01, -6.4679e-02,  7.7324e-01,  ..., -1.0083e+00,
│   │                          -1.2407e+00, -1.2468e+00],
│   │                         [ 6.6515e-01, -3.3858e-01,  2.3060e+00,  ...,  5.5070e-01,
│   │                          -1.5220e+00,  1.9295e+00]],
│   │               
│   │                        [[ 2.2590e+00, -1.8627e-01,  2.2707e+00,  ..., -5.0040e-01,
│   │                           5.6782e-03, -1.1826e+00],
│   │                         [ 2.1465e+00,  7.1557e-01, -2.6532e-01,  ..., -4.0490e-01,
│   │                          -4.9034e-01,  2.4351e-01],
│   │                         [ 2.4209e-01,  5.6923e-01,  9.2301e-01,  ..., -1.8959e-01,
│   │                           3.1919e-02, -1.1459e+00],
│   │                         ...,
│   │                         [ 2.9622e-01, -5.9302e-01,  3.6436e+00,  ...,  2.0232e+00,
│   │                           1.8814e+00,  1.5992e+00],
│   │                         [ 1.0715e+00,  2.8858e+00,  5.1627e-01,  ..., -1.5071e+00,
│   │                          -1.2885e+00,  8.3872e-01],
│   │                         [ 2.8005e-01, -1.8957e+00, -1.4291e+00,  ..., -5.3541e-01,
│   │                          -6.3762e-01, -6.2037e-01]]],
│   │               
│   │               
│   │                       [[[-1.4532e-01,  1.6583e-01, -6.7800e-01,  ..., -5.5735e-01,
│   │                           6.2348e-01,  1.7030e+00],
│   │                         [-2.1194e+00, -4.6013e-01, -4.3043e-01,  ...,  7.9954e-02,
│   │                          -1.9377e+00, -3.1011e-02],
│   │                         [ 3.1651e+00,  2.9665e-01,  5.5388e-02,  ..., -9.7278e-01,
│   │                           4.6199e-01, -7.1926e-01],
│   │                         ...,
│   │                         [-1.8500e-01,  1.0583e+00, -5.3590e-01,  ...,  1.8176e-01,
│   │                          -3.3309e-01,  2.7907e-01],
│   │                         [ 1.0312e+00,  6.0475e-01,  3.3776e-01,  ..., -3.8960e-02,
│   │                          -4.8853e-01,  1.1011e+00],
│   │                         [-4.9237e-01,  3.6887e-01, -1.2112e+00,  ...,  9.6820e-02,
│   │                           1.3762e-01, -3.3375e-01]],
│   │               
│   │                        [[-4.2384e-01,  7.1212e-01, -9.9415e-01,  ...,  3.3916e-01,
│   │                          -1.1783e+00,  1.6344e+00],
│   │                         [ 8.6838e-01,  1.8300e+00, -1.3059e+00,  ..., -7.7257e-01,
│   │                          -6.6901e-01, -1.3052e+00],
│   │                         [-1.5296e+00, -2.9463e-01,  4.1431e-02,  ...,  4.3528e-01,
│   │                           5.9916e-01, -4.4934e-01],
│   │                         ...,
│   │                         [ 2.6772e-01,  1.7160e+00, -1.8159e+00,  ..., -2.9415e-02,
│   │                          -3.1289e-01,  4.1249e-01],
│   │                         [-8.9727e-02,  7.3756e-01, -1.8032e-01,  ...,  1.7773e+00,
│   │                           7.9362e-01, -3.6837e-01],
│   │                         [ 7.0765e-01,  1.9979e-01,  1.9445e+00,  ...,  1.1912e+00,
│   │                          -8.8822e-01,  1.4295e-01]],
│   │               
│   │                        [[-2.8180e-02,  5.8417e-01, -8.5739e-01,  ..., -8.6050e-01,
│   │                          -6.7480e-01,  1.3177e+00],
│   │                         [ 1.2668e+00,  7.6879e-01, -1.2169e+00,  ...,  3.6852e-01,
│   │                           7.0144e-01,  8.7712e-01],
│   │                         [-1.8074e+00,  3.4534e-01, -1.1876e+00,  ..., -1.6073e+00,
│   │                           1.2646e+00,  1.2472e+00],
│   │                         ...,
│   │                         [ 2.3117e+00,  7.2074e-01, -1.5890e-01,  ...,  1.5906e+00,
│   │                          -1.9592e+00, -2.4748e-01],
│   │                         [ 3.6790e-01, -8.6337e-02, -1.5528e+00,  ...,  6.9861e-01,
│   │                           6.7337e-01, -8.6109e-01],
│   │                         [ 6.8111e-01,  5.2480e-01, -1.0565e-01,  ..., -8.6345e-01,
│   │                           1.5344e+00, -2.2042e-01]]]])
│   └── 'scalar' --> tensor([[-1.7710, -2.1136,  0.1759,  1.0415,  0.1290, -1.5161, -0.0271,  1.3272,
│                             -0.0061, -2.2568, -2.4636,  1.1810],
│                            [ 0.5186, -0.5596,  1.1712, -0.5116, -0.1917, -1.6485,  0.1522,  0.5870,
│                              2.9673,  0.7821, -0.5502,  0.7267],
│                            [ 0.4105,  0.1730, -0.7573,  0.2456,  0.3332, -0.7820, -0.7678,  0.4240,
│                             -0.1314, -0.8639,  0.1638, -0.5347],
│                            [-0.7781, -0.2391, -0.9212,  0.3216,  0.0664,  0.7816,  0.0865,  0.3464,
│                              1.8122,  1.2214,  1.7158, -0.2674]])
└── 'reward' --> tensor([[0.7772],
                         [0.1684],
                         [0.6344],
                         [0.6294]])

This code looks much simpler and clearer.