Stack Structured Data

When the tensors form the tree structures, they are often needed to be stacked together, like the torch.stack() implemented in torch.

Stack With Native PyTorch API

Here is the common code implement with native pytorch API.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch

B = 4


def get_item():
    return {
        'obs': {
            'scalar': torch.randn(12),
            'image': torch.randn(3, 32, 32),
        },
        'action': torch.randint(0, 10, size=(1,)),
        'reward': torch.rand(1),
        'done': False,
    }


data = [get_item() for _ in range(B)]


# execute `stack` op
def stack(data, dim):
    elem = data[0]
    if isinstance(elem, torch.Tensor):
        return torch.stack(data, dim)
    elif isinstance(elem, dict):
        return {k: stack([item[k] for item in data], dim) for k in elem.keys()}
    elif isinstance(elem, bool):
        return torch.BoolTensor(data)
    else:
        raise TypeError("not support elem type: {}".format(type(elem)))


stacked_data = stack(data, dim=0)
# validate
print(stacked_data)
assert stacked_data['obs']['image'].shape == (B, 3, 32, 32)
assert stacked_data['action'].shape == (B, 1)
assert stacked_data['reward'].shape == (B, 1)
assert stacked_data['done'].shape == (B,)
assert stacked_data['done'].dtype == torch.bool

