Stack Structured Data

When the tensors form the tree structures, they are often needed to be stacked together, like the torch.stack() implemented in torch.

Stack With Native PyTorch API

Here is the common code implement with native pytorch API.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch

B = 4


def get_item():
    return {
        'obs': {
            'scalar': torch.randn(12),
            'image': torch.randn(3, 32, 32),
        },
        'action': torch.randint(0, 10, size=(1,)),
        'reward': torch.rand(1),
        'done': False,
    }


data = [get_item() for _ in range(B)]


# execute `stack` op
def stack(data, dim):
    elem = data[0]
    if isinstance(elem, torch.Tensor):
        return torch.stack(data, dim)
    elif isinstance(elem, dict):
        return {k: stack([item[k] for item in data], dim) for k in elem.keys()}
    elif isinstance(elem, bool):
        return torch.BoolTensor(data)
    else:
        raise TypeError("not support elem type: {}".format(type(elem)))


stacked_data = stack(data, dim=0)
# validate
print(stacked_data)
assert stacked_data['obs']['image'].shape == (B, 3, 32, 32)
assert stacked_data['action'].shape == (B, 1)
assert stacked_data['reward'].shape == (B, 1)
assert stacked_data['done'].shape == (B,)
assert stacked_data['done'].dtype == torch.bool

