Stack Structured Data

When the tensors form the tree structures, they are often needed to be stacked together, like the torch.stack() implemented in torch.

Stack With Native PyTorch API

Here is the common code implement with native pytorch API.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch

B = 4


def get_item():
    return {
        'obs': {
            'scalar': torch.randn(12),
            'image': torch.randn(3, 32, 32),
        },
        'action': torch.randint(0, 10, size=(1,)),
        'reward': torch.rand(1),
        'done': False,
    }


data = [get_item() for _ in range(B)]


# execute `stack` op
def stack(data, dim):
    elem = data[0]
    if isinstance(elem, torch.Tensor):
        return torch.stack(data, dim)
    elif isinstance(elem, dict):
        return {k: stack([item[k] for item in data], dim) for k in elem.keys()}
    elif isinstance(elem, bool):
        return torch.BoolTensor(data)
    else:
        raise TypeError("not support elem type: {}".format(type(elem)))


stacked_data = stack(data, dim=0)
# validate
print(stacked_data)
assert stacked_data['obs']['image'].shape == (B, 3, 32, 32)
assert stacked_data['action'].shape == (B, 1)
assert stacked_data['reward'].shape == (B, 1)
assert stacked_data['done'].shape == (B,)
assert stacked_data['done'].dtype == torch.bool

