Stack Structured Data

When the tensors form the tree structures, they are often needed to be stacked together, like the torch.stack() implemented in torch.

Stack With Native PyTorch API

Here is the common code implement with native pytorch API.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch

B = 4


def get_item():
    return {
        'obs': {
            'scalar': torch.randn(12),
            'image': torch.randn(3, 32, 32),
        },
        'action': torch.randint(0, 10, size=(1,)),
        'reward': torch.rand(1),
        'done': False,
    }


data = [get_item() for _ in range(B)]


# execute `stack` op
def stack(data, dim):
    elem = data[0]
    if isinstance(elem, torch.Tensor):
        return torch.stack(data, dim)
    elif isinstance(elem, dict):
        return {k: stack([item[k] for item in data], dim) for k in elem.keys()}
    elif isinstance(elem, bool):
        return torch.BoolTensor(data)
    else:
        raise TypeError("not support elem type: {}".format(type(elem)))


stacked_data = stack(data, dim=0)
# validate
print(stacked_data)
assert stacked_data['obs']['image'].shape == (B, 3, 32, 32)
assert stacked_data['action'].shape == (B, 1)
assert stacked_data['reward'].shape == (B, 1)
assert stacked_data['done'].shape == (B,)
assert stacked_data['done'].dtype == torch.bool

