Stack Structured Data

When the tensors form the tree structures, they are often needed to be stacked together, like the torch.stack() implemented in torch.

Stack With Native PyTorch API

Here is the common code implement with native pytorch API.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch

B = 4


def get_item():
    return {
        'obs': {
            'scalar': torch.randn(12),
            'image': torch.randn(3, 32, 32),
        },
        'action': torch.randint(0, 10, size=(1,)),
        'reward': torch.rand(1),
        'done': False,
    }


data = [get_item() for _ in range(B)]


# execute `stack` op
def stack(data, dim):
    elem = data[0]
    if isinstance(elem, torch.Tensor):
        return torch.stack(data, dim)
    elif isinstance(elem, dict):
        return {k: stack([item[k] for item in data], dim) for k in elem.keys()}
    elif isinstance(elem, bool):
        return torch.BoolTensor(data)
    else:
        raise TypeError("not support elem type: {}".format(type(elem)))


stacked_data = stack(data, dim=0)
# validate
print(stacked_data)
assert stacked_data['obs']['image'].shape == (B, 3, 32, 32)
assert stacked_data['action'].shape == (B, 1)
assert stacked_data['reward'].shape == (B, 1)
assert stacked_data['done'].shape == (B,)
assert stacked_data['done'].dtype == torch.bool

