Stack Structured Data

When the tensors form the tree structures, they are often needed to be stacked together, like the torch.stack() implemented in torch.

Stack With Native PyTorch API

Here is the common code implement with native pytorch API.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch

B = 4


def get_item():
    return {
        'obs': {
            'scalar': torch.randn(12),
            'image': torch.randn(3, 32, 32),
        },
        'action': torch.randint(0, 10, size=(1,)),
        'reward': torch.rand(1),
        'done': False,
    }


data = [get_item() for _ in range(B)]


# execute `stack` op
def stack(data, dim):
    elem = data[0]
    if isinstance(elem, torch.Tensor):
        return torch.stack(data, dim)
    elif isinstance(elem, dict):
        return {k: stack([item[k] for item in data], dim) for k in elem.keys()}
    elif isinstance(elem, bool):
        return torch.BoolTensor(data)
    else:
        raise TypeError("not support elem type: {}".format(type(elem)))


stacked_data = stack(data, dim=0)
# validate
print(stacked_data)
assert stacked_data['obs']['image'].shape == (B, 3, 32, 32)
assert stacked_data['action'].shape == (B, 1)
assert stacked_data['reward'].shape == (B, 1)
assert stacked_data['done'].shape == (B,)
assert stacked_data['done'].dtype == torch.bool