The output should be like below, and the assertion statements can be all passed.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
{'obs': {'scalar': tensor([[ 0.7932, -0.9222, -1.3446, -0.5559,  0.6458,  0.5098, -2.4559,  0.6424,
         -1.3009,  0.8222, -1.0717, -0.1113],
        [-0.2990, -0.6373,  1.1773,  0.5897,  0.1452,  0.3830, -0.0322, -0.4129,
          0.0603, -1.2046,  1.9595,  0.9456],
        [-0.8266, -0.6139, -1.5216,  0.2978,  1.1613, -1.2786, -1.0298, -1.0051,
          1.1522,  2.6098, -0.6124, -0.2839],
        [-1.4161,  0.4941, -1.0175,  1.2538,  2.2227, -1.2593, -0.4021,  0.7888,
          0.9283,  1.0343, -0.2475,  0.1977]]), 'image': tensor([[[[ 1.2380, -2.4631,  0.0337,  ..., -2.4148, -0.3557, -0.1164],
          [ 0.4452, -1.7648, -0.7155,  ..., -1.4488,  1.3118, -1.3697],
          [ 0.5380,  1.1713,  0.9900,  ...,  0.4440, -0.3192,  0.8238],
          ...,
          [ 0.1466,  0.5005, -1.0111,  ...,  0.0319, -0.3572,  1.9773],
          [-0.4833,  0.5371,  1.9037,  ..., -0.8849,  0.0126,  1.1997],
          [ 1.5593,  0.3725, -0.1779,  ...,  1.2821,  0.6637, -0.5054]],

         [[ 0.4508,  1.5284,  1.0313,  ..., -0.0660,  0.5233,  1.4902],
          [-0.1048,  0.2689,  2.1356,  ..., -0.2689,  0.4325, -0.1757],
          [ 0.2675, -0.7940, -1.6288,  ..., -1.3268, -1.3974, -0.6916],
          ...,
          [-1.1467, -0.0748,  0.0811,  ...,  0.3524, -0.1825,  0.3085],
          [ 0.9935,  1.4705, -1.0036,  ..., -0.3275, -0.4407,  0.7062],
          [ 0.0843,  0.8614,  1.2490,  ..., -0.7108,  0.6461, -0.7824]],

         [[-1.0162,  0.3024,  0.0580,  ...,  1.0351,  0.1336, -0.3573],
          [ 1.5491,  0.6748,  0.1653,  ...,  0.5827,  0.2876,  0.3452],
          [ 0.2527,  0.6918,  0.3304,  ...,  1.5127,  0.2724,  0.2934],
          ...,
          [-2.2597, -0.3824,  0.3502,  ...,  1.0090, -1.4997, -0.4550],
          [-1.0045, -0.1039, -0.2959,  ...,  1.9923,  2.4758, -0.5123],
          [-1.4138, -1.1566,  0.6517,  ..., -1.6892, -0.0840, -0.4827]]],


        [[[-2.0162, -0.9346, -0.9355,  ..., -0.0101,  0.0268, -1.0797],
          [ 0.5060, -0.2055,  0.2285,  ...,  0.1447, -0.5228, -0.1111],
          [ 0.4344,  0.1196, -0.6798,  ...,  0.4000, -2.0507, -0.0207],
          ...,
          [ 2.0747, -0.6992, -1.1428,  ..., -1.0759,  0.2787,  1.7821],
          [-0.0237, -0.6858,  0.2768,  ..., -0.3334, -0.0928,  0.2354],
          [-1.0940, -0.0413, -0.0037,  ..., -0.7189,  1.8493,  1.3413]],

         [[-1.0878,  1.4241, -1.0107,  ..., -2.0817,  0.2022, -0.0751],
          [ 0.7906,  0.5587, -0.0749,  ...,  0.0838, -0.4387, -1.0082],
          [ 0.9751,  0.4750,  1.4572,  ...,  0.0473,  0.5883,  0.0516],
          ...,
          [ 0.2110,  0.1640, -1.2474,  ...,  0.0637, -0.2625,  0.0121],
          [-0.6730,  0.7399,  0.5472,  ..., -0.5210,  0.5973,  0.5624],
          [-0.3648, -0.6528,  0.4191,  ..., -0.7186, -0.7375, -0.5261]],

         [[ 1.3876,  2.5972,  0.0335,  ...,  0.6089, -0.4094,  0.0573],
          [ 0.0937,  1.7867, -0.9991,  ..., -0.3316,  0.2051, -0.0547],
          [ 1.3818, -0.0746, -0.5710,  ..., -0.3588, -1.1402,  0.7733],
          ...,
          [-1.2990,  0.1522,  1.6240,  ...,  0.5511, -1.0414,  0.4646],
          [ 0.9846, -0.3556,  1.3440,  ...,  0.2240,  0.4705, -0.2544],
          [ 0.1860,  0.8185,  0.0352,  ..., -1.8402, -1.3069, -0.3065]]],


        [[[-0.1946, -0.0101, -0.2891,  ...,  0.5428, -0.9403,  0.0928],
          [-1.7674,  0.8160,  0.4752,  ..., -0.7384,  1.1939,  0.2459],
          [-0.8130, -1.1600, -0.5070,  ...,  1.0349, -0.2593, -0.1796],
          ...,
          [-0.4393,  0.8608, -0.0352,  ...,  0.5921, -0.7447, -0.9031],
          [ 0.7372,  1.2285,  0.5027,  ..., -1.8992,  1.6161, -0.6071],
          [ 0.1819,  0.1463, -0.0798,  ...,  0.5473,  1.8474,  0.0886]],

         [[ 1.0403, -1.0910,  0.3294,  ...,  2.1361,  1.0072, -0.5012],
          [ 1.3232, -0.6539, -0.3969,  ..., -1.7829,  1.7643,  0.4928],
          [-1.6885, -0.2454, -0.1873,  ..., -0.1517, -0.8394,  0.9339],
          ...,
          [-1.5397, -0.6594,  0.5164,  ..., -0.0815,  2.1328, -0.6738],
          [-0.1701, -1.0005,  0.0138,  ..., -0.2141,  0.5774,  0.1907],
          [ 2.4779,  0.1681, -0.1989,  ...,  2.4262,  0.7116, -0.7654]],

         [[-0.0310, -0.4705,  0.5795,  ...,  0.9993,  1.0736,  1.1848],
          [-0.4708, -0.4015, -0.7381,  ..., -0.4058, -0.5949, -2.1402],
          [-0.9960, -0.3734, -0.2418,  ..., -0.0599, -0.6076, -0.9015],
          ...,
          [-0.3184, -0.7032,  1.6884,  ..., -0.8805,  2.3494, -1.1748],
          [ 1.0097,  0.7583, -0.7440,  ..., -0.8854,  0.0278, -0.7806],
          [-0.8811, -0.6458, -0.2610,  ..., -0.1456, -0.5522,  0.3828]]],


        [[[-0.4092, -0.9741,  0.2104,  ..., -0.6552,  0.1605, -1.1088],
          [-2.0150,  0.6253, -1.6672,  ...,  0.3202,  0.6175, -0.1390],
          [-0.3905, -0.2333, -0.4288,  ...,  0.4944, -0.1823, -0.4771],
          ...,
          [ 0.1231, -0.3220,  0.1632,  ...,  0.4199,  0.9479,  0.3208],
          [-0.6596, -0.1658, -0.5778,  ..., -1.5156, -0.9119,  0.5667],
          [ 0.1516,  0.6772, -0.4716,  ...,  1.0023, -1.1107,  1.5858]],

         [[-0.8067,  0.4242,  0.3898,  ...,  0.7400, -0.3151, -0.0958],
          [ 0.0465, -0.4695,  1.4429,  ..., -1.2441,  0.2401,  0.0861],
          [ 0.7189, -1.5256, -0.5494,  ..., -0.2128, -1.3309, -0.9675],
          ...,
          [-0.1552,  1.2135, -0.0325,  ..., -0.2619,  0.3367, -0.0551],
          [-0.1335,  1.7441,  0.0512,  ..., -0.7586,  2.9495,  0.2414],
          [ 2.6994, -1.8738,  1.1975,  ...,  0.4408,  0.7963, -1.4215]],

         [[-0.9697,  0.1199,  1.2672,  ...,  0.0723,  0.7299, -0.2891],
          [ 0.1681,  0.0112, -0.5478,  ..., -0.6919, -0.6058, -0.3963],
          [ 0.2612,  0.7435,  0.2621,  ..., -0.3277, -1.2682,  0.4378],
          ...,
          [ 0.4943, -1.4571,  0.7201,  ...,  0.8832, -0.3063, -0.0889],
          [-1.3147, -0.5254, -1.3141,  ..., -0.0958, -0.9376,  0.3752],
          [ 0.2867, -1.6274,  1.1582,  ..., -0.1568,  1.1145,  0.5023]]]])}, 'action': tensor([[7],
        [8],
        [8],
        [0]]), 'reward': tensor([[0.5577],
        [0.1247],
        [0.8737],
        [0.1554]]), 'done': tensor([False, False, False, False])}