The output should be like below, and the assertion statements can be all passed.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
{'obs': {'scalar': tensor([[ 7.8326e-01, -6.8331e-01, -9.0649e-01, -5.5945e-01, -1.0393e+00,
         -6.6517e-01,  9.5066e-01,  8.2402e-01,  1.5185e+00,  2.1209e-01,
         -4.1231e-01,  1.2537e-02],
        [-1.4019e+00,  6.1605e-01, -6.8361e-01,  7.8026e-01, -5.9984e-01,
          7.6129e-01,  1.4648e+00, -2.5011e-01,  1.7202e+00,  7.0563e-01,
          1.3625e-01, -7.0724e-01],
        [-4.9307e-01,  2.0924e-01,  3.3170e+00,  5.0752e-01,  8.5068e-01,
         -7.7194e-01, -1.1396e+00,  8.2016e-01,  7.2547e-01,  7.2704e-01,
          2.6749e-01,  8.3760e-01],
        [-3.5336e-01,  2.1490e-04, -1.1430e+00,  7.9597e-01, -1.1177e-01,
         -3.6127e-01, -1.0061e+00,  6.7657e-01,  2.3435e-01,  4.6530e-01,
          6.8542e-01,  1.6713e+00]]), 'image': tensor([[[[ 7.5967e-01,  2.3677e+00, -1.6185e-01,  ..., -8.2355e-01,
           -3.9295e-02, -1.9955e-02],
          [ 1.6472e+00,  6.1612e-01,  2.0368e-01,  ...,  1.5203e-01,
            1.0231e+00,  4.0338e-01],
          [-1.2747e-01,  7.7251e-01,  1.5336e+00,  ...,  1.3333e+00,
            5.3994e-01, -6.5639e-01],
          ...,
          [-3.1113e-01, -7.4817e-01,  9.6644e-01,  ...,  4.2423e-01,
            8.7326e-01,  8.0140e-01],
          [-8.2530e-02, -1.0227e+00, -1.6268e-01,  ..., -2.8627e-02,
            1.5870e+00, -4.8415e-01],
          [ 6.0471e-01,  7.5864e-01,  1.3898e-01,  ..., -2.4629e-01,
           -1.3195e+00, -6.0302e-01]],

         [[-1.3228e-01,  1.6346e+00, -5.0586e-01,  ..., -4.3679e-01,
           -5.2344e-01, -1.6767e+00],
          [-3.7857e-02, -1.2849e+00, -5.4202e-01,  ..., -1.2185e-01,
           -3.7975e-01, -9.1056e-01],
          [ 5.8344e-01, -2.0190e+00,  1.9997e-01,  ...,  6.5804e-02,
            7.0049e-02,  1.0491e+00],
          ...,
          [-1.4207e+00, -7.0291e-02,  2.8696e-01,  ..., -1.6500e+00,
           -1.2731e+00,  8.5969e-01],
          [-1.1368e+00,  1.1385e+00, -1.5195e-01,  ..., -6.2081e-01,
            5.7777e-01,  1.3055e+00],
          [-1.1427e-02, -6.9848e-01, -2.6561e+00,  ...,  1.4812e+00,
            3.5360e-01, -1.6545e+00]],

         [[-1.2383e+00,  1.9037e+00, -8.5530e-02,  ...,  7.7439e-01,
            7.2406e-01, -6.9666e-01],
          [-3.1571e-01, -2.8172e+00, -5.0036e-01,  ...,  1.0867e+00,
           -9.5219e-01,  1.1142e+00],
          [ 9.6173e-01, -1.0206e+00, -4.9838e-01,  ..., -1.7236e-01,
            3.3256e-01,  2.4114e+00],
          ...,
          [ 3.9102e-01,  1.2670e+00, -1.5322e+00,  ..., -6.0356e-01,
           -4.6033e-03, -3.5892e-02],
          [ 1.8968e+00,  2.2487e-01, -1.6317e-01,  ..., -2.5686e-01,
            6.9060e-01, -3.0022e-01],
          [-2.3545e-01,  1.8353e+00,  1.3482e-01,  ...,  3.3161e-01,
            2.9025e-01, -2.0178e+00]]],


        [[[ 1.1689e+00, -2.1013e+00, -8.3609e-01,  ...,  1.5565e+00,
            4.0571e-01, -2.7282e-01],
          [-4.0807e-01,  1.1532e-01,  1.6904e+00,  ..., -1.1915e+00,
            1.1675e+00, -2.8828e-01],
          [-1.4670e+00, -1.7190e+00, -8.0480e-01,  ...,  9.3016e-01,
            8.0495e-01, -9.5870e-01],
          ...,
          [-6.9530e-01, -6.4153e-01,  1.2936e+00,  ...,  4.1152e-01,
           -3.0769e-01, -1.7801e+00],
          [ 9.6147e-02, -6.3399e-01,  4.1660e-01,  ..., -1.8923e+00,
            1.2498e+00, -3.6506e-01],
          [-1.8689e-01,  6.0071e-01, -2.7068e-01,  ...,  1.2407e-01,
            1.7060e+00, -8.7069e-02]],

         [[ 9.3445e-01, -1.0811e+00,  2.0771e-01,  ...,  3.7677e-01,
           -6.2284e-02,  1.0065e+00],
          [ 2.9841e-02, -8.3726e-01, -2.9467e-01,  ...,  1.4418e+00,
            9.1296e-01,  6.6567e-02],
          [-5.1044e-01,  4.6408e-01,  1.4243e+00,  ..., -2.5163e+00,
           -1.4670e+00,  4.6528e-01],
          ...,
          [-6.9441e-01,  1.5699e-01, -8.1140e-01,  ..., -1.4479e+00,
           -4.8954e-01, -9.4979e-01],
          [-2.5636e-01,  1.1317e-01, -1.9187e+00,  ..., -1.9319e+00,
            7.2291e-01, -8.9293e-01],
          [-6.6313e-01, -1.0031e+00,  1.1510e-01,  ...,  5.5238e-01,
            8.9021e-01, -8.8792e-01]],

         [[-2.1538e-01, -1.6874e+00,  1.0778e+00,  ...,  9.8324e-01,
            1.5315e-01, -1.7844e-01],
          [ 7.1233e-01, -1.5829e+00, -2.8057e-01,  ...,  2.2528e-01,
            9.3379e-01,  2.0393e+00],
          [-1.7813e-01, -1.3809e+00, -8.6806e-01,  ...,  1.3456e+00,
            6.6897e-01,  2.7392e+00],
          ...,
          [ 1.6174e-03,  1.9014e+00, -1.8459e-01,  ..., -2.4190e+00,
            4.7605e-01, -1.3525e+00],
          [ 8.8451e-01, -8.9904e-01,  3.2420e-01,  ..., -1.5503e+00,
            4.6346e-01,  5.4471e-01],
          [-1.0496e+00,  7.9747e-01, -1.7142e+00,  ...,  2.5746e-01,
            6.8495e-01, -9.8599e-02]]],


        [[[-8.5057e-01, -1.1023e+00,  9.8774e-01,  ..., -3.6780e-01,
           -1.0927e+00,  1.0582e+00],
          [-2.0736e-01, -4.1865e-02,  1.0726e+00,  ..., -1.2956e+00,
            1.4910e+00, -4.3768e-01],
          [ 1.2296e+00,  1.3141e+00, -1.5074e+00,  ..., -1.6973e-01,
           -7.5717e-01,  5.8909e-01],
          ...,
          [ 1.3031e+00, -1.8145e+00,  2.4158e-01,  ...,  8.2723e-01,
           -4.2276e-01, -1.0698e+00],
          [-1.7393e+00,  1.0518e+00,  2.5781e+00,  ..., -1.4798e+00,
           -1.9492e+00, -6.1218e-01],
          [-1.0169e+00,  5.7333e-01, -1.0733e+00,  ..., -4.4049e-01,
            1.4505e+00,  1.7466e+00]],

         [[ 4.0453e-01, -1.3628e+00,  1.0815e+00,  ...,  1.1570e+00,
            1.2956e-01,  1.3223e+00],
          [ 4.1455e-01,  6.6520e-02, -4.7606e-01,  ...,  8.1480e-01,
           -7.2155e-01,  1.1066e+00],
          [ 9.2124e-01,  6.0358e-01,  1.5758e-01,  ...,  1.0077e+00,
           -1.3079e+00, -1.1441e+00],
          ...,
          [-5.4330e-01, -1.2901e+00,  6.0966e-01,  ...,  6.6400e-01,
           -8.1563e-01, -1.8732e-01],
          [-4.1601e-01,  5.9637e-01, -1.1697e-01,  ..., -1.1065e+00,
           -2.8916e-01,  2.0305e+00],
          [ 2.5941e+00,  3.8134e-01, -2.0759e+00,  ..., -5.1752e-01,
            1.2203e+00,  6.6424e-01]],

         [[ 1.5178e+00, -1.4107e-01,  2.0787e-02,  ..., -2.7419e-01,
            1.7594e+00,  3.4112e-01],
          [ 2.2713e+00,  1.2645e+00,  1.4167e-01,  ..., -1.3624e+00,
           -3.9471e-01, -1.0551e+00],
          [-5.6896e-01,  3.2535e-01,  6.7375e-01,  ..., -4.3356e-01,
           -8.9049e-01,  6.2434e-01],
          ...,
          [ 1.1771e+00, -1.5237e+00,  1.4692e+00,  ..., -9.1008e-01,
            2.4977e-01,  1.8976e-01],
          [ 3.5521e-01,  4.4842e-01, -2.1372e-01,  ..., -3.1800e-01,
            5.7764e-01,  9.8962e-01],
          [ 2.4482e-01, -3.0076e-01, -1.5345e+00,  ...,  6.1054e-01,
            1.1108e-01,  5.7744e-01]]],


        [[[ 3.5499e-01, -1.6768e-01, -9.9104e-01,  ...,  1.4003e+00,
            3.3791e-01,  2.4207e-01],
          [ 1.0070e+00, -3.8741e-01, -1.1791e+00,  ...,  1.4492e+00,
            9.9354e-02,  5.7079e-01],
          [-3.6632e-02,  1.4448e+00,  7.6475e-01,  ..., -7.4662e-01,
           -3.9348e-01, -1.4111e-01],
          ...,
          [-1.1064e+00, -1.3870e+00, -1.2940e+00,  ..., -6.7773e-01,
           -5.1018e-01, -1.3664e+00],
          [ 6.0045e-01, -4.3023e-01,  9.9434e-01,  ..., -1.5988e-01,
           -2.9792e-01,  1.6647e+00],
          [ 1.2545e+00, -7.1398e-01,  2.0364e-01,  ..., -7.2289e-01,
            3.9850e-01, -5.7919e-01]],

         [[ 2.8893e-01, -5.9618e-01, -5.6142e-01,  ...,  9.4555e-01,
            5.1033e-01, -1.1653e+00],
          [ 9.2532e-01, -1.4905e+00, -2.7413e-02,  ..., -1.5008e-01,
           -1.6422e-01, -8.4218e-01],
          [-2.4663e-01, -7.1457e-02,  1.0222e-01,  ..., -1.2784e-01,
            6.7937e-01,  5.0327e-01],
          ...,
          [ 1.9103e+00, -4.7451e-01,  1.0062e+00,  ...,  8.2645e-01,
            3.6270e-01, -1.2332e+00],
          [ 6.6634e-01, -9.9578e-02, -9.6929e-01,  ...,  4.0602e-01,
            1.3630e+00,  4.4063e-01],
          [ 6.1514e-01,  1.0401e+00,  6.5788e-01,  ...,  3.3027e-01,
           -4.7771e-01, -2.7703e-01]],

         [[ 1.7985e+00,  1.4773e+00,  2.3674e-01,  ...,  4.8540e-01,
            1.1563e+00,  4.9573e-01],
          [-1.2525e+00,  1.7834e-01, -1.8430e+00,  ...,  1.2985e+00,
            1.8026e-01,  8.6401e-01],
          [-5.1601e-01, -2.4202e-01,  3.4039e-01,  ...,  1.6713e+00,
           -1.7097e-01,  1.8218e+00],
          ...,
          [-1.0369e+00, -5.0906e-01,  5.6614e-01,  ...,  1.6036e+00,
            2.7071e-01, -8.3267e-01],
          [ 2.7301e-01,  6.6019e-01,  8.5434e-01,  ..., -4.9730e-01,
            9.4621e-01,  1.2293e+00],
          [ 1.4025e+00, -1.2155e+00, -1.2006e-01,  ...,  5.3345e-01,
           -5.7231e-01,  1.1151e+00]]]])}, 'action': tensor([[6],
        [2],
        [1],
        [9]]), 'reward': tensor([[0.9627],
        [0.5962],
        [0.3364],
        [0.1356]]), 'done': tensor([False, False, False, False])}