The output should be like below, and the assertion statements can be all passed.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
{'obs': {'scalar': tensor([[-0.2840, -0.3634,  0.9521, -0.5149,  0.0987, -0.7287,  0.1551, -0.6975,
         -0.4657,  0.6324, -0.4855,  0.2009],
        [ 1.7019, -0.7905,  1.1281,  1.6868,  0.9590,  0.0241, -0.2249, -0.7365,
         -1.6521,  0.2435, -0.4900, -0.3369],
        [-0.4088, -0.4597,  1.0224,  0.8411, -1.3451,  1.5354,  0.6608, -0.9282,
         -0.7652, -0.7484,  0.5435, -0.1357],
        [-0.6685, -0.2941, -1.5426,  2.2074,  1.6867, -0.4392, -1.4887, -0.9353,
          0.2300, -1.5421, -1.0665, -0.1565]]), 'image': tensor([[[[-4.5121e-02, -2.2680e-01, -1.1164e+00,  ...,  7.6033e-02,
            1.6391e-01, -2.9676e-01],
          [ 4.3215e-03,  1.0790e+00,  3.5419e-01,  ..., -2.0342e+00,
           -7.6382e-01,  6.5797e-01],
          [ 6.9889e-01,  9.0356e-01, -2.3687e-01,  ...,  4.4673e-01,
            1.0851e-01,  1.3564e+00],
          ...,
          [ 7.8343e-01, -2.0106e+00, -5.6503e-01,  ..., -2.1530e+00,
           -4.7780e-01, -2.2156e+00],
          [-1.0998e+00, -2.4053e+00,  5.7040e-01,  ...,  3.6227e-01,
           -4.2776e-01, -3.9227e-01],
          [-1.0050e+00,  8.0163e-01,  1.1930e+00,  ...,  2.4211e-01,
           -1.5429e+00, -1.6545e+00]],

         [[-2.6870e-01,  2.2828e+00, -1.7400e+00,  ..., -1.5348e-01,
           -3.2487e-01, -1.4432e-01],
          [-6.2969e-01,  9.8415e-01,  1.8722e-01,  ...,  9.5278e-01,
           -6.7011e-01,  1.7960e+00],
          [ 2.8753e-01, -1.8278e+00,  3.9951e-01,  ..., -1.8179e+00,
           -4.4990e-01, -1.3072e-01],
          ...,
          [-6.7291e-01, -6.7803e-04, -4.7174e-02,  ...,  5.5050e-01,
           -2.7602e+00, -5.0209e-01],
          [-9.0710e-02, -3.6761e-01,  1.5800e+00,  ...,  1.1335e+00,
            2.0917e+00,  6.3119e-01],
          [ 4.9747e-01,  2.2614e+00,  9.1086e-01,  ..., -4.0391e-01,
            4.1297e-01, -2.2987e-01]],

         [[-6.2910e-01,  1.1425e+00,  9.8083e-01,  ..., -1.3785e+00,
            9.5377e-01, -1.4359e+00],
          [-1.6253e+00, -5.2809e-01,  2.0963e+00,  ..., -7.6894e-01,
            8.3685e-01, -3.1836e-01],
          [-5.2703e-02, -1.4109e+00, -5.8247e-01,  ...,  1.3680e+00,
            7.2394e-01,  1.8591e-01],
          ...,
          [-4.7122e-02,  1.7136e+00, -4.6096e-01,  ..., -1.9979e+00,
           -9.3381e-02, -1.8513e-01],
          [-1.4061e+00,  3.4088e-01,  3.2430e+00,  ..., -1.4858e+00,
           -4.5998e-01, -9.2384e-01],
          [-7.5114e-01,  2.1499e-01, -8.6539e-01,  ...,  5.0539e-01,
           -3.2902e-01,  6.3032e-02]]],


        [[[-8.5675e-01, -9.3711e-01, -1.5717e+00,  ..., -1.6073e+00,
           -5.6652e-01,  1.8551e-01],
          [ 6.1701e-01, -4.4643e-01,  2.4150e-01,  ..., -2.0915e-01,
            2.4090e+00, -1.5672e+00],
          [-5.3781e-01,  8.7518e-01,  1.0100e+00,  ..., -1.1971e+00,
            9.3083e-01,  1.3852e+00],
          ...,
          [ 1.3390e-01,  1.2455e+00,  1.6701e+00,  ..., -6.5577e-01,
            3.0499e-01,  1.0548e+00],
          [ 7.5220e-01,  3.1135e-01,  4.1437e-01,  ..., -2.7704e-01,
           -4.3294e-01,  3.9899e-01],
          [-6.9253e-01,  2.8408e-02,  1.6630e+00,  ..., -2.2828e-01,
            9.2207e-01, -2.9157e-01]],

         [[-1.7668e-01, -7.6933e-01, -2.1550e-01,  ...,  2.1382e-01,
            3.9757e-01,  1.1892e+00],
          [-2.7735e-01,  9.4627e-01,  4.7046e-01,  ..., -2.2222e-01,
            1.2273e-01, -8.4070e-01],
          [ 3.9327e-02, -6.7895e-01,  3.4781e-01,  ...,  1.4598e+00,
            1.2063e+00,  1.2077e+00],
          ...,
          [ 1.4399e+00, -1.2146e+00,  3.1091e-01,  ...,  4.6194e-01,
           -4.5527e-01, -4.8440e-01],
          [ 8.7319e-01,  1.8735e+00,  1.0735e+00,  ..., -9.9458e-02,
           -8.5552e-02, -9.7367e-01],
          [ 9.1713e-01, -4.1817e-01,  1.0678e+00,  ...,  7.7337e-01,
           -6.8728e-01,  1.2765e+00]],

         [[ 1.8413e-02,  2.3515e-02, -2.9989e-01,  ...,  3.7097e-01,
            2.9614e-02, -2.7109e+00],
          [-2.9631e-01,  6.1463e-01, -1.1221e+00,  ...,  2.3268e+00,
           -4.9163e-01, -9.2458e-01],
          [ 2.6203e-01, -3.1708e-01,  1.5906e+00,  ...,  4.7732e-01,
           -1.1893e-01, -2.8443e-03],
          ...,
          [ 1.8601e-01,  5.8114e-01,  9.0267e-02,  ...,  1.4884e+00,
           -1.7130e-01,  4.2668e-01],
          [ 3.5927e-01,  4.6590e-01,  1.2273e+00,  ..., -9.9198e-01,
           -1.2984e+00, -1.0893e+00],
          [ 1.4063e+00,  1.5053e+00,  1.5832e+00,  ..., -3.4958e-01,
            8.6656e-01,  1.0481e-01]]],


        [[[-8.6967e-01,  3.5614e-01, -2.1566e-01,  ..., -4.6409e-01,
           -1.3361e+00, -6.0840e-01],
          [ 4.0915e-02,  1.2063e+00,  9.2919e-01,  ..., -1.8401e+00,
           -1.7136e+00, -9.0794e-01],
          [-3.9051e-01,  1.1263e-01,  9.1553e-01,  ..., -9.0324e-01,
           -9.4214e-01, -1.0102e+00],
          ...,
          [ 3.9434e-01,  2.0142e+00, -1.0307e+00,  ..., -3.4609e-01,
           -5.5551e-01, -5.1673e-01],
          [ 4.7953e-01,  2.0993e-01, -4.5370e-01,  ..., -6.5955e-01,
           -4.9829e-01, -6.5444e-01],
          [-1.7510e+00,  4.4026e-01, -4.6522e-01,  ..., -1.3531e+00,
           -3.1418e-01,  5.7820e-01]],

         [[ 1.1290e+00,  7.7504e-02,  1.8905e+00,  ...,  6.4466e-01,
           -1.3985e+00,  2.1244e+00],
          [ 1.4552e+00, -1.0331e+00,  9.4305e-01,  ..., -3.2792e-01,
           -7.6510e-01,  1.0919e+00],
          [-1.8828e+00, -6.7343e-01,  4.4475e-02,  ...,  1.4342e+00,
           -1.7831e+00,  8.9851e-01],
          ...,
          [ 1.5444e+00, -6.4220e-02, -1.6125e+00,  ...,  8.3701e-01,
            1.1188e+00,  8.5849e-01],
          [-1.1280e+00,  3.4798e-01, -9.4200e-01,  ...,  2.0550e-01,
            7.2233e-01, -2.0604e+00],
          [-1.1670e+00,  1.5780e-01,  6.6756e-01,  ...,  2.0367e-01,
           -4.9155e-02, -7.3329e-01]],

         [[ 7.1519e-01,  1.5907e+00,  6.0335e-01,  ..., -2.4679e-01,
           -2.2060e-01, -7.4976e-01],
          [ 7.4181e-01,  1.6541e-01, -2.7937e-01,  ..., -1.5639e+00,
           -1.2085e+00,  2.1309e-01],
          [ 6.6968e-01,  3.8764e-01, -7.9096e-02,  ...,  2.1396e+00,
           -2.1531e+00,  3.8993e-01],
          ...,
          [-7.6772e-01,  1.0094e+00, -1.0043e+00,  ...,  7.7668e-02,
           -9.0318e-01,  2.2951e+00],
          [ 6.6404e-01, -1.1080e+00,  2.1678e+00,  ...,  1.0081e+00,
           -7.7951e-01,  8.3822e-01],
          [-6.0536e-01,  1.0429e+00,  5.1661e-01,  ...,  7.3517e-01,
            4.5096e-01, -1.8055e-01]]],


        [[[-9.9052e-02, -5.0342e-01, -6.4595e-01,  ..., -2.6159e+00,
            1.6636e-01,  7.9794e-01],
          [ 9.2378e-01, -9.2294e-01, -7.6925e-01,  ...,  3.8585e-01,
           -2.1103e+00, -2.0487e-01],
          [ 1.3706e+00,  8.4382e-01,  2.5089e-01,  ...,  5.7964e-01,
            7.2823e-02,  1.8928e+00],
          ...,
          [-2.2277e+00,  5.2174e-01, -1.7592e+00,  ..., -1.7294e+00,
           -9.5507e-01,  8.6874e-01],
          [-1.0880e+00,  1.8372e+00,  2.9192e-01,  ...,  4.3466e-01,
           -1.0732e+00,  5.1229e-01],
          [ 1.4222e-01, -1.4210e-01,  7.0661e-01,  ...,  2.6488e-01,
           -3.1773e-01, -1.4608e+00]],

         [[-1.6424e+00,  3.2783e-02,  7.7209e-01,  ..., -2.0753e+00,
            6.8555e-01, -3.2479e-01],
          [ 6.3090e-01,  2.5724e+00, -1.8561e+00,  ..., -2.3817e-01,
           -1.3254e+00, -1.0785e+00],
          [ 1.0413e+00,  7.2824e-01,  5.2031e-01,  ..., -3.8122e-02,
           -2.0400e+00,  9.1134e-01],
          ...,
          [ 6.8952e-02,  1.9344e-01,  1.2753e+00,  ..., -3.5823e-01,
            3.1901e-01,  2.2574e-01],
          [-1.2684e+00, -1.3359e+00, -1.6363e+00,  ..., -2.2908e-01,
            2.1934e+00, -1.4993e-01],
          [-2.8702e-01,  5.0642e-01,  1.1062e+00,  ..., -8.0505e-01,
           -5.0712e-01,  5.4565e-01]],

         [[-1.9153e+00,  3.9362e-02,  4.3851e-02,  ...,  5.0572e-01,
           -1.9208e-01, -7.8702e-01],
          [-2.0129e-01,  2.0403e-01, -1.2106e+00,  ..., -1.6148e-01,
            7.6110e-01,  4.4937e-01],
          [ 1.3093e-01, -6.4560e-01, -2.2511e-01,  ..., -2.2961e+00,
            7.1832e-01, -1.1351e-01],
          ...,
          [-5.6873e-01,  5.0049e-01,  2.7740e+00,  ..., -9.9026e-02,
            1.0941e+00, -6.1444e-01],
          [ 2.0034e+00,  2.7231e-01, -3.6255e-01,  ..., -3.6286e-01,
           -5.1123e-01, -8.3009e-01],
          [ 1.5741e-01,  6.8545e-01, -1.9567e-01,  ..., -9.8548e-02,
           -9.9199e-01,  1.9214e+00]]]])}, 'action': tensor([[6],
        [8],
        [6],
        [3]]), 'reward': tensor([[0.3085],
        [0.9404],
        [0.3143],
        [0.2650]]), 'done': tensor([False, False, False, False])}