The output should be like below, and the assertion statements can be all passed.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
{'obs': {'scalar': tensor([[-0.5439,  1.5525, -2.5364,  0.4296,  0.2306, -1.6162, -0.5098, -0.3007,
         -0.9911, -0.1652, -1.4215,  1.7829],
        [-0.3130,  0.8931,  1.0154,  1.5778,  0.6681,  0.9067,  0.3437, -0.7323,
          1.2826, -1.3228, -0.3680,  1.1760],
        [-0.5315,  0.0356,  0.2379,  0.6581, -0.5206, -0.2500,  0.1329, -0.0054,
         -0.6031, -0.9142,  0.1459,  0.2512],
        [-1.9744, -0.4854,  1.5225,  0.9087, -0.3122,  1.0963, -0.8389,  0.7343,
          0.3420, -0.3607,  1.2227,  0.2868]]), 'image': tensor([[[[-1.0355, -2.3575,  0.7406,  ..., -0.4346,  1.4851, -0.8481],
          [ 1.2878, -0.7349, -0.3034,  ...,  0.7290,  0.8731,  0.0377],
          [-0.8369, -0.3971, -2.1880,  ..., -1.7327, -1.6128, -0.0922],
          ...,
          [ 1.0049, -1.3276,  1.2966,  ...,  0.4423,  0.3873, -1.4188],
          [ 0.9000, -0.6823,  1.5857,  ...,  3.1329,  1.5153,  0.6293],
          [-0.2699, -0.6091,  0.5405,  ...,  0.6354,  0.0171, -0.3068]],

         [[ 1.8242, -1.4873,  1.7514,  ...,  0.4835,  1.4030, -1.1831],
          [ 0.0241, -0.1789, -1.2658,  ...,  0.8562, -0.7430, -0.9667],
          [ 1.4516,  0.9898,  0.6914,  ..., -0.1246,  0.7257, -0.0312],
          ...,
          [ 1.0995,  0.9870, -0.5597,  ..., -0.5211, -1.7528,  0.6516],
          [ 0.0648,  2.3567, -0.8047,  ...,  0.5057, -2.7344,  0.3598],
          [-1.3461, -0.5811, -0.9836,  ...,  0.8184,  0.0543,  0.5045]],

         [[ 0.6419,  0.5425,  1.0155,  ..., -0.0896, -0.0247,  0.2880],
          [-0.0685, -1.2104, -1.0465,  ...,  1.9754,  0.4169,  2.7332],
          [ 1.2864,  0.4891, -1.0690,  ..., -1.3406, -0.7768,  0.3359],
          ...,
          [-0.1180, -0.4897,  0.1138,  ..., -1.2769,  0.6272,  0.3767],
          [ 0.4202, -2.0327, -0.4379,  ...,  0.3923,  1.4139, -0.6590],
          [ 0.3935, -0.9722,  0.7480,  ...,  0.9070, -1.8967,  0.5447]]],


        [[[-0.4684,  0.4955, -0.7595,  ...,  1.0047, -0.7355,  0.8434],
          [ 0.5379, -1.1372,  2.5876,  ..., -1.2932,  1.7314,  0.1961],
          [-1.0318, -0.3406,  0.0349,  ...,  1.6182, -0.6486, -0.1076],
          ...,
          [ 0.7020,  0.0742,  0.1570,  ...,  1.6998,  1.2663, -0.3665],
          [ 0.3755,  0.1074, -0.4854,  ...,  0.9299, -1.2508,  0.5001],
          [-0.2960,  0.6556, -1.6620,  ...,  0.1858,  0.3402, -0.2965]],

         [[-0.3216, -0.1692, -1.7769,  ..., -1.2512,  0.3778, -0.9057],
          [ 0.1462, -1.0335,  2.6344,  ...,  1.6624,  1.2709, -0.5711],
          [-1.1572,  0.9810,  1.2451,  ..., -0.9154,  0.8500,  0.5509],
          ...,
          [-0.1871,  1.3897, -0.6686,  ...,  0.0512, -0.3331, -1.5550],
          [ 1.8065, -0.1077, -0.2139,  ...,  0.6769, -0.7738,  0.5418],
          [-0.8979,  2.5827,  1.1243,  ..., -1.0532,  0.6558,  0.2633]],

         [[ 0.3315, -0.5877,  1.8886,  ..., -0.3790, -0.5413,  1.3775],
          [ 0.0484,  0.4370,  0.6870,  ...,  1.4169,  2.1045,  0.1620],
          [ 0.5599,  0.4293, -0.9194,  ...,  0.2937, -0.6722, -1.7155],
          ...,
          [-2.1214,  0.0862, -0.5391,  ..., -1.2697, -0.3438, -0.4146],
          [ 1.2405, -0.2034,  0.9086,  ...,  0.5504,  1.6138,  2.1500],
          [-0.3980,  0.5023, -0.7237,  ..., -0.6129, -0.6409,  0.1773]]],


        [[[ 0.8545, -1.1236, -0.1327,  ...,  0.5171,  0.8301,  1.6672],
          [-0.3385,  2.3121, -0.4159,  ..., -0.5434, -1.1837, -1.3264],
          [ 0.2922,  1.9178, -1.4020,  ...,  0.7444, -0.8822, -1.3208],
          ...,
          [ 0.9485,  0.6859, -0.0648,  ..., -0.4033, -0.0648,  1.1805],
          [ 0.0881,  0.3743,  0.6833,  ...,  0.2128, -0.3396, -1.0739],
          [-1.0450,  1.0900, -0.2487,  ...,  1.6288, -0.2440,  0.7266]],

         [[-0.4752, -0.6603, -2.3606,  ...,  1.0126,  0.8204,  0.4657],
          [ 0.4427,  0.5316, -0.3936,  ...,  1.0050, -1.5727,  2.3134],
          [ 0.2490, -0.9648, -0.3544,  ..., -1.4491, -1.4617, -0.1078],
          ...,
          [-1.6874, -1.4677, -0.1728,  ...,  0.5630, -1.3154, -1.4742],
          [ 0.7241,  1.3261, -0.1397,  ...,  1.1315,  0.7148, -1.7773],
          [-0.1018,  0.1336, -0.0907,  ...,  0.7593,  0.7086, -0.6521]],

         [[-0.3958, -2.0302, -0.4027,  ..., -0.1296,  1.0292,  1.1047],
          [ 1.7339, -1.1807,  0.7617,  ..., -0.5440,  0.5825,  0.3263],
          [ 1.3539, -0.3701, -0.1236,  ...,  0.7388, -1.2894,  2.4045],
          ...,
          [-1.7638, -0.2596, -1.1081,  ..., -0.7817,  0.9688,  1.6874],
          [-0.8217,  0.3193,  1.3100,  ..., -0.6221, -0.4938, -0.4989],
          [ 0.9137, -0.0173,  0.5279,  ..., -0.5549, -1.7686, -0.6575]]],


        [[[ 2.2258,  0.2731, -0.8125,  ...,  1.5199, -0.4135, -0.0206],
          [-1.3395, -1.1469,  0.5421,  ..., -0.5712,  0.0421,  1.6956],
          [-0.4454, -1.1940,  0.3461,  ..., -0.2274,  0.4773,  0.8016],
          ...,
          [-0.8675, -0.4071,  0.1694,  ...,  0.1774,  0.7157, -0.2496],
          [ 0.2533,  0.2742, -0.1155,  ..., -0.9778,  1.9997, -0.3028],
          [-0.3587, -0.2820,  0.7471,  ...,  0.0123, -0.0397,  0.1725]],

         [[ 0.7101,  0.8042,  1.1798,  ..., -0.3023,  0.5312, -0.4634],
          [ 0.7873, -0.8710,  1.0618,  ..., -0.1434, -0.4286, -0.2138],
          [ 0.9074, -2.2004, -0.2696,  ...,  0.0903, -1.9921, -2.3305],
          ...,
          [-0.9613, -0.9245,  1.2956,  ..., -1.0975, -0.0289,  1.3185],
          [-1.9933,  1.9845,  0.2679,  ..., -1.2519,  0.1491, -1.1180],
          [-0.4625,  0.2501, -0.9873,  ..., -1.2893, -0.2134, -0.8583]],

         [[ 1.0140, -1.0176,  1.6116,  ...,  1.0195,  0.1082,  0.2236],
          [-0.6933,  0.6982, -0.7072,  ...,  1.2718,  0.6256,  1.9671],
          [-1.0142,  1.3395,  0.4521,  ..., -0.9776,  0.1666, -0.3324],
          ...,
          [ 0.4892,  0.2550, -0.5979,  ..., -0.8815, -1.1647,  1.3134],
          [ 0.1590, -1.5928, -1.0149,  ..., -0.4396, -0.6207, -2.0290],
          [-0.4404, -1.1780, -0.3131,  ...,  0.2115, -0.9426, -0.6174]]]])}, 'action': tensor([[0],
        [5],
        [5],
        [2]]), 'reward': tensor([[0.6205],
        [0.0102],
        [0.0746],
        [0.1737]]), 'done': tensor([False, False, False, False])}