We can see that the structure process function need to be fully implemented (like the function stack). This code is actually not clear, and due to hard coding, if you need to support more data types (such as integer), you must make special modifications to the function.

Stack With TreeTensor API

The same workflow can be implemented with treetensor API like the code below.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch

import treetensor.torch as ttorch

B = 4


def get_item():
    return {
        'obs': {
            'scalar': torch.randn(12),
            'image': torch.randn(3, 32, 32),
        },
        'action': torch.randint(0, 10, size=(1,)),
        'reward': torch.rand(1),
        'done': False,
    }


data = [get_item() for _ in range(B)]
# execute `stack` op
data = [ttorch.tensor(d) for d in data]
stacked_data = ttorch.stack(data, dim=0)
# validate
print(stacked_data)
assert stacked_data.obs.image.shape == (B, 3, 32, 32)
assert stacked_data.action.shape == (B, 1)
assert stacked_data.reward.shape == (B, 1)
assert stacked_data.done.shape == (B,)
assert stacked_data.done.dtype == torch.bool

The output should be like below, and the assertion statements can be all passed as well.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
<Tensor 0x7f854d56e190>
├── 'action' --> tensor([[4],
│                        [1],
│                        [7],
│                        [0]])
├── 'done' --> tensor([False, False, False, False])
├── 'obs' --> <Tensor 0x7f854d56e2b0>
│   ├── 'image' --> tensor([[[[ 2.7768e-01,  7.2269e-01, -1.3915e+00,  ..., -1.1070e+00,
│   │                          -1.5589e+00,  3.4227e-01],
│   │                         [ 7.1008e-01,  7.1199e-02, -1.3941e+00,  ..., -1.5262e-02,
│   │                           1.9739e+00,  3.8984e-01],
│   │                         [ 3.2797e-01,  9.5656e-01, -1.0387e+00,  ...,  9.4736e-01,
│   │                          -1.2969e-01, -4.4648e-01],
│   │                         ...,
│   │                         [-1.5421e+00, -1.4744e+00,  1.1028e+00,  ...,  1.2473e-01,
│   │                          -2.8918e-01, -1.5246e+00],
│   │                         [ 7.2730e-01, -6.7164e-01,  9.1795e-01,  ..., -7.4616e-01,
│   │                          -1.2868e+00, -1.0479e+00],
│   │                         [ 2.5345e+00,  1.8525e+00,  7.8746e-01,  ...,  7.3047e-01,
│   │                           1.0833e-01,  9.9055e-02]],
│   │               
│   │                        [[ 1.3207e+00,  7.9623e-01, -4.6628e-01,  ...,  5.0219e-01,
│   │                          -9.4181e-01,  4.3514e-02],
│   │                         [ 6.1025e-02, -5.9720e-01,  1.1635e-01,  ..., -1.1111e-01,
│   │                          -3.1795e-01, -2.7077e-01],
│   │                         [-3.2185e-01,  7.2285e-01, -1.1239e+00,  ..., -2.3607e-01,
│   │                          -1.2155e-01,  2.0366e+00],
│   │                         ...,
│   │                         [-1.7039e+00, -2.2239e-01,  7.4453e-01,  ...,  5.3811e-01,
│   │                           1.8190e-01, -3.8538e-01],
│   │                         [-1.2629e+00, -1.3389e+00,  1.8053e+00,  ...,  2.9688e-01,
│   │                           5.8880e-01,  3.9439e-01],
│   │                         [ 4.0840e-01, -7.5275e-01,  1.1271e+00,  ..., -1.7865e-01,
│   │                           2.6663e-01,  3.6330e-01]],
│   │               
│   │                        [[ 8.9963e-01,  1.8445e+00, -9.2875e-01,  ..., -7.3726e-01,
│   │                           1.6722e+00, -1.3347e+00],
│   │                         [-3.5709e-01,  7.4005e-01, -1.0521e+00,  ..., -1.2865e+00,
│   │                           1.6662e+00,  1.8456e+00],
│   │                         [ 1.8223e+00, -1.0299e+00,  1.9987e+00,  ..., -1.5522e+00,
│   │                          -7.9742e-02,  1.6922e+00],
│   │                         ...,
│   │                         [ 4.8612e-01,  1.4085e-01,  1.7902e+00,  ...,  1.4983e+00,
│   │                           4.2752e-01, -2.7627e+00],
│   │                         [ 1.4083e+00,  2.6287e-01, -2.5619e-02,  ..., -7.2635e-01,
│   │                          -2.1798e-01, -1.5597e+00],
│   │                         [-1.0473e-01,  9.4573e-01,  1.7007e+00,  ...,  6.0289e-01,
│   │                          -4.8393e-01,  3.4755e-01]]],
│   │               
│   │               
│   │                       [[[ 7.7897e-01,  1.1320e-01, -2.1114e-01,  ...,  1.7084e+00,
│   │                           1.0117e-01,  7.8978e-03],
│   │                         [-4.8677e-01,  2.3906e-01,  7.7334e-01,  ...,  4.6409e-01,
│   │                          -4.7031e-01, -3.6635e-01],
│   │                         [-7.5587e-02,  1.5644e+00,  2.3391e-01,  ...,  6.7585e-01,
│   │                          -2.8346e-01,  5.6481e-01],
│   │                         ...,
│   │                         [ 2.3618e-01, -1.5815e-01, -9.7550e-01,  ...,  5.8330e-01,
│   │                           1.1823e+00,  1.0578e+00],
│   │                         [-7.8364e-01, -1.5866e+00, -1.6449e+00,  ...,  1.0233e+00,
│   │                          -5.0236e-01,  5.4849e-01],
│   │                         [ 4.4972e-01, -1.4466e+00, -1.8819e+00,  ...,  4.3116e-01,
│   │                          -1.2527e-01, -1.1814e+00]],
│   │               
│   │                        [[ 1.2845e+00,  1.2730e+00,  2.7647e-01,  ..., -5.2878e-01,
│   │                          -9.5056e-01,  1.8760e+00],
│   │                         [-1.6317e+00, -4.7470e-01, -2.4282e-01,  ..., -1.2668e+00,
│   │                          -1.8132e+00, -1.5144e-01],
│   │                         [-1.4870e+00, -3.3494e-01,  2.5189e+00,  ...,  2.7161e-01,
│   │                          -7.0971e-02, -1.2140e+00],
│   │                         ...,
│   │                         [ 8.2977e-01, -5.2122e-01, -1.6120e+00,  ..., -4.2821e-01,
│   │                           1.0515e+00,  4.3227e-01],
│   │                         [-6.4725e-01, -5.5695e-01, -6.4863e-01,  ...,  6.5223e-01,
│   │                          -9.5231e-01,  2.2416e+00],
│   │                         [-1.1557e+00, -1.2960e-01,  1.0956e-01,  ..., -2.4274e-01,
│   │                          -4.7423e-01, -6.8311e-01]],
│   │               
│   │                        [[ 7.8800e-01,  4.3260e-01, -1.2404e+00,  ...,  3.4317e-01,
│   │                           3.2559e-01,  6.1748e-01],
│   │                         [ 8.9122e-01,  8.0310e-01,  1.0648e+00,  ..., -1.0518e+00,
│   │                          -1.8672e+00,  4.7385e-01],
│   │                         [ 8.7452e-01, -6.2313e-02,  2.5838e+00,  ...,  3.1579e-01,
│   │                           8.5292e-01, -1.1236e+00],
│   │                         ...,
│   │                         [-9.7012e-01, -1.7455e+00, -7.3819e-01,  ...,  2.9028e-01,
│   │                          -2.8360e+00,  1.1007e+00],
│   │                         [ 1.0910e+00, -4.5006e-01,  4.0116e-01,  ..., -1.4809e+00,
│   │                          -1.2549e-01, -1.8702e+00],
│   │                         [ 7.5707e-01, -1.1873e+00,  2.7364e-01,  ...,  2.6473e-01,
│   │                          -4.0248e-02,  5.4633e-01]]],
│   │               
│   │               
│   │                       [[[-3.2781e-01, -1.2978e+00,  6.0700e-01,  ...,  3.1904e-01,
│   │                           1.3643e+00,  5.0313e-01],
│   │                         [-4.0699e-01,  1.9325e-01,  2.8137e-03,  ...,  9.2814e-01,
│   │                          -1.7251e+00,  1.0935e-01],
│   │                         [ 1.1301e+00,  4.1916e-01, -8.0110e-01,  ..., -3.0072e-01,
│   │                          -5.7206e-01,  9.5668e-01],
│   │                         ...,
│   │                         [ 1.4335e+00, -2.7020e-01, -3.1803e-01,  ...,  7.9595e-01,
│   │                           1.2687e-01,  2.0036e+00],
│   │                         [ 1.4315e+00,  1.2912e-01, -1.4107e+00,  ...,  7.7235e-01,
│   │                           4.2746e-01, -4.7502e-01],
│   │                         [ 7.4702e-01,  1.7163e+00, -5.1897e-01,  ...,  8.4231e-01,
│   │                          -3.6242e-01,  1.3959e+00]],
│   │               
│   │                        [[-5.5850e-01, -3.1159e-01, -3.1202e-01,  ..., -2.1453e-01,
│   │                          -8.6975e-01, -5.8275e-01],
│   │                         [-2.0519e+00,  1.5186e+00,  1.3738e+00,  ...,  1.6691e+00,
│   │                           1.1561e+00, -1.4051e-01],
│   │                         [ 3.0153e-01, -1.6306e+00, -4.4136e-01,  ...,  5.8454e-01,
│   │                           2.4728e-01,  9.1023e-01],
│   │                         ...,
│   │                         [-8.7527e-01,  1.2645e+00, -1.1091e+00,  ...,  7.5054e-01,
│   │                           1.3943e+00,  1.8248e-01],
│   │                         [-2.7783e-01,  1.0737e+00, -3.7681e-01,  ..., -1.3046e+00,
│   │                          -3.7334e-02,  8.9705e-02],
│   │                         [ 1.0340e-01,  4.7047e-01,  6.6843e-01,  ..., -2.1243e+00,
│   │                          -4.4899e-01,  1.4269e+00]],
│   │               
│   │                        [[-1.8311e+00,  1.9211e-01,  3.7694e-01,  ..., -7.3557e-01,
│   │                           6.5412e-01,  2.9185e-01],
│   │                         [ 9.6965e-01,  2.5162e-01,  7.9837e-01,  ...,  1.6130e-01,
│   │                          -1.1855e+00, -7.1552e-01],
│   │                         [-1.8326e+00,  2.9881e-01,  9.2038e-01,  ...,  3.0970e-01,
│   │                           2.2107e-01,  6.4244e-01],
│   │                         ...,
│   │                         [ 2.4874e+00, -1.5158e+00,  3.3830e-01,  ..., -1.7782e-01,
│   │                           1.0157e+00, -1.0347e-01],
│   │                         [ 4.5586e-01,  2.2183e-01, -3.2258e-01,  ..., -1.2041e+00,
│   │                           3.6287e-01, -3.2270e+00],
│   │                         [-2.9722e-01,  1.2296e+00, -2.2016e+00,  ...,  1.9688e+00,
│   │                           1.5530e+00, -4.2646e-01]]],
│   │               
│   │               
│   │                       [[[-6.1490e-01, -1.1774e+00, -2.8286e-01,  ...,  1.9360e+00,
│   │                          -1.3951e+00,  3.5584e-01],
│   │                         [-7.9607e-01,  5.4468e-01, -1.1136e+00,  ..., -6.3716e-01,
│   │                          -3.3734e-01,  5.3892e-01],
│   │                         [-7.1590e-01, -5.7831e-01,  9.4852e-01,  ..., -6.1943e-01,
│   │                          -1.2086e+00,  1.8259e+00],
│   │                         ...,
│   │                         [-1.4050e+00, -9.4474e-01,  1.0682e+00,  ..., -9.8072e-01,
│   │                           7.2787e-01,  1.1122e+00],
│   │                         [-1.1221e+00, -1.0004e+00,  1.1748e+00,  ...,  2.0072e-01,
│   │                           7.5808e-01, -1.1780e+00],
│   │                         [-4.8905e-01,  2.2637e+00, -5.7236e-01,  ..., -3.0184e-01,
│   │                          -8.9296e-01, -7.1598e-01]],
│   │               
│   │                        [[ 4.1308e-01,  1.1907e-01,  1.3611e-01,  ...,  4.4816e-01,
│   │                          -8.3799e-01,  9.1052e-01],
│   │                         [ 1.4973e+00, -1.4097e+00,  1.1275e+00,  ..., -1.6520e+00,
│   │                           4.3156e-01,  1.0900e+00],
│   │                         [-1.0643e-01,  1.1581e+00, -1.2586e+00,  ...,  1.7217e+00,
│   │                          -4.0327e-01, -2.0471e+00],
│   │                         ...,
│   │                         [-5.7999e-01, -2.5171e-01,  1.7430e+00,  ..., -1.0088e+00,
│   │                          -1.0629e-01, -1.9427e-01],
│   │                         [ 1.8873e+00,  5.0113e-03,  1.8853e+00,  ..., -4.2813e-01,
│   │                          -5.2380e-01,  4.6839e-01],
│   │                         [ 2.9441e+00,  3.2780e-01,  1.0412e-01,  ..., -9.9263e-01,
│   │                          -2.2214e-01, -1.4519e-01]],
│   │               
│   │                        [[ 8.0442e-01,  6.2019e-01, -1.1874e+00,  ...,  7.4780e-01,
│   │                          -2.1727e-01,  7.3025e-01],
│   │                         [ 6.2545e-01,  9.4898e-01, -3.6161e-01,  ...,  2.2039e-02,
│   │                           1.4667e+00, -1.3711e+00],
│   │                         [-1.0311e+00, -2.6868e-01, -2.9878e-01,  ...,  5.9559e-01,
│   │                           5.7911e-01,  1.2790e+00],
│   │                         ...,
│   │                         [ 1.0289e-01, -9.9213e-01,  1.2444e+00,  ...,  8.2043e-01,
│   │                          -1.3571e+00, -4.7111e-01],
│   │                         [-2.0202e-01, -1.6877e-01, -1.2767e+00,  ...,  5.2516e-02,
│   │                           2.5144e-01,  1.5627e+00],
│   │                         [ 7.8594e-01, -1.4140e+00, -6.1526e-01,  ..., -1.6988e+00,
│   │                          -1.5050e+00,  5.9991e-01]]]])
│   └── 'scalar' --> tensor([[-2.2491e+00, -1.0993e+00, -1.3367e+00,  4.5959e-02, -3.2067e-01,
│                              1.0183e+00, -4.6599e-01,  2.0997e+00, -2.3903e-01,  1.6686e-02,
│                             -1.7695e-01,  1.3608e+00],
│                            [-6.9283e-01, -1.3584e+00, -1.4731e+00,  1.0708e-01, -3.0125e-02,
│                             -4.2124e-01,  1.2719e-01, -1.2341e+00, -6.2120e-01, -7.3276e-01,
│                             -5.6019e-01, -6.2395e-01],
│                            [ 5.6889e-01, -4.4921e-04,  8.7905e-01, -1.3959e+00, -1.2668e-02,
│                              5.1552e-01, -1.0081e+00,  6.8289e-01,  1.6223e+00,  8.8416e-02,
│                              5.7134e-01,  6.2713e-01],
│                            [-1.5665e+00,  1.6148e+00, -1.7422e+00, -1.9278e+00,  8.8783e-01,
│                             -1.1776e+00, -8.7499e-01, -1.3110e+00,  1.7290e+00, -2.9704e-01,
│                             -1.4291e+00, -5.5408e-01]])
└── 'reward' --> tensor([[0.3571],
                         [0.4010],
                         [0.6740],
                         [0.6191]])

This code looks much simpler and clearer.