We can see that the structure process function need to be fully implemented (like the function stack). This code is actually not clear, and due to hard coding, if you need to support more data types (such as integer), you must make special modifications to the function.

Stack With TreeTensor API

The same workflow can be implemented with treetensor API like the code below.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch

import treetensor.torch as ttorch

B = 4


def get_item():
    return {
        'obs': {
            'scalar': torch.randn(12),
            'image': torch.randn(3, 32, 32),
        },
        'action': torch.randint(0, 10, size=(1,)),
        'reward': torch.rand(1),
        'done': False,
    }


data = [get_item() for _ in range(B)]
# execute `stack` op
data = [ttorch.tensor(d) for d in data]
stacked_data = ttorch.stack(data, dim=0)
# validate
print(stacked_data)
assert stacked_data.obs.image.shape == (B, 3, 32, 32)
assert stacked_data.action.shape == (B, 1)
assert stacked_data.reward.shape == (B, 1)
assert stacked_data.done.shape == (B,)
assert stacked_data.done.dtype == torch.bool

The output should be like below, and the assertion statements can be all passed as well.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
<Tensor 0x7f41b67ae1f0>
├── 'action' --> tensor([[5],
│                        [0],
│                        [5],
│                        [2]])
├── 'done' --> tensor([False, False, False, False])
├── 'obs' --> <Tensor 0x7f41b73f3b20>
│   ├── 'image' --> tensor([[[[-9.0018e-01,  1.4331e+00, -2.6982e+00,  ...,  9.5270e-01,
│   │                          -2.3200e+00,  2.0061e-01],
│   │                         [-7.3227e-01,  2.5837e-01,  1.4573e+00,  ..., -8.2640e-01,
│   │                           2.0012e-01, -8.6422e-01],
│   │                         [-2.3586e-01, -9.7709e-01,  1.4994e+00,  ..., -8.8790e-01,
│   │                           1.4724e+00,  9.9293e-03],
│   │                         ...,
│   │                         [-7.6320e-02,  1.7216e+00,  7.8413e-01,  ...,  1.4411e+00,
│   │                           8.6898e-01,  8.5641e-01],
│   │                         [ 8.3306e-01,  1.9879e-01, -4.2274e-01,  ..., -7.5366e-01,
│   │                           1.3348e+00,  5.3655e-01],
│   │                         [ 6.6418e-01,  5.6501e-01, -1.9327e-01,  ...,  1.7000e+00,
│   │                          -8.4681e-01, -1.0415e+00]],
│   │               
│   │                        [[-3.5870e-01,  1.8708e+00, -7.0289e-01,  ...,  5.8492e-01,
│   │                           5.7687e-01, -3.4394e-01],
│   │                         [ 3.5668e-02, -8.1907e-01, -5.2651e-01,  ..., -2.2884e+00,
│   │                           1.1385e+00, -5.4354e-01],
│   │                         [ 1.6017e+00,  4.9983e-01,  2.4461e-01,  ..., -1.6681e-01,
│   │                           1.0008e+00, -3.7325e-01],
│   │                         ...,
│   │                         [ 1.7583e+00, -1.0595e+00, -1.8706e-01,  ..., -2.2764e-01,
│   │                           1.9069e+00, -7.1246e-01],
│   │                         [ 2.9182e-01, -3.3004e-01,  1.4129e+00,  ..., -3.2604e-01,
│   │                           8.7452e-02,  1.0924e+00],
│   │                         [-6.9581e-01, -7.5228e-01,  2.1548e+00,  ..., -2.4424e-02,
│   │                           1.3193e+00, -7.8471e-01]],
│   │               
│   │                        [[ 9.6777e-01, -6.4119e-01,  7.7836e-01,  ...,  5.3110e-01,
│   │                           7.6621e-01, -1.8420e+00],
│   │                         [-7.9994e-01, -7.0083e-01, -2.1568e+00,  ...,  3.3342e-01,
│   │                           9.2908e-01,  1.4338e+00],
│   │                         [ 1.8233e-01, -1.2284e+00, -9.6041e-01,  ...,  1.9880e+00,
│   │                           1.3219e+00, -3.0737e-01],
│   │                         ...,
│   │                         [ 1.7319e+00, -2.3418e+00, -6.0352e-01,  ..., -1.3142e+00,
│   │                          -3.2159e-01,  7.3270e-01],
│   │                         [-1.2117e+00, -6.0486e-01,  1.6194e+00,  ..., -3.0595e-01,
│   │                           1.1062e-01,  7.6021e-01],
│   │                         [ 1.7458e+00,  1.0433e+00, -3.9791e-01,  ..., -3.7883e-01,
│   │                          -1.5831e+00, -4.8594e-01]]],
│   │               
│   │               
│   │                       [[[-2.2570e-01, -2.0785e+00, -5.1630e-01,  ..., -7.9495e-01,
│   │                          -8.1757e-01,  1.4740e+00],
│   │                         [-9.9191e-01,  5.4883e-01,  9.7951e-01,  ...,  8.9641e-01,
│   │                          -2.6024e-01, -3.6538e+00],
│   │                         [ 7.5449e-01,  2.6033e-02, -5.3318e-01,  ...,  1.3573e+00,
│   │                           5.2226e-01, -5.8673e-01],
│   │                         ...,
│   │                         [-5.1627e-01, -1.9321e+00,  4.7618e-01,  ...,  4.5441e-01,
│   │                          -3.2623e-02,  1.0679e+00],
│   │                         [ 1.6424e+00, -1.6767e+00, -1.1283e+00,  ...,  6.5228e-01,
│   │                          -5.7813e-01, -1.8367e+00],
│   │                         [ 1.2287e+00,  3.3300e-01, -5.8853e-01,  ...,  3.0922e+00,
│   │                           2.3950e+00, -3.6193e-02]],
│   │               
│   │                        [[ 6.9714e-01,  2.2439e-01, -7.4174e-01,  ..., -1.2950e-01,
│   │                           1.4317e-01, -6.9019e-01],
│   │                         [-5.0578e-01,  1.3750e+00, -3.8295e-01,  ..., -3.2069e-01,
│   │                          -2.9079e-01, -1.1588e-01],
│   │                         [ 1.5753e-01, -1.6261e-01, -8.6985e-01,  ..., -1.0257e+00,
│   │                           2.4558e-01,  1.2468e+00],
│   │                         ...,
│   │                         [ 9.5873e-01, -1.8259e+00, -1.2166e+00,  ...,  1.5918e+00,
│   │                          -1.6896e+00, -8.9217e-01],
│   │                         [-1.1260e+00, -2.5787e+00, -6.0019e-01,  ...,  1.3191e+00,
│   │                           8.4857e-01, -5.5990e-01],
│   │                         [-4.5927e-02,  1.9239e+00,  7.9091e-01,  ..., -9.1586e-01,
│   │                          -1.2105e-01, -1.7218e+00]],
│   │               
│   │                        [[ 3.7243e-01,  4.4061e-01,  2.2152e-01,  ...,  1.4713e+00,
│   │                           1.7862e+00, -1.7048e+00],
│   │                         [-4.4005e-01, -3.5211e-01,  4.8723e-01,  ...,  2.1826e+00,
│   │                          -1.0280e+00,  1.9704e-01],
│   │                         [ 5.1414e-01,  4.3469e-01, -4.7844e-01,  ..., -2.4866e-01,
│   │                          -6.6008e-01,  1.9696e-01],
│   │                         ...,
│   │                         [-1.5020e-01,  1.3227e+00, -1.1274e+00,  ..., -8.3746e-03,
│   │                          -4.9178e-01, -1.5900e-01],
│   │                         [ 2.7750e-01,  1.8937e+00,  2.5466e-01,  ...,  1.5497e+00,
│   │                           1.1546e+00,  2.2999e-01],
│   │                         [-1.2336e+00, -3.9497e-01,  1.9618e-01,  ...,  8.7128e-01,
│   │                           9.9361e-01,  4.9353e-01]]],
│   │               
│   │               
│   │                       [[[ 4.1093e-01, -2.7061e-01,  3.0286e-01,  ..., -5.1303e-02,
│   │                          -3.1289e-01, -9.2387e-01],
│   │                         [ 5.8699e-01,  1.7289e+00,  2.1932e+00,  ...,  5.1016e-01,
│   │                           1.1350e+00, -1.1008e+00],
│   │                         [ 1.2028e+00,  1.3375e+00,  2.4198e+00,  ...,  1.1021e+00,
│   │                           5.0144e-01,  9.3196e-01],
│   │                         ...,
│   │                         [ 2.1455e-01,  8.3995e-01,  8.4887e-01,  ..., -1.4144e+00,
│   │                          -1.0022e+00,  1.3298e+00],
│   │                         [-1.1129e+00,  8.8454e-01, -9.6536e-01,  ..., -5.7438e-01,
│   │                           7.3645e-01, -1.5947e-01],
│   │                         [-9.5826e-02,  2.1352e+00, -5.7948e-01,  ...,  1.5582e+00,
│   │                          -9.4527e-01,  9.2179e-01]],
│   │               
│   │                        [[ 3.8480e-01, -2.3764e+00, -1.4914e+00,  ..., -2.3150e-01,
│   │                           1.3181e+00, -1.3938e+00],
│   │                         [ 3.0355e+00, -2.6426e-02,  6.1563e-01,  ...,  1.3988e+00,
│   │                          -1.8974e+00,  1.0663e+00],
│   │                         [ 1.5087e+00,  1.6750e+00, -9.3298e-01,  ...,  2.1275e+00,
│   │                          -2.0282e-01,  9.5458e-01],
│   │                         ...,
│   │                         [-5.0105e-01,  7.3980e-01,  3.1559e-01,  ...,  1.9789e-01,
│   │                          -2.1448e-01,  3.7209e-01],
│   │                         [ 1.1327e+00, -2.4825e-01, -7.7777e-01,  ..., -1.0147e+00,
│   │                           3.1023e-01, -6.9260e-01],
│   │                         [ 4.2043e-01,  1.3579e-01,  4.1384e-01,  ..., -6.8451e-01,
│   │                          -2.3869e-01,  9.2342e-01]],
│   │               
│   │                        [[ 1.5819e+00,  2.3846e-02,  1.5621e+00,  ...,  3.6346e+00,
│   │                           1.0842e+00,  4.9328e-01],
│   │                         [ 7.9789e-01,  2.2789e-01, -4.1591e-01,  ..., -5.8387e-02,
│   │                           7.5209e-01, -5.4905e-01],
│   │                         [ 8.7766e-01,  5.3516e-01, -2.2979e-01,  ...,  7.6539e-01,
│   │                          -3.7189e-01,  1.1579e-01],
│   │                         ...,
│   │                         [ 3.8239e-01,  1.4717e+00, -8.6206e-01,  ..., -6.4551e-01,
│   │                           1.1295e+00, -1.4083e+00],
│   │                         [-1.9115e+00,  1.8455e-01, -5.7500e-01,  ...,  3.1396e+00,
│   │                          -7.3549e-01, -2.0597e+00],
│   │                         [ 7.1369e-01,  4.2059e-01, -3.2238e-01,  ..., -6.1477e-02,
│   │                          -9.2156e-01, -8.6877e-01]]],
│   │               
│   │               
│   │                       [[[-3.0921e-03, -6.0142e-01,  1.6273e+00,  ..., -7.6559e-02,
│   │                          -1.2123e-01, -6.8688e-01],
│   │                         [ 5.2613e-01, -1.0123e+00, -1.8765e+00,  ..., -6.0181e-01,
│   │                           5.4449e-01,  3.0890e-01],
│   │                         [ 6.6949e-01,  1.5806e-01,  9.1377e-01,  ..., -1.0309e+00,
│   │                           1.1952e+00, -3.5461e-01],
│   │                         ...,
│   │                         [ 2.3458e-01,  6.0245e-01, -4.4742e-01,  ..., -1.6361e+00,
│   │                          -5.2046e-01,  1.4693e-01],
│   │                         [-7.4344e-01,  4.0334e-01, -3.9413e-01,  ...,  2.9190e+00,
│   │                           8.6533e-01, -6.4462e-01],
│   │                         [ 1.4130e+00, -5.7617e-01, -8.3556e-01,  ..., -3.6027e-01,
│   │                           2.7963e-01,  1.0226e+00]],
│   │               
│   │                        [[-1.2985e+00,  2.9876e-01,  1.4098e+00,  ...,  5.8165e-01,
│   │                           4.9399e-01,  2.3263e-01],
│   │                         [-6.1401e-01, -6.8176e-01, -1.2204e-02,  ...,  1.3359e+00,
│   │                           5.9817e-01,  2.6285e-01],
│   │                         [ 1.7404e+00, -5.7152e-01, -4.4089e-01,  ..., -5.1100e-01,
│   │                          -3.1897e-01,  7.9213e-01],
│   │                         ...,
│   │                         [-1.0424e+00, -4.6219e-01, -1.9417e+00,  ...,  3.3510e-01,
│   │                           9.2694e-01, -1.2596e+00],
│   │                         [-1.4249e+00,  5.0887e-01, -3.7231e-01,  ...,  8.5508e-01,
│   │                           1.9456e+00,  8.1086e-01],
│   │                         [ 4.6129e-01,  1.4800e+00,  8.2697e-01,  ...,  1.2298e+00,
│   │                           1.2942e+00, -7.5299e-01]],
│   │               
│   │                        [[-4.2197e-02,  1.0379e+00, -3.7286e-01,  ..., -3.3648e-01,
│   │                           1.3749e+00, -1.2164e+00],
│   │                         [-8.7268e-01, -2.4208e-01, -1.5337e+00,  ..., -5.4504e-01,
│   │                          -9.0323e-01, -3.8476e-01],
│   │                         [-3.2578e-01, -5.6141e-01,  5.5550e-01,  ...,  1.9809e-01,
│   │                           9.1609e-02,  3.1676e-01],
│   │                         ...,
│   │                         [-2.9798e-02, -4.5428e-01, -1.7056e-01,  ..., -2.3765e-01,
│   │                           1.0142e+00,  1.1245e+00],
│   │                         [-1.3569e+00,  1.8312e+00,  5.5271e-01,  ...,  1.2056e+00,
│   │                           6.4420e-01, -1.8018e+00],
│   │                         [-7.1896e-02, -3.0519e-01, -1.7317e+00,  ..., -1.9754e-01,
│   │                          -2.9306e-01,  4.7570e-01]]]])
│   └── 'scalar' --> tensor([[-1.5883, -0.1771, -0.2445, -0.3159,  0.3222, -0.7895,  0.2759, -0.3163,
│                             -1.5060, -0.3449,  1.2375,  1.7028],
│                            [-0.7163, -0.9695, -0.2881,  1.2854,  0.5681,  0.0960,  0.0139,  0.5039,
│                             -0.3532, -0.7387,  1.2294,  0.1023],
│                            [-0.8173, -0.8824,  0.5326, -1.1704,  0.6222,  0.8051, -0.1758, -2.0343,
│                             -1.5388,  0.0624, -1.6759, -0.1480],
│                            [-0.2716, -0.9561, -0.1131, -0.7782,  0.4266, -0.7936,  0.1344, -1.7490,
│                             -1.0896,  0.5602, -0.1041, -0.8478]])
└── 'reward' --> tensor([[0.4802],
                         [0.9152],
                         [0.2030],
                         [0.1367]])

This code looks much simpler and clearer.