We can see that the structure process function need to be fully implemented (like the function stack). This code is actually not clear, and due to hard coding, if you need to support more data types (such as integer), you must make special modifications to the function.

Stack With TreeTensor API

The same workflow can be implemented with treetensor API like the code below.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch

import treetensor.torch as ttorch

B = 4


def get_item():
    return {
        'obs': {
            'scalar': torch.randn(12),
            'image': torch.randn(3, 32, 32),
        },
        'action': torch.randint(0, 10, size=(1,)),
        'reward': torch.rand(1),
        'done': False,
    }


data = [get_item() for _ in range(B)]
# execute `stack` op
data = [ttorch.tensor(d) for d in data]
stacked_data = ttorch.stack(data, dim=0)
# validate
print(stacked_data)
assert stacked_data.obs.image.shape == (B, 3, 32, 32)
assert stacked_data.action.shape == (B, 1)
assert stacked_data.reward.shape == (B, 1)
assert stacked_data.done.shape == (B,)
assert stacked_data.done.dtype == torch.bool

The output should be like below, and the assertion statements can be all passed as well.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
<Tensor 0x7f530074eb20>
├── 'action' --> tensor([[8],
│                        [9],
│                        [4],
│                        [2]])
├── 'done' --> tensor([False, False, False, False])
├── 'obs' --> <Tensor 0x7f5299c11ee0>
│   ├── 'image' --> tensor([[[[ 1.7141e-02, -3.5320e-01, -8.1959e-01,  ...,  1.7016e-01,
│   │                          -1.4725e+00, -5.2514e-01],
│   │                         [ 1.2207e+00, -7.8884e-01, -1.1706e+00,  ...,  5.1108e-01,
│   │                          -1.9477e-01, -1.8075e+00],
│   │                         [-1.4449e+00,  1.7678e+00, -1.0331e+00,  ...,  1.0396e+00,
│   │                          -1.5172e+00, -2.9562e-01],
│   │                         ...,
│   │                         [ 1.2865e+00, -3.2242e-01,  2.1704e-01,  ..., -9.5423e-01,
│   │                          -2.2839e+00, -3.7863e-01],
│   │                         [-1.2139e+00, -4.6141e-01,  9.5783e-02,  ...,  2.1576e-01,
│   │                           1.6223e+00, -2.0468e-01],
│   │                         [-2.8119e-01,  4.1379e-01,  4.2238e-01,  ..., -3.8607e-01,
│   │                           1.0180e-03, -1.1905e+00]],
│   │               
│   │                        [[ 1.1226e-01,  5.6945e-01, -7.8552e-01,  ...,  3.5852e-01,
│   │                           1.7758e+00,  3.8584e-01],
│   │                         [-1.3407e-01,  9.5511e-01,  5.4485e-01,  ..., -2.3804e-02,
│   │                          -2.0653e-01,  2.9108e-01],
│   │                         [-1.8569e+00,  1.7579e-01,  1.7710e+00,  ..., -1.2855e+00,
│   │                           1.1359e+00,  9.0804e-02],
│   │                         ...,
│   │                         [-7.7216e-01,  4.7848e-01, -5.8516e-02,  ...,  8.7314e-01,
│   │                           1.1225e+00, -6.6931e-01],
│   │                         [-3.3770e-01, -3.1907e-01,  1.8266e-01,  ..., -3.5387e-02,
│   │                          -1.3989e-01,  4.1891e-01],
│   │                         [ 1.9748e+00,  1.7973e-01, -3.6963e-01,  ...,  1.5167e+00,
│   │                          -2.9495e-01, -8.8386e-01]],
│   │               
│   │                        [[-1.0222e+00,  4.4665e-01,  9.0526e-01,  ...,  2.6522e+00,
│   │                           3.1822e-01, -3.4724e-01],
│   │                         [-7.2270e-01, -1.5879e+00, -1.5155e-01,  ..., -5.5952e-01,
│   │                           1.6414e-01,  1.8491e-01],
│   │                         [-6.0023e-02,  1.6337e+00,  2.1997e+00,  ...,  1.9177e+00,
│   │                           1.0469e+00, -2.3642e-01],
│   │                         ...,
│   │                         [-6.7499e-01,  1.1351e+00, -1.4166e+00,  ...,  3.8554e-01,
│   │                          -8.1761e-01, -1.0600e+00],
│   │                         [ 3.4747e-01,  1.0231e+00,  6.4722e-01,  ...,  4.0558e-02,
│   │                          -7.0103e-01, -1.6931e-01],
│   │                         [-8.2901e-01, -9.0920e-01,  8.9393e-01,  ..., -1.2780e+00,
│   │                          -5.2976e-01, -4.9009e-01]]],
│   │               
│   │               
│   │                       [[[-1.4514e+00, -1.2475e+00, -5.0951e-01,  ..., -6.6949e-02,
│   │                           1.0540e+00,  2.9280e-01],
│   │                         [-6.7701e-01,  7.7114e-01, -1.4780e+00,  ..., -1.5032e+00,
│   │                           9.2649e-02,  7.0381e-01],
│   │                         [ 1.1677e+00,  8.5350e-01,  1.1849e+00,  ..., -1.4840e+00,
│   │                          -9.2543e-01, -6.7476e-01],
│   │                         ...,
│   │                         [ 5.4510e-01,  4.1561e-01,  1.0003e+00,  ..., -6.1577e-02,
│   │                          -9.9260e-01, -3.8822e-01],
│   │                         [-9.5394e-02, -8.1892e-01,  2.6548e-01,  ..., -1.9137e+00,
│   │                           2.3651e-01, -3.5038e-01],
│   │                         [ 1.1033e+00,  9.9671e-01, -2.1790e+00,  ..., -4.1062e-01,
│   │                          -2.3804e+00, -9.6401e-02]],
│   │               
│   │                        [[-5.5404e-01,  7.1094e-02,  7.5669e-01,  ..., -1.1730e+00,
│   │                           1.4899e+00, -4.4760e-01],
│   │                         [ 1.4159e+00,  6.9721e-01, -7.4730e-01,  ...,  1.5125e-01,
│   │                          -3.5389e-01,  1.1318e+00],
│   │                         [-1.2084e+00, -2.7884e+00, -5.9933e-01,  ..., -2.1919e+00,
│   │                          -7.9053e-01,  1.6682e+00],
│   │                         ...,
│   │                         [-5.8191e-01, -7.6311e-01, -9.5547e-01,  ..., -1.2395e+00,
│   │                           9.7180e-01,  1.4802e+00],
│   │                         [ 1.0860e+00,  5.9196e-01, -4.4113e-01,  ..., -4.2583e-02,
│   │                          -1.0151e+00,  1.3984e+00],
│   │                         [ 4.0421e-01, -9.9210e-01, -1.6687e-01,  ..., -1.4863e+00,
│   │                           9.7970e-01,  3.7317e-01]],
│   │               
│   │                        [[-6.5122e-01, -9.1908e-02,  1.3201e+00,  ...,  7.0011e-01,
│   │                          -1.4464e+00, -8.0954e-01],
│   │                         [-3.9242e-01, -1.4340e-01, -7.0892e-01,  ...,  8.7134e-01,
│   │                          -7.1900e-02, -2.6464e-01],
│   │                         [-5.2406e-01, -2.4794e+00, -1.3888e+00,  ..., -5.3002e-01,
│   │                           1.3412e+00,  2.6652e-01],
│   │                         ...,
│   │                         [ 1.6153e+00, -1.3309e+00,  6.4746e-01,  ..., -1.5601e+00,
│   │                          -1.2026e+00, -9.8497e-01],
│   │                         [-2.0237e+00, -7.3026e-01, -3.1804e-01,  ..., -5.9028e-01,
│   │                          -1.0709e+00,  6.4530e-01],
│   │                         [-5.8922e-01, -1.0129e-01,  1.6494e+00,  ..., -1.3154e+00,
│   │                          -2.2868e-01, -8.1147e-03]]],
│   │               
│   │               
│   │                       [[[-9.1397e-01,  3.6510e-01, -5.3084e-01,  ..., -6.3248e-01,
│   │                          -1.0458e+00,  1.4122e-01],
│   │                         [ 2.7955e-01,  2.0059e+00, -7.4364e-01,  ...,  1.8276e+00,
│   │                          -6.5129e-01, -1.3266e-01],
│   │                         [ 6.2949e-02, -1.0731e+00,  6.2269e-01,  ..., -8.4487e-01,
│   │                           1.0352e+00,  3.9690e-02],
│   │                         ...,
│   │                         [ 6.6964e-01,  3.0586e-01, -4.4805e-01,  ..., -7.2597e-02,
│   │                           6.3003e-01,  1.1989e+00],
│   │                         [ 2.1032e+00,  3.5461e-01, -9.9990e-01,  ..., -5.1090e-01,
│   │                          -1.2749e-02,  2.9062e-01],
│   │                         [ 1.0360e+00,  7.7144e-01,  2.8145e-02,  ..., -6.4996e-01,
│   │                          -2.0149e+00, -9.5965e-01]],
│   │               
│   │                        [[ 1.8540e+00, -9.9252e-02, -1.0720e+00,  ..., -1.1357e+00,
│   │                           1.8413e+00,  1.3561e+00],
│   │                         [-4.0232e-01, -5.1660e-01,  3.7654e-01,  ...,  8.2641e-02,
│   │                           1.0970e+00, -5.6366e-01],
│   │                         [-6.8471e-01, -4.5682e-01, -5.2672e-01,  ...,  8.7769e-01,
│   │                           3.7786e-01,  1.4260e+00],
│   │                         ...,
│   │                         [ 1.1285e+00, -1.3212e+00,  2.5492e-01,  ..., -3.3181e-01,
│   │                          -1.2761e+00, -1.5223e+00],
│   │                         [ 2.2034e-01, -9.5642e-01, -1.3952e+00,  ..., -1.5355e+00,
│   │                          -8.8444e-01,  8.0211e-01],
│   │                         [-8.1752e-01,  2.6213e-01,  3.3672e-01,  ...,  4.6514e-01,
│   │                           2.5935e-01, -8.8353e-01]],
│   │               
│   │                        [[-1.9074e-03,  5.7054e-01, -1.5113e+00,  ...,  8.5527e-01,
│   │                           1.1988e+00, -1.5118e+00],
│   │                         [-9.8853e-01,  1.7142e-01,  6.6264e-01,  ...,  8.1053e-01,
│   │                          -1.7684e+00,  1.7861e+00],
│   │                         [ 1.4883e+00,  1.0056e+00,  1.3530e+00,  ..., -2.2767e-01,
│   │                           8.2962e-02, -6.7558e-01],
│   │                         ...,
│   │                         [-7.4116e-02,  3.8104e+00,  4.3449e-01,  ...,  1.8868e-01,
│   │                           1.1739e+00, -2.7602e-01],
│   │                         [-2.1932e+00,  5.4622e-01, -1.1945e-02,  ..., -1.2092e+00,
│   │                          -1.3727e+00,  8.8010e-01],
│   │                         [ 2.8076e-01, -8.9285e-01,  1.2662e+00,  ...,  1.8903e+00,
│   │                          -1.0978e+00, -1.9518e+00]]],
│   │               
│   │               
│   │                       [[[-2.0331e+00,  5.7457e-01, -1.2417e+00,  ...,  1.1525e+00,
│   │                           6.0338e-01,  3.1732e-01],
│   │                         [ 2.3089e-01,  8.6192e-02, -2.8353e-01,  ...,  4.5224e-01,
│   │                          -1.4627e+00, -5.6552e-01],
│   │                         [-1.1575e+00, -7.6618e-01, -2.8471e+00,  ...,  1.2460e+00,
│   │                          -3.8716e-01,  7.5730e-01],
│   │                         ...,
│   │                         [-9.3802e-01, -9.7460e-01,  5.0438e-02,  ...,  4.8982e-01,
│   │                          -1.7945e+00, -9.6353e-02],
│   │                         [-6.6366e-01,  1.2828e-01, -7.4450e-01,  ...,  1.5780e-01,
│   │                           1.3645e+00, -2.7127e+00],
│   │                         [ 1.0671e+00,  6.6669e-01,  8.7063e-01,  ...,  1.7169e+00,
│   │                           2.1508e+00, -1.3380e+00]],
│   │               
│   │                        [[ 6.7575e-01,  8.0231e-01,  1.4304e+00,  ..., -8.3732e-01,
│   │                          -5.7430e-01, -6.4904e-01],
│   │                         [ 1.3922e+00, -2.6879e-01, -2.1517e+00,  ...,  2.0281e+00,
│   │                           1.6727e+00,  8.7112e-01],
│   │                         [ 3.4302e-01,  7.6941e-01, -3.5132e-01,  ...,  1.0489e+00,
│   │                           1.9925e-01,  1.7552e+00],
│   │                         ...,
│   │                         [-4.4358e-01,  2.9077e-01, -1.1450e+00,  ..., -2.5655e-01,
│   │                           1.2108e+00,  6.7779e-02],
│   │                         [ 1.3177e-01, -1.1058e+00, -1.5541e+00,  ...,  1.3203e+00,
│   │                          -4.5277e-01, -9.5318e-01],
│   │                         [ 8.1782e-01,  8.2245e-01,  1.1387e+00,  ..., -9.7383e-01,
│   │                           9.9616e-01, -8.4592e-01]],
│   │               
│   │                        [[ 1.2132e+00, -2.0329e-01,  4.6961e-03,  ...,  3.8667e-01,
│   │                           9.0289e-01,  1.4949e+00],
│   │                         [-1.7887e+00, -5.9950e-01,  2.0701e+00,  ...,  1.5585e-02,
│   │                           1.8926e+00,  8.1583e-01],
│   │                         [-4.5342e-01,  2.9010e-01, -7.0527e-01,  ...,  1.1800e-01,
│   │                          -5.5185e-01,  9.8383e-01],
│   │                         ...,
│   │                         [-1.2631e+00,  3.2361e+00,  4.5160e-01,  ...,  9.0120e-01,
│   │                           6.2643e-01,  5.5287e-01],
│   │                         [-1.2656e-01, -1.1531e+00, -5.3638e-02,  ...,  4.6664e-01,
│   │                           1.3861e+00, -5.1134e-01],
│   │                         [-1.3833e+00,  9.4645e-01,  7.8542e-02,  ...,  1.0083e+00,
│   │                          -1.3574e+00,  9.0209e-01]]]])
│   └── 'scalar' --> tensor([[ 1.3279,  0.2754,  0.2602,  1.0422,  1.5870, -1.4383, -0.5448, -0.0147,
│                              0.9718,  1.7773,  0.6544,  0.4573],
│                            [-0.7747,  1.4149, -0.0556, -0.1000,  0.6594,  1.0650,  0.4387,  0.5963,
│                             -1.9563,  0.9282,  1.2347,  0.0158],
│                            [ 0.3438,  0.2292,  2.0654,  0.3097,  0.4312,  1.0012,  0.2881, -0.4297,
│                             -0.1351, -0.1354,  0.7835, -1.4891],
│                            [-0.2885,  0.2340, -0.3504,  0.1722, -0.4840, -0.6614,  1.0234, -0.9891,
│                             -0.7920,  1.1529,  0.4562,  0.9874]])
└── 'reward' --> tensor([[0.8738],
                         [0.0396],
                         [0.1549],
                         [0.1968]])

This code looks much simpler and clearer.