The output should be like below, and the assertion statements can be all passed.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
{'obs': {'scalar': tensor([[-2.4885, -0.1894,  1.0505, -0.5696, -1.3865,  1.2687, -2.2627, -1.4003,
          0.2217, -0.4253,  0.6571,  0.2556],
        [-1.6167,  0.5180, -1.9645,  0.6586,  1.7927, -0.5265,  0.3769, -0.0284,
         -0.2526,  0.3659,  0.7190,  1.1841],
        [-0.7709, -0.3557,  1.0179,  0.1447,  0.9058, -0.3972,  0.9837,  0.1803,
          1.1207, -0.4273,  0.0616, -1.5844],
        [ 0.5991, -0.7407,  0.4536,  0.1069,  0.8117, -1.4940,  0.7157, -0.3310,
         -0.6182, -1.6858,  1.8456,  0.1732]]), 'image': tensor([[[[ 0.4978, -1.2501, -0.2585,  ...,  0.9171, -0.9232,  0.4134],
          [-0.8668, -1.1759,  0.8838,  ...,  0.7510, -0.6181,  0.0936],
          [-0.8729,  1.5218,  2.2357,  ..., -1.7105, -0.3870, -0.7484],
          ...,
          [-0.0285, -0.7611,  0.6748,  ...,  1.0388,  0.2471, -2.2912],
          [ 1.3451,  0.2561,  0.4394,  ...,  0.9724, -0.2808, -2.1780],
          [-0.1965,  0.3809,  1.3289,  ...,  0.1252,  0.6652, -0.8572]],

         [[ 0.4101,  1.1065, -0.2862,  ..., -1.2804, -1.5356, -0.0364],
          [-1.5546, -1.0120,  0.9758,  ..., -0.2480, -0.9406, -0.2451],
          [ 0.8384, -0.8384,  2.3684,  ...,  1.2431,  0.0674,  0.1893],
          ...,
          [ 0.2290,  0.2318, -0.3985,  ..., -2.6407,  1.6062,  1.0113],
          [-1.1343,  0.8138, -1.0180,  ...,  1.2344,  0.4174,  0.2905],
          [ 0.0656, -0.3539,  1.5355,  ..., -1.8539, -0.7397, -1.6315]],

         [[ 0.4509,  0.9131,  0.9687,  ..., -0.6503, -0.3681, -0.8998],
          [ 0.1032,  0.2697, -0.9172,  ...,  0.4166,  1.5272,  1.1800],
          [ 1.4287,  1.1963,  1.0067,  ...,  0.8998, -0.0643,  0.5635],
          ...,
          [-0.3159,  0.3187, -1.0065,  ...,  0.3267, -1.1541,  1.6143],
          [-1.3905, -0.6715,  0.2177,  ..., -0.0546, -0.1315, -0.0987],
          [ 1.0319,  0.8301, -0.5162,  ...,  0.0639, -0.1456,  1.0828]]],


        [[[-0.0918,  0.0803,  2.6462,  ..., -0.7681,  1.2222, -0.8572],
          [-0.5196,  1.3421, -0.2328,  ..., -1.0880,  0.7320,  2.3672],
          [ 0.5456,  0.9193, -0.2553,  ...,  1.5758,  0.6394,  0.0414],
          ...,
          [ 0.8551,  0.0557, -0.2813,  ...,  0.9962,  0.0257, -0.7037],
          [-0.0888,  0.0296,  0.0246,  ...,  0.4145, -2.0659, -0.3468],
          [-0.6369, -2.5740, -0.4460,  ..., -0.2904,  0.2416,  0.6949]],

         [[-1.4777, -0.2630, -1.2332,  ..., -0.0765,  0.1558,  0.5217],
          [-0.2586, -1.8858, -1.6222,  ..., -0.2315, -0.1977, -0.8328],
          [-1.3002, -0.0452, -0.5856,  ...,  0.2939,  0.1649, -1.0120],
          ...,
          [-0.4112, -0.3250, -0.6504,  ..., -0.6918,  0.9288,  0.0184],
          [-0.5738,  0.0869, -0.8330,  ..., -1.0286, -1.2463,  0.2466],
          [-0.4157, -0.1999,  0.3587,  ...,  0.1341,  1.1860,  0.3728]],

         [[ 0.7094, -0.5768,  0.9085,  ...,  0.7164,  0.6091, -0.4209],
          [ 1.4925, -0.7576, -0.1510,  ...,  1.6984, -0.7491,  0.1250],
          [ 1.4881,  0.3677,  0.9488,  ...,  0.3255, -0.0304, -0.2892],
          ...,
          [ 0.5784,  0.6722, -0.4441,  ..., -0.2097, -0.8890,  0.7273],
          [ 0.8981, -0.8770,  0.0075,  ..., -1.1683,  0.1248, -0.5278],
          [-0.1051, -0.0663, -0.9547,  ..., -1.8686, -0.3838,  0.9184]]],


        [[[ 0.0277,  1.7713,  0.4556,  ...,  0.8780, -0.4100, -2.4852],
          [-0.1821,  0.4623, -0.7123,  ..., -0.0690,  1.4253,  0.2717],
          [-0.4959, -0.1001, -0.2545,  ...,  0.4245, -0.2398,  0.0861],
          ...,
          [ 1.6672,  0.2831, -0.7439,  ..., -0.1039, -1.4552, -0.2802],
          [-0.6393,  0.6663, -1.2996,  ...,  0.2319, -0.8950,  0.0110],
          [-0.7170, -1.1117, -0.4357,  ..., -1.7127, -0.4674, -0.6336]],

         [[ 0.9503,  0.1914,  0.9174,  ...,  1.1520, -1.8099,  0.3725],
          [-0.8349, -0.2042, -0.1120,  ..., -0.0283,  0.3038, -2.7491],
          [ 0.1278,  0.0234, -0.4974,  ...,  0.6057, -0.3488,  0.8233],
          ...,
          [-0.6157,  1.0636,  0.0954,  ..., -0.0259, -0.5212,  1.4392],
          [ 0.8266, -0.2745, -1.1778,  ..., -0.5911, -0.5791, -1.6650],
          [ 1.0428,  1.6942, -0.4652,  ..., -0.1240, -0.4113,  0.0909]],

         [[ 0.1097,  0.8598,  1.0994,  ...,  2.6004,  0.5754, -0.0178],
          [-0.4325, -0.0482,  0.1976,  ...,  2.3090,  1.6262,  0.0279],
          [-0.6250,  1.0215,  1.3977,  ..., -0.3246,  0.9411,  0.2160],
          ...,
          [-0.2375,  0.7877,  1.8226,  ..., -0.6274,  0.1500,  0.5809],
          [ 0.0385, -0.4142,  1.3438,  ..., -0.1054,  0.3152,  0.8567],
          [-1.0169, -0.7359, -0.7361,  ...,  0.2474,  0.1202,  1.1428]]],


        [[[-1.0234, -1.3896,  0.3372,  ...,  0.0209, -0.2851, -0.1365],
          [ 1.0168, -1.0918, -0.1256,  ..., -0.6947, -0.7739,  0.8017],
          [-0.3967,  1.7463,  0.3932,  ..., -0.1429,  0.2764,  0.5867],
          ...,
          [ 1.4034,  0.4194,  0.1746,  ..., -3.3174, -1.5328,  0.0111],
          [-1.0120, -0.0721, -0.5863,  ..., -0.4522,  0.0414,  2.2797],
          [-0.0784, -0.8473,  2.1855,  ...,  0.3322, -0.9017, -0.8769]],

         [[-0.5934,  0.1361, -0.7575,  ...,  0.4584, -1.3004,  0.4093],
          [ 0.9002, -0.8695,  0.5391,  ...,  0.5897,  0.5335, -0.2362],
          [-1.0964, -0.4823,  0.0381,  ..., -1.9960,  1.9149, -0.5002],
          ...,
          [-1.6471,  0.9005,  2.4646,  ..., -2.0139,  0.1132,  1.0323],
          [-0.7244,  1.2568,  0.1826,  ..., -0.5564,  0.3448, -0.8192],
          [-0.8679, -0.5046, -0.5917,  ..., -0.1500,  1.0475, -0.2927]],

         [[-0.4847,  0.0867, -0.4310,  ..., -0.5368, -0.2088,  1.1243],
          [-1.4942, -0.4517, -0.6754,  ..., -1.5830,  0.1005, -0.0155],
          [-0.4695,  0.3367, -0.5569,  ..., -0.7086,  0.4225,  0.2779],
          ...,
          [ 1.2799, -0.6381,  0.3642,  ...,  1.4593,  0.7397,  0.0202],
          [ 0.3282,  0.7489, -0.1701,  ..., -0.0997,  2.1312, -0.3191],
          [-1.0411,  1.1142, -0.0842,  ...,  1.2986, -0.2909, -1.1934]]]])}, 'action': tensor([[0],
        [4],
        [2],
        [3]]), 'reward': tensor([[0.7910],
        [0.0778],
        [0.6917],
        [0.9253]]), 'done': tensor([False, False, False, False])}