The output should be like below, and the assertion statements can be all passed.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
{'obs': {'scalar': tensor([[-1.1621,  0.8622,  1.2698, -2.0655,  0.1466,  0.0613, -1.0280, -1.9768,
          0.8074,  0.7536,  1.3047,  0.1601],
        [-0.2898, -0.1905, -1.3473, -0.7830,  0.6574,  0.6699,  1.3833, -2.8235,
          0.8399,  0.0666,  0.4324,  0.7176],
        [-1.7322,  0.2681, -0.5126, -0.4104, -0.5424, -0.1381, -1.7524,  0.5402,
          0.4447,  1.9487,  0.8499, -0.3846],
        [ 0.1030, -0.0444,  0.4281,  0.2201,  1.5434,  1.0083, -0.0238,  0.1518,
          0.9856, -0.0491, -0.8394,  2.1917]]), 'image': tensor([[[[-3.5428e-01, -1.0710e+00, -2.1601e+00,  ...,  1.2529e-01,
            4.9434e-01,  3.2622e-01],
          [ 1.8345e+00,  2.5315e-01, -1.4162e+00,  ...,  5.8329e-01,
           -2.1226e-01, -8.0565e-01],
          [ 4.2671e-01,  6.1066e-02, -2.0281e+00,  ..., -1.3839e+00,
           -2.0064e+00, -1.2457e+00],
          ...,
          [-1.1036e-01, -3.4989e-01,  1.1018e+00,  ...,  9.3691e-01,
            9.7898e-03,  2.2422e-01],
          [ 1.7422e+00,  1.0930e+00,  1.1672e+00,  ..., -1.2259e+00,
           -3.4510e-01,  2.0295e+00],
          [ 5.4826e-01,  1.1657e+00,  7.7462e-01,  ...,  1.3120e+00,
           -3.9478e-01, -1.6001e+00]],

         [[ 1.8931e+00, -3.9137e-01,  1.4829e-02,  ...,  2.8188e-01,
            7.8420e-01,  1.4191e+00],
          [ 2.1036e-01,  1.1272e+00,  1.1758e+00,  ...,  1.0786e+00,
           -3.3983e-01,  8.9400e-01],
          [-5.7435e-01,  1.1320e+00,  9.2317e-01,  ...,  6.4011e-01,
            8.6618e-01,  1.6217e+00],
          ...,
          [ 3.5085e-01, -7.2272e-01, -1.3508e+00,  ..., -1.0040e+00,
            6.9446e-01, -7.9106e-01],
          [-7.9288e-01, -2.3790e-01,  4.9975e-01,  ..., -1.2600e+00,
           -3.6130e-01,  1.7833e+00],
          [ 1.2699e+00, -7.3960e-02,  9.8388e-01,  ...,  2.2297e+00,
            7.6955e-01,  1.4493e+00]],

         [[-1.2609e+00,  1.7659e-01,  1.6767e-01,  ...,  1.0007e+00,
           -9.8175e-04,  9.7172e-01],
          [-4.6968e-01, -9.0817e-02,  1.3981e+00,  ..., -8.2077e-01,
           -1.3292e+00,  1.0331e-03],
          [ 5.2134e-01,  2.1738e+00, -1.5413e+00,  ...,  1.2963e+00,
            2.1603e+00,  8.2639e-01],
          ...,
          [-9.0385e-01, -2.3104e+00,  5.1585e-01,  ..., -3.8838e-01,
            6.6063e-02,  2.0226e+00],
          [ 1.9784e+00, -2.3606e-01, -2.8733e-01,  ..., -3.5743e-01,
            1.0208e+00,  8.2738e-01],
          [ 2.7743e+00,  1.1639e+00,  9.5278e-01,  ...,  1.6589e+00,
            4.5166e-01, -1.1621e+00]]],


        [[[ 6.5336e-01, -3.2003e-01,  2.1246e+00,  ...,  7.6724e-01,
            5.7754e-01, -2.9259e-01],
          [-1.5965e+00,  1.0666e+00, -1.2113e-01,  ..., -2.6140e-01,
            2.8023e-01, -1.2202e-01],
          [ 5.4764e-01, -7.4232e-01, -8.5787e-01,  ...,  7.7067e-01,
           -6.8688e-01, -1.6664e+00],
          ...,
          [ 2.6449e-01,  1.1516e+00,  7.8035e-01,  ...,  1.6934e+00,
            6.5924e-01,  4.3997e-01],
          [-9.4063e-01, -1.2337e+00, -1.6775e-01,  ..., -3.7044e-01,
            1.7488e+00, -2.9105e-01],
          [-7.5115e-01, -1.8744e+00,  6.9431e-01,  ..., -1.5013e-01,
           -1.2597e+00,  6.2037e-01]],

         [[ 1.7819e+00, -2.1993e-01, -8.3123e-01,  ..., -1.1850e+00,
            7.4261e-02, -1.4837e+00],
          [-2.8852e-01, -1.8786e+00,  8.2738e-01,  ..., -3.6772e-01,
            4.0198e-01,  9.5076e-02],
          [ 1.3929e-01, -3.3608e-01,  8.9016e-01,  ...,  1.2677e+00,
            1.1003e+00, -1.0391e-01],
          ...,
          [ 1.1075e+00, -4.7174e-02, -9.3835e-01,  ...,  6.1765e-01,
            1.0434e+00,  1.1983e+00],
          [ 3.5018e-01,  4.7674e-01, -4.6201e-01,  ...,  7.5623e-01,
            1.2314e+00, -1.2950e-01],
          [ 6.3861e-01,  8.2503e-01, -4.5231e-02,  ...,  4.5765e-01,
           -4.2140e-02, -1.7759e+00]],

         [[ 8.9666e-01, -8.3493e-01,  9.9475e-01,  ...,  2.3814e+00,
           -1.0785e+00, -1.0002e+00],
          [ 9.4536e-01, -3.6996e-02,  1.3265e+00,  ..., -2.0757e-01,
            8.5976e-01,  1.5207e+00],
          [-7.9001e-01,  5.1010e-01, -6.7072e-01,  ...,  2.8837e-01,
            1.0037e+00, -1.1345e+00],
          ...,
          [ 1.4787e-01, -6.0716e-01, -5.9914e-02,  ..., -6.3741e-01,
            5.6569e-01,  1.1485e+00],
          [-8.4662e-01,  1.0033e+00,  1.6926e+00,  ..., -2.4898e+00,
           -8.7937e-01, -1.0672e+00],
          [-1.0355e+00,  7.3152e-02,  1.0899e+00,  ...,  1.2343e+00,
           -4.3748e-01,  1.9880e-01]]],


        [[[-1.0495e+00, -1.5049e+00,  1.6949e+00,  ...,  2.2314e-01,
           -4.1456e-01,  1.1701e+00],
          [ 8.7829e-01, -1.0837e-01, -4.0874e-01,  ...,  2.5799e-01,
           -8.9661e-01, -9.1289e-01],
          [ 2.7845e-01,  4.9237e-01,  3.7110e-01,  ..., -1.2698e+00,
            3.7347e-02, -7.7660e-01],
          ...,
          [ 4.3739e-01, -1.6133e+00, -2.0422e-01,  ..., -1.1729e-02,
            4.9386e-01,  1.0180e+00],
          [ 6.5150e-01, -1.6439e-02,  1.1549e+00,  ..., -1.4530e+00,
            1.1945e-01, -1.3890e+00],
          [-1.6089e+00,  6.2207e-03, -2.0555e-01,  ...,  9.2455e-03,
            2.2972e+00,  1.0571e+00]],

         [[-1.0437e+00, -2.0156e-01, -6.1245e-01,  ..., -6.2064e-01,
           -2.7846e-01, -3.7187e-01],
          [-7.3534e-02, -1.6522e+00,  8.8815e-01,  ...,  4.1850e-01,
           -4.6741e-01,  1.0419e+00],
          [ 1.0806e-01,  1.7259e+00,  1.5370e-01,  ...,  7.7863e-01,
           -4.0181e-01,  3.4241e-01],
          ...,
          [ 1.0277e+00, -1.8105e+00,  1.4991e-01,  ...,  5.9657e-01,
           -1.2410e+00,  3.5733e-01],
          [ 5.9746e-01,  2.6715e-01, -1.0661e+00,  ..., -1.1960e+00,
           -1.0199e+00, -6.0800e-01],
          [-1.5420e-01,  1.1825e+00,  5.5973e-01,  ...,  1.1288e+00,
            4.0003e-01,  3.6002e-01]],

         [[-6.6013e-01, -2.7366e-02, -3.3612e-01,  ..., -2.7557e-01,
            1.6774e-01,  2.3072e-01],
          [ 1.2095e+00,  5.9106e-01, -7.4129e-01,  ...,  4.2521e-01,
           -9.1576e-01,  1.9465e+00],
          [ 2.4623e+00,  1.2587e+00,  3.9799e-01,  ...,  6.5430e-01,
           -4.5958e-01,  1.2762e+00],
          ...,
          [-1.4557e+00,  8.0278e-01,  3.4651e-01,  ..., -1.2636e+00,
            1.1801e+00, -9.6556e-02],
          [-3.5309e-01, -1.9784e-01,  1.9258e+00,  ...,  6.3403e-01,
            1.4975e+00,  1.7229e+00],
          [-1.0197e+00, -6.1565e-01,  7.5604e-01,  ..., -3.9719e-01,
           -2.1157e+00,  2.8194e+00]]],


        [[[ 2.2984e-01, -2.7654e-01, -1.9212e+00,  ..., -3.5101e-01,
            8.8159e-01, -4.0119e-01],
          [ 6.1201e-01, -1.1325e+00, -6.1409e-01,  ...,  1.1833e+00,
           -5.1078e-01,  3.1845e-01],
          [ 1.0457e+00, -3.2835e-01, -1.4234e+00,  ...,  1.7654e+00,
            1.5985e+00,  9.6359e-02],
          ...,
          [ 1.5899e-01,  2.1034e+00,  1.9663e+00,  ..., -9.7495e-01,
            1.9532e+00,  7.4224e-01],
          [ 8.0707e-01,  1.9488e+00,  4.9595e-01,  ...,  2.3061e-01,
           -3.8538e-01, -2.8067e+00],
          [ 7.9545e-01,  5.3325e-01, -6.8003e-01,  ..., -9.4661e-01,
            4.4985e-01,  9.7887e-01]],

         [[-7.5883e-01,  7.1917e-02, -7.7786e-01,  ...,  1.3086e-01,
           -2.5226e+00,  2.8478e-01],
          [-5.3182e-01, -1.3182e+00, -1.0094e+00,  ..., -1.0953e+00,
            8.0949e-01, -1.6582e-01],
          [-3.6863e-01,  1.6644e-01, -1.5616e+00,  ...,  1.9987e+00,
            9.7860e-01, -5.5544e-01],
          ...,
          [ 1.3410e+00,  1.4618e-01,  1.0455e+00,  ...,  2.7058e-01,
            4.2303e-01,  8.8564e-01],
          [-1.8814e-01,  8.5378e-01,  8.2544e-01,  ..., -4.2194e-01,
            4.4802e-01,  6.2704e-01],
          [-8.9912e-01,  8.3177e-01, -6.4825e-01,  ...,  1.3941e+00,
           -2.6662e+00,  3.5464e-01]],

         [[ 1.2600e-01, -7.0550e-01,  7.9264e-01,  ...,  4.3575e-01,
            1.9276e-01,  1.6815e+00],
          [-2.6532e-01,  5.4715e-02,  1.8190e+00,  ...,  1.1267e+00,
            1.3940e+00, -7.0673e-01],
          [ 7.6817e-01,  2.7359e-01,  5.6427e-02,  ...,  1.5339e-01,
           -9.4139e-01, -1.1846e+00],
          ...,
          [-1.0522e+00, -5.6898e-01, -1.1531e+00,  ...,  6.5160e-01,
            2.8525e-01, -9.7959e-01],
          [ 1.0820e+00,  2.3070e-01,  1.0374e+00,  ..., -8.5300e-01,
            2.1119e-01, -6.0910e-01],
          [ 8.2257e-01, -1.3045e+00, -5.4486e-01,  ..., -6.1910e-01,
           -1.4841e-01,  2.5825e+00]]]])}, 'action': tensor([[8],
        [6],
        [5],
        [6]]), 'reward': tensor([[0.7377],
        [0.8645],
        [0.0472],
        [0.2678]]), 'done': tensor([False, False, False, False])}