We can see that the structure process function need to be fully implemented (like the function stack). This code is actually not clear, and due to hard coding, if you need to support more data types (such as integer), you must make special modifications to the function.

Stack With TreeTensor API

The same workflow can be implemented with treetensor API like the code below.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch

import treetensor.torch as ttorch

B = 4


def get_item():
    return {
        'obs': {
            'scalar': torch.randn(12),
            'image': torch.randn(3, 32, 32),
        },
        'action': torch.randint(0, 10, size=(1,)),
        'reward': torch.rand(1),
        'done': False,
    }


data = [get_item() for _ in range(B)]
# execute `stack` op
data = [ttorch.tensor(d) for d in data]
stacked_data = ttorch.stack(data, dim=0)
# validate
print(stacked_data)
assert stacked_data.obs.image.shape == (B, 3, 32, 32)
assert stacked_data.action.shape == (B, 1)
assert stacked_data.reward.shape == (B, 1)
assert stacked_data.done.shape == (B,)
assert stacked_data.done.dtype == torch.bool

The output should be like below, and the assertion statements can be all passed as well.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
<Tensor 0x7f5db34cab50>
├── 'action' --> tensor([[7],
│                        [3],
│                        [7],
│                        [4]])
├── 'done' --> tensor([False, False, False, False])
├── 'obs' --> <Tensor 0x7f5e705f3b20>
│   ├── 'image' --> tensor([[[[ 6.0491e-01,  6.2035e-01, -1.4157e+00,  ..., -1.1852e+00,
│   │                          -2.9777e-01,  1.3498e+00],
│   │                         [-3.9143e-01,  2.2735e+00,  4.6481e-01,  ..., -2.6099e-01,
│   │                          -4.2337e-01, -1.0870e+00],
│   │                         [-9.5710e-01,  7.8952e-01, -1.3639e+00,  ..., -3.9440e-01,
│   │                          -1.4676e+00, -1.1285e+00],
│   │                         ...,
│   │                         [-9.2814e-01,  2.0619e+00,  6.5631e-01,  ..., -2.7601e+00,
│   │                           3.8497e-01, -1.6172e+00],
│   │                         [ 1.5359e+00,  3.3934e-01, -1.8725e+00,  ..., -1.0296e+00,
│   │                          -6.7341e-01,  2.4159e-01],
│   │                         [-3.9835e-01,  1.2958e+00,  6.1941e-01,  ..., -4.5101e-01,
│   │                           4.8585e-01,  1.3921e-01]],
│   │               
│   │                        [[ 2.2765e-01, -2.4643e-01, -2.1107e+00,  ...,  2.5699e+00,
│   │                           4.2809e-01, -1.1317e+00],
│   │                         [-9.8382e-01, -7.5618e-01,  9.2320e-01,  ..., -2.9370e+00,
│   │                          -1.1677e+00, -7.1868e-01],
│   │                         [ 1.2856e+00,  1.3135e+00,  1.5204e+00,  ..., -1.9319e-01,
│   │                          -2.8591e+00, -3.7769e-01],
│   │                         ...,
│   │                         [-1.7890e+00, -1.4832e+00, -1.0813e+00,  ..., -7.4906e-01,
│   │                           7.0562e-03, -1.6395e+00],
│   │                         [-8.6999e-01,  5.9350e-01,  1.0873e+00,  ...,  8.0207e-01,
│   │                           1.1662e-01,  5.9827e-01],
│   │                         [ 1.1007e+00,  3.0420e-01, -3.8219e-01,  ...,  1.2563e-01,
│   │                           2.4563e-02,  8.5376e-01]],
│   │               
│   │                        [[-1.5578e+00,  1.1582e+00, -1.0329e-02,  ..., -4.4652e-02,
│   │                          -1.4290e+00,  2.5182e-01],
│   │                         [ 9.2543e-01,  9.6208e-02,  4.3906e-01,  ..., -1.9175e-01,
│   │                          -4.9882e-01, -8.7490e-01],
│   │                         [ 5.4497e-01, -6.0647e-01,  7.1136e-01,  ...,  1.9480e+00,
│   │                           4.8616e-01, -4.0973e-02],
│   │                         ...,
│   │                         [ 1.2451e+00, -1.2325e+00, -1.1121e-01,  ...,  1.1665e+00,
│   │                           5.6378e-01,  1.8703e-01],
│   │                         [ 9.5143e-01, -3.8321e-01, -9.8722e-01,  ..., -8.2767e-01,
│   │                          -4.4235e-01,  7.4089e-01],
│   │                         [-6.8629e-01, -7.0036e-01, -7.0488e-01,  ...,  3.4360e-01,
│   │                          -4.7981e-01,  2.7686e-01]]],
│   │               
│   │               
│   │                       [[[ 1.2345e+00, -4.3545e-01,  3.8050e-01,  ..., -1.9059e-01,
│   │                          -2.3013e-01,  2.0597e-01],
│   │                         [ 1.3890e+00, -4.0402e-01, -8.4775e-01,  ..., -1.1564e+00,
│   │                           6.0894e-01,  7.8214e-01],
│   │                         [ 1.9876e+00, -1.0117e+00,  4.2992e-01,  ...,  2.7501e-01,
│   │                          -2.4991e-01, -5.2593e-01],
│   │                         ...,
│   │                         [-8.9472e-02, -9.1473e-01,  7.5796e-01,  ..., -9.4867e-01,
│   │                           2.7643e-01, -5.7385e-03],
│   │                         [ 1.9462e+00,  1.4143e+00, -6.2590e-01,  ..., -5.7457e-01,
│   │                          -1.6119e+00,  1.6538e+00],
│   │                         [ 1.1706e+00,  1.3673e+00, -5.8799e-01,  ...,  3.9720e-01,
│   │                           4.8804e-02, -5.4600e-01]],
│   │               
│   │                        [[ 1.0991e+00, -1.4352e+00, -1.5521e-02,  ...,  3.9412e-01,
│   │                           6.3778e-01,  1.0133e+00],
│   │                         [-1.6794e+00, -2.2764e-01, -8.3136e-01,  ..., -1.6760e-01,
│   │                           1.5344e-01,  1.2149e+00],
│   │                         [-1.6497e+00,  5.1475e-01, -7.9886e-01,  ..., -1.7998e-01,
│   │                           1.0497e-01, -8.3682e-01],
│   │                         ...,
│   │                         [-6.8097e-01,  1.4262e+00, -3.4783e-01,  ...,  2.2377e-01,
│   │                           1.3951e-01,  1.1335e+00],
│   │                         [ 2.9878e-03, -1.0000e+00,  1.0392e+00,  ..., -1.0403e+00,
│   │                          -9.5043e-01, -1.3742e+00],
│   │                         [ 2.0765e+00, -3.7355e-02, -5.3457e-01,  ..., -8.7327e-01,
│   │                          -5.8363e-01,  1.3274e+00]],
│   │               
│   │                        [[-4.2527e-01,  6.1542e-01,  1.5684e+00,  ..., -5.2865e-01,
│   │                           4.6098e-02, -2.1230e-01],
│   │                         [ 7.6377e-01, -3.4257e-01, -8.7081e-01,  ..., -9.6008e-01,
│   │                           3.9936e-01,  1.4612e-01],
│   │                         [-5.3492e-01, -5.1198e-01,  1.2734e-01,  ..., -2.1578e+00,
│   │                           1.8435e-01,  7.0498e-01],
│   │                         ...,
│   │                         [ 1.0484e+00,  2.0422e+00,  3.0042e+00,  ...,  9.1131e-01,
│   │                          -8.0172e-01, -1.9654e+00],
│   │                         [ 1.1605e+00,  7.3602e-01,  9.6949e-01,  ..., -5.1946e-01,
│   │                           9.8158e-01, -3.5733e-02],
│   │                         [ 7.2839e-01, -7.2189e-02, -3.5335e-01,  ..., -9.0549e-01,
│   │                          -6.9994e-01, -5.3743e-01]]],
│   │               
│   │               
│   │                       [[[ 3.1480e-01, -7.0003e-01,  1.1350e+00,  ..., -9.6894e-01,
│   │                          -5.3728e-01, -4.6991e-01],
│   │                         [ 1.6919e-03,  3.5995e-02,  2.3327e-02,  ...,  9.7522e-01,
│   │                           1.6597e+00, -1.4420e+00],
│   │                         [ 8.9391e-01, -9.9064e-01,  1.3430e+00,  ...,  1.5148e+00,
│   │                           5.9051e-01,  3.9515e-01],
│   │                         ...,
│   │                         [ 5.7241e-01, -9.4326e-01, -9.6073e-01,  ...,  7.8491e-02,
│   │                           3.6360e-01,  1.4721e+00],
│   │                         [-5.3042e-02, -8.7747e-01,  7.8472e-01,  ..., -2.3527e-02,
│   │                          -2.7887e-01, -5.3696e-01],
│   │                         [-2.3875e+00, -1.5519e+00, -2.2774e+00,  ..., -1.9712e+00,
│   │                           5.8805e-01,  1.2947e+00]],
│   │               
│   │                        [[ 2.1716e-02, -5.7773e-01, -2.3881e-01,  ...,  1.0357e-01,
│   │                          -1.7584e-01, -2.3245e-01],
│   │                         [-5.6528e-02, -3.3935e-01,  5.7944e-01,  ...,  8.8092e-01,
│   │                          -3.6546e-01, -1.2299e+00],
│   │                         [-2.7874e-01,  1.0380e+00,  9.8950e-01,  ...,  5.2625e-01,
│   │                           2.7703e-01, -5.0338e-01],
│   │                         ...,
│   │                         [ 6.0702e-01,  1.3428e+00, -1.7610e+00,  ..., -8.9197e-01,
│   │                          -9.9336e-01, -1.4691e+00],
│   │                         [-6.3594e-01, -2.0257e+00, -2.8456e+00,  ...,  2.3086e-01,
│   │                           9.4520e-02,  1.7244e+00],
│   │                         [-3.1241e-01,  5.9017e-01, -1.6905e-01,  ...,  1.5259e+00,
│   │                           3.5403e-01,  2.1085e+00]],
│   │               
│   │                        [[-5.7899e-01, -1.1956e+00, -8.2613e-01,  ..., -8.2445e-01,
│   │                           4.8076e-01,  2.0298e+00],
│   │                         [ 1.9169e+00, -8.0087e-01, -1.3284e-01,  ..., -2.4424e-01,
│   │                          -4.2067e-01,  1.6571e+00],
│   │                         [-1.2796e-01,  1.0960e-02,  1.6551e-01,  ...,  1.1280e-01,
│   │                          -9.9208e-02, -1.5824e-01],
│   │                         ...,
│   │                         [ 3.7393e-01,  6.5658e-01, -1.3129e+00,  ..., -2.4924e-01,
│   │                           2.3331e+00, -4.9498e-01],
│   │                         [-2.3894e-01, -3.1400e-01, -1.0016e+00,  ..., -3.4044e-02,
│   │                           2.1645e-01,  5.8772e-01],
│   │                         [ 1.0274e-01,  1.2117e-01,  4.8145e-01,  ...,  4.4367e-01,
│   │                          -7.8213e-01, -7.1751e-01]]],
│   │               
│   │               
│   │                       [[[ 4.3093e-01, -8.9663e-01,  9.7779e-01,  ..., -8.5037e-01,
│   │                          -2.1401e-01,  3.4561e-01],
│   │                         [ 7.6804e-02,  2.8668e-01, -2.4880e-01,  ...,  1.9435e-01,
│   │                          -2.1577e+00,  3.6904e-01],
│   │                         [ 1.3436e+00, -1.9077e+00,  2.0240e+00,  ..., -2.5718e-01,
│   │                          -3.5409e-01, -4.2510e-01],
│   │                         ...,
│   │                         [ 1.8166e-01,  1.5542e+00,  1.3729e+00,  ..., -1.1488e-01,
│   │                           9.5525e-01,  4.3516e-01],
│   │                         [-4.4698e-01,  3.2768e+00, -7.4483e-01,  ...,  1.6262e-01,
│   │                           6.8086e-01,  2.0269e-01],
│   │                         [-1.2208e+00,  5.2528e-01,  2.1715e-01,  ..., -4.4392e-01,
│   │                          -3.0471e-01, -7.4893e-01]],
│   │               
│   │                        [[ 2.0017e+00,  9.2287e-02, -5.8657e-01,  ..., -2.7426e-02,
│   │                          -8.2376e-01,  1.7340e-01],
│   │                         [ 4.3111e-01,  1.0809e+00,  8.3582e-02,  ..., -1.5822e+00,
│   │                           1.7136e+00, -3.4180e-01],
│   │                         [ 1.3194e-01,  5.7046e-01,  7.0476e-01,  ...,  1.5565e+00,
│   │                           9.5292e-01,  3.4529e-01],
│   │                         ...,
│   │                         [ 1.6582e-01, -2.5548e-01,  1.1658e+00,  ..., -6.1540e-02,
│   │                           1.3968e+00,  2.3458e-01],
│   │                         [ 8.3413e-01,  4.9961e-01, -7.6383e-01,  ..., -5.6934e-01,
│   │                           1.7811e+00,  2.8406e-01],
│   │                         [-2.0900e-01,  2.2198e-01,  2.0040e-02,  ...,  5.9347e-01,
│   │                          -1.2270e+00, -8.2168e-01]],
│   │               
│   │                        [[ 8.3771e-01, -1.3873e+00,  6.5440e-02,  ...,  5.2534e-01,
│   │                          -7.4022e-01, -1.8233e+00],
│   │                         [-4.0800e-02, -8.9228e-01, -1.5278e-01,  ..., -6.9822e-01,
│   │                          -2.9530e-02, -1.2972e+00],
│   │                         [ 6.7113e-01, -2.1716e-01, -4.5959e-01,  ..., -6.7453e-01,
│   │                           2.5205e+00,  1.1726e+00],
│   │                         ...,
│   │                         [-1.8662e-01,  1.1956e+00, -3.7570e-01,  ...,  1.7713e-01,
│   │                          -6.7643e-01, -2.1710e-01],
│   │                         [-8.9215e-01,  5.8553e-01,  2.6335e-01,  ...,  1.0285e-01,
│   │                           1.1924e+00,  8.9541e-01],
│   │                         [-1.8109e+00,  7.6420e-01, -1.5329e+00,  ..., -3.2874e-01,
│   │                           1.1832e+00,  4.3480e-01]]]])
│   └── 'scalar' --> tensor([[-0.2577, -1.3878,  0.0943, -0.9186,  1.5606,  2.7561,  0.9963,  1.6120,
│                             -0.3115,  1.4316,  0.5285,  1.1511],
│                            [-1.2707,  0.2827,  0.4468, -0.8755, -0.3063, -0.4284,  1.8732, -1.3902,
│                             -1.0280, -1.1760, -1.0110, -0.0351],
│                            [-0.4873,  1.1478,  1.5819,  0.3801, -0.4244,  0.4763, -0.7645, -0.8707,
│                              0.8633, -0.3963, -1.4016, -0.0655],
│                            [-1.9009, -0.4979,  1.3584,  0.1752,  1.2464,  0.5961,  0.6831, -1.0763,
│                             -2.0390,  0.5420, -0.6129,  1.5327]])
└── 'reward' --> tensor([[0.9951],
                         [0.1780],
                         [0.7281],
                         [0.4226]])

This code looks much simpler and clearer.