We can see that the structure process function need to be fully implemented (like the function stack). This code is actually not clear, and due to hard coding, if you need to support more data types (such as integer), you must make special modifications to the function.

Stack With TreeTensor API

The same workflow can be implemented with treetensor API like the code below.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch

import treetensor.torch as ttorch

B = 4


def get_item():
    return {
        'obs': {
            'scalar': torch.randn(12),
            'image': torch.randn(3, 32, 32),
        },
        'action': torch.randint(0, 10, size=(1,)),
        'reward': torch.rand(1),
        'done': False,
    }


data = [get_item() for _ in range(B)]
# execute `stack` op
data = [ttorch.tensor(d) for d in data]
stacked_data = ttorch.stack(data, dim=0)
# validate
print(stacked_data)
assert stacked_data.obs.image.shape == (B, 3, 32, 32)
assert stacked_data.action.shape == (B, 1)
assert stacked_data.reward.shape == (B, 1)
assert stacked_data.done.shape == (B,)
assert stacked_data.done.dtype == torch.bool

The output should be like below, and the assertion statements can be all passed as well.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
<Tensor 0x7fed44eae1f0>
├── 'action' --> tensor([[2],
│                        [3],
│                        [0],
│                        [7]])
├── 'done' --> tensor([False, False, False, False])
├── 'obs' --> <Tensor 0x7fed457f3b20>
│   ├── 'image' --> tensor([[[[ 2.0416e+00, -5.9925e-01, -1.2449e+00,  ...,  3.0378e-01,
│   │                           8.0941e-01,  2.5898e-02],
│   │                         [ 1.2635e-01,  7.6078e-01, -8.3621e-01,  ..., -2.2284e+00,
│   │                           2.2912e-01, -4.8490e-01],
│   │                         [ 6.8345e-01, -3.1160e-01, -3.1799e-01,  ...,  5.5821e-01,
│   │                           8.5515e-01, -1.8264e+00],
│   │                         ...,
│   │                         [-2.5850e-01, -8.2596e-01, -1.2165e+00,  ..., -2.4003e+00,
│   │                           6.4898e-01,  3.4618e-01],
│   │                         [ 1.0162e+00, -5.7869e-01,  3.5027e-02,  ..., -5.3370e-01,
│   │                           3.7601e-01,  1.2790e+00],
│   │                         [ 2.4595e-01, -3.2430e-01, -1.4574e+00,  ...,  8.2505e-01,
│   │                          -1.3795e+00,  1.1111e+00]],
│   │               
│   │                        [[ 1.9650e+00,  3.7545e-01, -5.6306e-02,  ..., -1.5407e+00,
│   │                          -1.2354e+00, -2.0725e+00],
│   │                         [ 7.2166e-01,  6.2287e-01,  1.5451e+00,  ..., -5.2576e-01,
│   │                          -3.3847e-01,  5.4572e-01],
│   │                         [ 3.8211e-01, -1.8942e+00,  1.2981e-01,  ..., -1.1306e+00,
│   │                           8.2827e-01, -1.4117e+00],
│   │                         ...,
│   │                         [-3.0986e-01,  8.3433e-01, -9.1089e-01,  ...,  1.8090e+00,
│   │                           1.4611e+00, -9.0840e-01],
│   │                         [ 4.1999e-01, -7.6364e-01, -3.1119e-01,  ...,  3.5400e-01,
│   │                          -1.4407e-01,  3.3688e-01],
│   │                         [-9.2383e-01,  9.1741e-02,  9.1685e-01,  ...,  1.1685e+00,
│   │                           1.4084e+00,  2.0603e+00]],
│   │               
│   │                        [[ 1.7175e+00, -8.3146e-01,  5.8421e-01,  ..., -2.3445e-01,
│   │                           6.9048e-01,  1.5010e+00],
│   │                         [ 1.1013e-01,  1.6321e+00,  3.2406e-01,  ...,  2.5666e-01,
│   │                          -7.1023e-01,  6.3358e-01],
│   │                         [ 2.7281e-01,  2.9113e-02,  1.5413e+00,  ...,  2.5555e-01,
│   │                          -4.0019e-01, -8.1130e-01],
│   │                         ...,
│   │                         [ 1.0454e+00,  1.0125e+00,  1.1726e+00,  ..., -1.9043e-01,
│   │                           9.8482e-01,  2.3942e+00],
│   │                         [-1.1372e+00,  3.0568e-01,  1.0446e+00,  ..., -1.6331e+00,
│   │                           6.9946e-01, -3.7304e-01],
│   │                         [ 1.8728e+00,  5.0260e-01,  6.7302e-01,  ..., -5.2926e-01,
│   │                          -3.4716e-01, -8.8449e-01]]],
│   │               
│   │               
│   │                       [[[ 7.5380e-01,  1.5068e+00, -1.2491e+00,  ..., -4.2131e-01,
│   │                          -2.0570e+00, -1.4849e-01],
│   │                         [-1.2820e-01, -1.6323e+00,  3.9871e-01,  ..., -7.8880e-01,
│   │                           2.5858e+00,  8.6370e-01],
│   │                         [ 9.6607e-01, -1.1814e+00,  5.1013e-01,  ...,  1.6661e+00,
│   │                          -1.8071e+00, -5.7397e-01],
│   │                         ...,
│   │                         [ 4.7556e-01, -5.8294e-01, -4.6873e-01,  ...,  6.2887e-02,
│   │                          -2.0871e+00,  6.2909e-01],
│   │                         [ 1.1660e+00, -1.8203e+00,  7.3137e-01,  ...,  9.0955e-01,
│   │                           8.7783e-01,  5.2051e-01],
│   │                         [ 1.0795e+00, -7.8222e-01, -5.2214e-01,  ...,  1.0444e+00,
│   │                           1.1908e+00, -9.4662e-01]],
│   │               
│   │                        [[-1.3826e+00,  3.5153e-01, -6.5782e-01,  ...,  7.8405e-01,
│   │                           1.5605e+00,  8.3677e-01],
│   │                         [ 3.5947e-01,  2.1041e-01,  3.7396e-01,  ...,  1.9038e+00,
│   │                          -3.2423e-01,  9.3666e-01],
│   │                         [ 7.4932e-01, -5.9710e-01, -1.8353e+00,  ...,  2.2123e+00,
│   │                          -6.8229e-01, -3.4608e-01],
│   │                         ...,
│   │                         [ 1.0183e+00,  1.3325e+00,  6.6054e-02,  ..., -1.4577e-01,
│   │                           6.3652e-01,  9.8351e-01],
│   │                         [ 2.2618e+00,  1.2337e+00,  1.7791e+00,  ...,  6.6020e-02,
│   │                           3.3089e+00,  3.9350e-01],
│   │                         [-2.2728e+00,  7.3721e-02,  1.1925e+00,  ...,  3.0889e-01,
│   │                           1.0032e-01, -2.4847e-01]],
│   │               
│   │                        [[ 9.7788e-01, -5.3582e-02, -1.1819e+00,  ..., -9.4883e-02,
│   │                           1.3618e+00, -1.8994e-01],
│   │                         [ 8.4481e-01, -5.4782e-01, -4.9596e-01,  ..., -4.1341e-01,
│   │                           7.2254e-01, -5.3765e-01],
│   │                         [-3.3592e-01,  2.1065e+00, -5.8536e-01,  ...,  1.1618e+00,
│   │                           4.9817e-01, -1.6466e+00],
│   │                         ...,
│   │                         [ 1.3575e+00,  2.4364e-01, -2.9475e-01,  ...,  2.5356e-01,
│   │                           9.8970e-01,  9.8877e-01],
│   │                         [-3.3773e-01, -7.7994e-01, -7.4169e-01,  ..., -3.0205e-01,
│   │                          -6.8611e-01,  9.5522e-01],
│   │                         [ 5.7925e-01,  1.1265e+00,  1.4228e+00,  ..., -5.3546e-01,
│   │                           7.0475e-01, -7.5621e-01]]],
│   │               
│   │               
│   │                       [[[-8.0855e-01,  2.3513e-01,  1.1527e+00,  ..., -9.8934e-01,
│   │                           7.5847e-01, -1.8872e-01],
│   │                         [ 1.8957e+00, -3.6922e-03, -1.4885e-01,  ...,  8.5574e-01,
│   │                           9.8840e-01, -2.1442e+00],
│   │                         [ 2.7207e-03,  2.2655e+00,  7.2555e-01,  ...,  1.0556e+00,
│   │                           6.7363e-01, -3.5936e-01],
│   │                         ...,
│   │                         [-8.7440e-01,  2.4647e-01,  2.0277e-01,  ...,  5.9892e-02,
│   │                           1.0237e+00, -1.8514e+00],
│   │                         [-4.8821e-01,  2.0842e-01,  1.3916e+00,  ..., -2.0656e+00,
│   │                           2.1805e+00, -7.5499e-01],
│   │                         [-8.4422e-01, -8.0240e-01,  1.5227e+00,  ..., -2.3180e-02,
│   │                          -7.4690e-01, -5.8966e-01]],
│   │               
│   │                        [[-1.2899e-01, -1.4767e+00, -6.3458e-01,  ...,  1.7832e+00,
│   │                           3.2020e-01,  8.7597e-03],
│   │                         [ 6.6648e-01, -2.7510e-01,  6.2410e-01,  ..., -2.2456e-02,
│   │                           1.6846e-01,  2.3513e-01],
│   │                         [-9.9967e-01, -9.4929e-01, -1.5338e-01,  ..., -5.0835e-01,
│   │                           5.9787e-01, -2.8127e-01],
│   │                         ...,
│   │                         [-1.7742e+00, -1.5492e+00, -1.7196e-01,  ...,  4.6514e-02,
│   │                           2.0010e-01, -9.3684e-01],
│   │                         [ 1.3208e+00,  4.5860e-01, -1.5585e-01,  ...,  1.1450e-01,
│   │                           1.1283e+00,  9.5827e-01],
│   │                         [-2.3822e-02, -1.9480e+00, -1.4155e+00,  ..., -3.8373e-01,
│   │                          -8.4783e-01,  8.4188e-01]],
│   │               
│   │                        [[-2.0908e-01,  4.9941e-01, -2.7140e-01,  ..., -1.4369e+00,
│   │                           5.0763e-01,  1.0217e+00],
│   │                         [-6.6292e-01,  4.5342e-01, -2.1183e+00,  ..., -2.5214e+00,
│   │                           1.2689e-01,  8.3478e-01],
│   │                         [-4.9568e-01,  3.3643e-01,  2.0124e-01,  ..., -1.1518e+00,
│   │                          -8.1403e-01,  3.1354e-01],
│   │                         ...,
│   │                         [-1.4396e+00, -2.6709e-01,  1.5399e-01,  ...,  5.1925e-01,
│   │                           2.7985e-01, -1.1648e+00],
│   │                         [ 7.6089e-01,  3.9733e-01, -3.4816e-01,  ...,  1.7702e+00,
│   │                          -1.5398e+00,  2.1652e-01],
│   │                         [-1.8386e-01,  3.8159e-01,  6.5463e-01,  ...,  6.1243e-02,
│   │                           4.6117e-01,  4.0854e-01]]],
│   │               
│   │               
│   │                       [[[ 7.5697e-01,  1.4649e-01,  7.1335e-01,  ...,  6.7925e-01,
│   │                           1.1384e-01,  1.0391e+00],
│   │                         [-1.1720e+00, -1.0798e+00,  2.0388e-01,  ...,  5.2457e-03,
│   │                           2.4723e-02,  6.5672e-01],
│   │                         [-3.5749e-01,  1.1721e+00,  2.3129e-01,  ...,  4.2506e-01,
│   │                           1.4432e+00, -8.9156e-01],
│   │                         ...,
│   │                         [-1.3140e+00,  1.8301e+00,  2.1734e-01,  ..., -4.4548e-02,
│   │                          -6.3429e-01, -3.3424e-01],
│   │                         [-9.5603e-01,  1.2492e+00,  1.4574e-01,  ...,  1.4821e+00,
│   │                          -4.8355e-01,  9.2583e-01],
│   │                         [ 5.6586e-01,  2.6708e-01,  1.4191e+00,  ..., -1.7003e+00,
│   │                          -2.1445e+00,  2.4752e-01]],
│   │               
│   │                        [[-1.3650e+00, -8.6744e-01,  2.4141e-01,  ..., -1.3019e+00,
│   │                           1.8370e+00,  1.7584e+00],
│   │                         [-8.2306e-02,  6.1359e-01, -4.6356e-01,  ..., -1.3901e+00,
│   │                          -4.8452e-01, -1.9530e-01],
│   │                         [ 6.2243e-01,  1.0755e+00,  1.1706e+00,  ...,  1.5467e+00,
│   │                          -2.3187e+00,  3.6055e-01],
│   │                         ...,
│   │                         [ 6.8893e-01, -1.0562e+00, -7.2646e-01,  ...,  3.4252e-01,
│   │                           4.0341e-02, -3.7157e-01],
│   │                         [-6.4628e-01,  4.9822e-01, -8.2966e-02,  ..., -2.4957e-01,
│   │                           4.9634e-01,  8.8230e-01],
│   │                         [ 2.0077e-01, -3.6408e-01, -7.9086e-01,  ...,  1.8752e+00,
│   │                          -1.4534e-01, -1.3230e+00]],
│   │               
│   │                        [[ 7.3649e-01, -1.3535e+00, -7.4957e-01,  ..., -2.0602e+00,
│   │                          -6.2822e-02, -1.2352e+00],
│   │                         [-3.8954e-01, -1.0413e+00,  7.5852e-01,  ...,  7.9690e-01,
│   │                          -2.6351e-01, -5.6950e-01],
│   │                         [ 6.4250e-01, -6.6629e-01,  7.0530e-01,  ..., -2.0414e-01,
│   │                          -1.9864e+00, -4.7203e-01],
│   │                         ...,
│   │                         [ 2.8049e-01,  1.3581e+00,  7.3969e-01,  ..., -4.3716e-01,
│   │                           1.6524e+00,  1.6888e+00],
│   │                         [-9.3943e-01,  2.8415e-01, -9.0094e-01,  ...,  2.7012e+00,
│   │                           2.8689e-01, -1.5171e+00],
│   │                         [-3.6873e-02, -1.9304e+00,  4.4248e-01,  ..., -3.0628e-01,
│   │                           2.8470e-02, -2.8992e-01]]]])
│   └── 'scalar' --> tensor([[-0.7870, -1.3864,  0.6002, -0.0935, -1.9169,  0.5229,  0.3602, -1.3059,
│                             -1.1352,  1.8609, -0.0462,  0.5435],
│                            [ 1.4911, -1.3666, -0.7389,  1.6854, -1.3274, -0.1017,  0.8480, -0.7213,
│                              0.0038, -0.6813, -2.1859, -0.7159],
│                            [-0.9157, -0.8866, -0.6136,  0.4199,  0.0781, -1.3129,  0.0330, -1.5632,
│                             -0.5328, -2.3863, -1.5438,  1.3180],
│                            [-1.0152, -2.2117, -1.2667,  1.4642,  3.0571,  2.4963, -1.0515, -0.6528,
│                             -0.0684,  1.5205,  1.8214, -0.5784]])
└── 'reward' --> tensor([[0.9162],
                         [0.4419],
                         [0.8313],
                         [0.9997]])

This code looks much simpler and clearer.