We can see that the structure process function need to be fully implemented (like the function stack). This code is actually not clear, and due to hard coding, if you need to support more data types (such as integer), you must make special modifications to the function.

Stack With TreeTensor API

The same workflow can be implemented with treetensor API like the code below.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch

import treetensor.torch as ttorch

B = 4


def get_item():
    return {
        'obs': {
            'scalar': torch.randn(12),
            'image': torch.randn(3, 32, 32),
        },
        'action': torch.randint(0, 10, size=(1,)),
        'reward': torch.rand(1),
        'done': False,
    }


data = [get_item() for _ in range(B)]
# execute `stack` op
data = [ttorch.tensor(d) for d in data]
stacked_data = ttorch.stack(data, dim=0)
# validate
print(stacked_data)
assert stacked_data.obs.image.shape == (B, 3, 32, 32)
assert stacked_data.action.shape == (B, 1)
assert stacked_data.reward.shape == (B, 1)
assert stacked_data.done.shape == (B,)
assert stacked_data.done.dtype == torch.bool

The output should be like below, and the assertion statements can be all passed as well.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
<Tensor 0x7f557106e190>
├── 'action' --> tensor([[3],
│                        [5],
│                        [0],
│                        [9]])
├── 'done' --> tensor([False, False, False, False])
├── 'obs' --> <Tensor 0x7f5571ef3b20>
│   ├── 'image' --> tensor([[[[ 4.5305e-01, -3.0336e-01, -1.6102e+00,  ...,  1.5841e+00,
│   │                           2.0629e+00,  1.0902e-01],
│   │                         [-2.7962e-01, -5.4670e-01,  1.6099e-01,  ...,  8.1282e-01,
│   │                           7.9952e-01, -6.0495e-01],
│   │                         [ 1.1155e-02,  8.4526e-01,  1.0347e+00,  ..., -3.5932e-01,
│   │                          -1.5659e+00,  1.3315e-01],
│   │                         ...,
│   │                         [ 6.0212e-03,  1.7415e+00,  3.9781e-01,  ..., -8.7747e-01,
│   │                           1.0913e-01,  2.2692e+00],
│   │                         [ 1.0508e-01, -2.3612e-01,  2.6880e-01,  ..., -6.0721e-02,
│   │                           1.1314e+00, -8.8264e-01],
│   │                         [ 1.3416e+00,  1.6671e-01,  1.2110e+00,  ..., -1.3156e+00,
│   │                           3.8785e-01, -7.8287e-01]],
│   │               
│   │                        [[ 9.2502e-01,  4.9699e-01, -7.4416e-01,  ...,  9.6407e-01,
│   │                          -1.7177e+00, -1.2746e-01],
│   │                         [-6.3110e-02, -5.3069e-01, -1.3758e+00,  ...,  1.8929e-01,
│   │                          -1.5771e+00,  8.3413e-02],
│   │                         [ 1.1359e+00,  1.2509e-01, -4.7094e-01,  ...,  1.2450e+00,
│   │                          -5.3627e-01,  3.9740e-01],
│   │                         ...,
│   │                         [-6.1291e-02, -9.4952e-01,  1.4534e+00,  ..., -9.5106e-01,
│   │                          -1.6712e-01, -2.1231e-01],
│   │                         [-8.2346e-01,  1.2020e+00, -5.6791e-01,  ...,  4.7281e-01,
│   │                          -1.0816e+00,  2.9828e-01],
│   │                         [-8.5188e-02,  5.7983e-01,  1.2777e+00,  ...,  1.7275e+00,
│   │                          -1.2654e+00, -2.0131e-01]],
│   │               
│   │                        [[-5.3384e-01,  4.9830e-01,  1.6513e-01,  ...,  1.8487e-01,
│   │                          -6.4973e-01,  4.5378e-01],
│   │                         [ 1.7814e+00,  4.5737e-01,  3.7691e-01,  ..., -7.0164e-01,
│   │                          -6.2657e-02, -1.1186e+00],
│   │                         [-5.7041e-01,  8.5518e-04, -1.2308e-01,  ..., -1.4112e+00,
│   │                          -1.1722e+00, -1.2931e+00],
│   │                         ...,
│   │                         [ 3.9679e-01,  3.2742e-01, -6.9487e-01,  ..., -6.5645e-01,
│   │                          -1.7745e-01, -2.0744e-01],
│   │                         [ 3.7309e-01,  2.2747e-02, -2.2039e+00,  ...,  2.1003e-01,
│   │                           4.9526e-01,  7.7278e-01],
│   │                         [ 1.1841e+00, -1.3013e+00,  9.0826e-01,  ...,  4.3452e-01,
│   │                          -2.9804e-01,  3.0763e+00]]],
│   │               
│   │               
│   │                       [[[ 3.3523e-01,  1.6094e+00,  3.0035e-01,  ...,  3.0878e-01,
│   │                          -3.8340e-01, -6.8601e-01],
│   │                         [-2.1227e-02,  1.4496e+00, -5.4613e-01,  ..., -3.1780e-01,
│   │                          -1.6857e+00, -8.3415e-01],
│   │                         [ 3.0169e-01, -1.3458e+00,  2.3000e-01,  ..., -1.5856e+00,
│   │                          -2.1587e-02, -5.7282e-01],
│   │                         ...,
│   │                         [ 6.4079e-01,  1.7333e+00,  2.1134e-01,  ...,  7.5056e-02,
│   │                          -1.7695e-01,  1.2730e-01],
│   │                         [-3.5883e-01, -1.0269e+00,  4.3036e-01,  ...,  1.5077e+00,
│   │                          -4.3318e-02,  6.0996e-01],
│   │                         [ 2.8908e-01,  3.5347e-01,  1.5299e-01,  ..., -1.8552e+00,
│   │                          -7.7469e-01,  1.4434e-02]],
│   │               
│   │                        [[ 1.5832e+00,  9.1332e-01,  3.3736e-01,  ..., -2.4340e-01,
│   │                          -9.4892e-02, -4.5676e-01],
│   │                         [-6.9907e-01, -1.6804e+00,  7.1835e-01,  ...,  2.3717e-01,
│   │                           3.1614e-01, -1.8293e+00],
│   │                         [ 2.0353e-01,  8.3531e-01,  4.3858e-01,  ..., -6.1167e-01,
│   │                           1.7666e-01, -1.5947e-01],
│   │                         ...,
│   │                         [-9.5353e-01,  9.6057e-01, -1.0227e+00,  ..., -3.0698e-01,
│   │                          -8.7517e-01, -1.3401e+00],
│   │                         [-1.6547e+00,  6.4976e-01, -7.7527e-01,  ...,  2.2822e-01,
│   │                           1.1827e+00,  6.2907e-01],
│   │                         [-1.5347e+00,  1.1366e+00,  2.0060e+00,  ..., -2.9645e-01,
│   │                          -2.4798e-01, -9.1440e-01]],
│   │               
│   │                        [[ 6.7731e-01, -1.2372e-01, -1.4714e+00,  ..., -1.7713e+00,
│   │                           1.3586e-01,  6.5763e-01],
│   │                         [ 2.0879e+00, -1.4936e+00, -5.4164e-02,  ...,  9.1828e-01,
│   │                           2.5872e+00, -5.6008e-01],
│   │                         [-1.2007e+00,  8.4864e-01,  9.1895e-01,  ..., -1.3270e+00,
│   │                          -1.1167e+00, -3.4064e-02],
│   │                         ...,
│   │                         [ 5.2187e-02,  5.2889e-01,  3.1232e-01,  ...,  4.7025e-01,
│   │                           5.2799e-01, -1.1737e+00],
│   │                         [-2.9236e-01,  1.4282e+00,  2.3062e-01,  ..., -2.6570e-01,
│   │                           1.1106e+00, -1.3901e+00],
│   │                         [ 1.9271e+00,  7.0249e-01,  2.1241e+00,  ..., -1.0437e+00,
│   │                          -9.6009e-02,  1.1584e+00]]],
│   │               
│   │               
│   │                       [[[-6.4074e-01,  8.8370e-02,  1.6059e+00,  ...,  8.3893e-01,
│   │                           1.2309e+00,  1.1106e+00],
│   │                         [-5.5425e-01, -8.2489e-01,  6.2072e-01,  ...,  1.1677e-01,
│   │                           1.0454e+00,  1.0070e+00],
│   │                         [ 3.4931e-01, -1.0036e+00,  4.0032e-01,  ...,  1.8615e+00,
│   │                           5.9167e-01,  9.9133e-01],
│   │                         ...,
│   │                         [ 3.4256e-01, -1.7520e-01, -1.6834e-01,  ..., -2.0434e+00,
│   │                           2.2654e-01, -5.3918e-01],
│   │                         [-1.9090e-01, -1.4652e-01,  6.3472e-02,  ...,  8.5850e-01,
│   │                           8.8871e-01, -7.8788e-01],
│   │                         [-4.1618e-01,  7.7465e-01,  1.9021e-01,  ...,  1.3071e+00,
│   │                          -7.0832e-01,  1.9594e-01]],
│   │               
│   │                        [[ 7.3759e-01, -1.2796e-01,  7.7650e-01,  ..., -9.0348e-01,
│   │                          -5.5450e-01, -1.5004e+00],
│   │                         [ 4.2839e-01, -3.3463e-01,  8.1024e-01,  ..., -1.2504e+00,
│   │                           1.2227e+00, -2.4110e-02],
│   │                         [-4.5792e-01,  6.0587e-01,  2.3643e-01,  ...,  2.2757e-01,
│   │                          -7.6852e-01,  9.4963e-01],
│   │                         ...,
│   │                         [-2.2462e-01, -7.7750e-03,  8.0273e-01,  ..., -4.6045e-01,
│   │                          -1.4387e+00, -6.3549e-01],
│   │                         [-1.1712e+00,  1.2857e+00,  2.4670e+00,  ..., -1.1469e+00,
│   │                          -1.6900e+00, -4.5593e-03],
│   │                         [-3.4158e-01,  1.8648e-01, -3.1790e-01,  ...,  1.7572e+00,
│   │                           4.2377e-01,  6.0646e-01]],
│   │               
│   │                        [[-9.8104e-01, -7.6587e-02,  1.2252e+00,  ..., -3.0363e-01,
│   │                          -2.5290e+00, -2.9548e-01],
│   │                         [ 1.2126e+00,  1.1443e+00, -5.6300e-01,  ...,  7.0297e-01,
│   │                          -1.1962e-01, -1.0543e+00],
│   │                         [ 1.0650e+00,  1.3232e+00,  1.3519e+00,  ..., -1.7350e-01,
│   │                           3.1324e-01, -1.6650e+00],
│   │                         ...,
│   │                         [-8.1496e-01, -1.0059e+00,  1.5679e+00,  ...,  1.0814e+00,
│   │                           5.7148e-01,  9.1896e-01],
│   │                         [ 1.9223e+00, -3.5209e-02,  1.9954e+00,  ...,  6.8514e-01,
│   │                          -1.6827e+00,  1.8084e-01],
│   │                         [ 9.3977e-01,  3.2010e-01, -1.9909e-01,  ...,  4.8631e-01,
│   │                          -7.1255e-01,  2.7735e-01]]],
│   │               
│   │               
│   │                       [[[ 1.5712e+00,  1.5210e-01, -1.6820e+00,  ..., -8.2562e-01,
│   │                           1.7781e-01,  2.3375e-02],
│   │                         [-2.3743e-01,  8.4960e-01,  1.1401e+00,  ..., -1.5590e+00,
│   │                           4.4488e-01, -1.3853e+00],
│   │                         [-1.5048e+00,  8.5534e-01, -6.3591e-01,  ..., -1.0457e+00,
│   │                          -1.3380e+00, -2.0372e-01],
│   │                         ...,
│   │                         [ 1.6393e+00, -1.0433e-01,  2.4292e-01,  ..., -2.2212e-01,
│   │                          -2.7331e-01,  6.7706e-01],
│   │                         [-1.4391e+00,  5.3963e-01, -7.0966e-01,  ..., -1.0590e-01,
│   │                          -2.8383e-01,  5.1172e-01],
│   │                         [ 1.3692e+00, -9.6058e-01,  8.4043e-01,  ...,  2.5890e-01,
│   │                          -6.8725e-01, -4.4703e-01]],
│   │               
│   │                        [[ 1.7553e-01,  3.6553e-02,  2.0117e-01,  ..., -9.1705e-01,
│   │                           1.0142e+00,  2.8441e-01],
│   │                         [ 7.7580e-01,  2.4734e-02, -7.4700e-01,  ..., -9.0560e-01,
│   │                          -5.1516e-01, -1.2296e+00],
│   │                         [ 1.6646e-01,  1.2060e+00, -7.0133e-01,  ...,  2.9796e-02,
│   │                           3.9324e-01, -1.1462e+00],
│   │                         ...,
│   │                         [-7.7911e-01,  1.2364e+00,  9.0312e-01,  ...,  7.6247e-01,
│   │                          -3.9345e-01,  4.9417e-01],
│   │                         [ 1.2497e+00,  8.3425e-01,  9.6441e-01,  ..., -1.2144e+00,
│   │                           6.4846e-01, -6.6589e-03],
│   │                         [-3.5097e-01, -8.8290e-02,  8.4920e-01,  ...,  1.9532e+00,
│   │                           4.9720e-01, -8.4429e-01]],
│   │               
│   │                        [[ 2.7980e-01, -4.1947e-02, -6.3996e-01,  ..., -1.6186e-01,
│   │                           8.0445e-01,  3.9946e-01],
│   │                         [-1.5958e-01, -9.3063e-03,  8.8107e-01,  ..., -3.3843e-01,
│   │                           1.0142e-01,  5.4543e-01],
│   │                         [ 6.6808e-01, -1.5891e-01,  1.2833e+00,  ...,  5.6917e-01,
│   │                           8.2898e-01,  3.2452e-01],
│   │                         ...,
│   │                         [-9.4703e-01,  1.4934e+00, -8.0702e-01,  ..., -6.9392e-02,
│   │                          -2.6725e-01,  1.0014e+00],
│   │                         [-8.2691e-02,  8.2520e-01, -6.6341e-01,  ...,  6.2946e-01,
│   │                           8.1387e-01, -2.3411e+00],
│   │                         [ 3.3730e+00,  4.9355e-01,  1.4211e-01,  ..., -5.6572e-01,
│   │                          -1.0447e+00,  1.8830e-01]]]])
│   └── 'scalar' --> tensor([[-0.9729, -0.5464, -0.6358,  0.0401, -0.8908,  0.7758,  0.3859, -0.9056,
│                             -1.6813, -0.2444,  0.8474, -0.5752],
│                            [-0.4740, -0.0424,  0.2064,  0.3482,  0.2146, -1.2817, -0.7719, -0.2434,
│                             -0.0908,  1.6274,  1.7456, -0.3639],
│                            [ 0.9763,  0.2411, -0.7345,  0.6713, -2.3894, -0.6765, -0.0462,  1.5050,
│                              0.5170,  0.0492,  0.5740,  0.6226],
│                            [-1.7079,  1.0654, -1.1028, -0.3045, -0.0606,  0.0557,  0.7524, -0.6574,
│                              0.2073, -1.3913,  0.1328,  0.4099]])
└── 'reward' --> tensor([[0.9866],
                         [0.2933],
                         [0.2967],
                         [0.0326]])

This code looks much simpler and clearer.