Stack Structured Data

When the tensors form the tree structures, they are often needed to be stacked together, like the torch.stack() implemented in torch.

Stack With Native PyTorch API

Here is the common code implement with native pytorch API.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch

B = 4


def get_item():
    return {
        'obs': {
            'scalar': torch.randn(12),
            'image': torch.randn(3, 32, 32),
        },
        'action': torch.randint(0, 10, size=(1,)),
        'reward': torch.rand(1),
        'done': False,
    }


data = [get_item() for _ in range(B)]


# execute `stack` op
def stack(data, dim):
    elem = data[0]
    if isinstance(elem, torch.Tensor):
        return torch.stack(data, dim)
    elif isinstance(elem, dict):
        return {k: stack([item[k] for item in data], dim) for k in elem.keys()}
    elif isinstance(elem, bool):
        return torch.BoolTensor(data)
    else:
        raise TypeError("not support elem type: {}".format(type(elem)))


stacked_data = stack(data, dim=0)
# validate
print(stacked_data)
assert stacked_data['obs']['image'].shape == (B, 3, 32, 32)
assert stacked_data['action'].shape == (B, 1)
assert stacked_data['reward'].shape == (B, 1)
assert stacked_data['done'].shape == (B,)
assert stacked_data['done'].dtype == torch.bool

The output should be like below, and the assertion statements can be all passed.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
{'obs': {'scalar': tensor([[ 0.1397,  0.6405, -0.1065, -1.1629,  0.5231,  0.9056, -0.1507, -1.7048,
         -0.2469, -1.0054, -0.9550, -1.1450],
        [ 1.0209,  0.0216, -1.1412,  0.0491,  0.0738,  0.3148,  0.1355,  0.2929,
         -0.8198, -1.4694, -0.1573, -0.9861],
        [ 0.2445,  0.6007,  1.1668, -1.0248, -1.0325, -0.8107,  0.0357,  1.0705,
          0.0232,  1.0331, -0.6108, -0.9902],
        [-0.1379, -1.0400, -0.0272, -0.2126, -0.1775,  1.3071,  0.2960,  0.5582,
         -0.1762, -0.1064, -1.8103, -0.5512]]), 'image': tensor([[[[-0.4326, -0.7533,  1.3030,  ..., -0.0109, -1.3185,  0.4280],
          [ 0.0829,  1.8533,  0.4719,  ..., -1.4861,  0.9368, -0.3270],
          [ 0.0865,  2.1496, -2.2310,  ..., -0.1413,  0.0513, -0.8146],
          ...,
          [ 1.2246, -1.0787,  0.2966,  ..., -0.4757, -0.1887,  1.8956],
          [ 0.4589, -0.5499, -1.6829,  ...,  0.0662,  1.2070, -0.5624],
          [ 0.6902,  0.6997, -0.0379,  ..., -1.4360,  0.9977,  0.9972]],

         [[ 0.9697,  0.8253, -0.0932,  ..., -0.8852,  0.0272, -0.6993],
          [-0.6284, -1.0246,  0.1896,  ...,  0.6979, -0.2650,  0.1428],
          [ 1.6016,  0.3169, -1.6228,  ...,  0.8058,  0.5034, -0.4536],
          ...,
          [ 0.2277,  0.1729, -1.7334,  ..., -0.5023, -1.7850, -0.0436],
          [-0.4577,  0.4500,  0.2988,  ...,  0.1346, -0.1855,  0.4559],
          [-1.0016, -0.2723,  0.6820,  ..., -0.4777,  1.0206,  0.3506]],

         [[-1.2456, -0.4237,  1.7488,  ...,  0.7549, -1.9808,  1.3219],
          [-0.6636,  2.4511, -1.8104,  ..., -1.4115, -0.5333,  1.4488],
          [ 0.1103,  0.6296, -3.5886,  ..., -1.6756, -1.2094,  1.4864],
          ...,
          [ 0.2035,  1.8552,  1.7461,  ..., -0.5683,  0.4337, -1.0125],
          [-1.0999,  0.7351,  0.9670,  ..., -1.2773,  0.6330, -1.3748],
          [ 0.9220,  0.0421, -1.3660,  ..., -0.0952, -1.6436,  1.7216]]],


        [[[ 0.9363,  0.5040,  1.5544,  ...,  1.2042, -0.1832, -0.4466],
          [ 0.4407,  0.7017,  1.6812,  ..., -0.1884, -0.4789, -0.5568],
          [-0.6145,  0.6685,  0.7179,  ...,  0.6435, -0.1549, -1.0045],
          ...,
          [-0.9856, -2.3838,  1.2274,  ..., -0.7904,  1.1943,  1.4193],
          [-0.6781, -0.0087, -0.4051,  ...,  0.0899,  0.1502,  1.4208],
          [-1.4857,  0.1149, -0.7937,  ...,  1.3302,  0.0909,  0.3794]],

         [[ 0.2771, -0.3458, -0.4623,  ...,  0.9665, -1.6586,  0.0382],
          [ 0.1610,  0.4554,  1.0350,  ...,  1.0425, -0.8898, -0.8385],
          [-0.2168,  0.8106,  2.9952,  ..., -0.9019,  1.6985,  1.4863],
          ...,
          [ 0.3996,  2.1054, -0.6097,  ...,  0.4439,  0.1815, -1.1867],
          [-0.8484,  1.9241,  0.8726,  ..., -1.0289,  0.6405, -0.0570],
          [ 0.2523, -0.1573,  1.5907,  ..., -0.7193, -0.7303, -0.0348]],

         [[-0.4859,  0.5751,  0.7009,  ...,  0.5231, -0.7711,  0.6045],
          [-0.0569, -0.6150, -1.8177,  ...,  0.9122, -0.0825,  0.8939],
          [ 0.9208, -1.6379,  1.2936,  ..., -0.4959, -0.1478, -0.2250],
          ...,
          [ 0.0937, -0.4282,  1.0310,  ...,  1.5147, -0.2100, -0.5660],
          [-0.6032,  0.0183,  0.1624,  ..., -1.9868,  3.0873, -1.3496],
          [-0.1165, -0.1251, -0.4253,  ..., -0.5325,  0.2728, -0.4957]]],


        [[[-0.2025, -1.7991,  0.7274,  ...,  0.3584,  0.5634, -0.0968],
          [ 0.7564,  1.1777, -1.2208,  ...,  1.1896, -0.9108, -0.3351],
          [-0.7297, -1.3296,  0.9482,  ..., -0.0066, -0.1052,  1.9654],
          ...,
          [ 1.1949, -0.1847, -0.1534,  ..., -1.9562, -0.0485,  1.6045],
          [ 0.4007,  0.9837,  0.2224,  ...,  0.9843,  1.3565,  1.1623],
          [ 0.5606, -0.2475,  0.9669,  ..., -0.8402, -0.6158,  1.1201]],

         [[ 0.5337,  0.3381, -0.1501,  ...,  0.7289, -0.4547, -1.5452],
          [ 0.3005,  1.2970,  0.4330,  ...,  0.4411, -0.5475,  0.9232],
          [ 0.7228,  0.2439, -0.6652,  ..., -0.2928, -1.8252, -1.1603],
          ...,
          [ 0.9397,  0.3396, -2.4223,  ...,  0.0322,  0.7380,  0.0104],
          [-1.2077,  0.5957,  0.1988,  ..., -2.6827, -0.2878, -0.5763],
          [-0.6298, -1.2216,  2.1716,  ...,  0.3720, -2.7269, -1.0368]],

         [[ 1.4139, -1.0098, -1.5199,  ..., -0.5062, -0.3032, -0.2048],
          [-0.5048,  0.6404, -0.5466,  ..., -0.7273, -0.1171,  0.0866],
          [ 0.8982,  0.1121, -0.7716,  ..., -2.7176, -1.1292,  0.1612],
          ...,
          [ 0.3561,  0.9135, -0.8747,  ...,  0.0966,  0.6425,  0.8907],
          [ 1.0174, -1.6984,  0.1494,  ..., -1.2426, -1.3041,  0.3580],
          [ 0.8126, -0.0934, -1.1277,  ..., -0.8676, -0.7366,  0.1619]]],


        [[[ 1.1737,  0.0895, -0.3679,  ...,  0.5905,  0.0775, -1.8457],
          [-0.9588,  0.1628,  0.8413,  ..., -0.3381, -0.0639, -0.3274],
          [ 0.5388,  0.3772,  0.1911,  ..., -1.1654, -0.8043, -0.2119],
          ...,
          [ 1.3184, -0.3690,  0.0170,  ...,  0.4111, -1.8371,  0.4956],
          [-1.4490, -2.0621,  1.3116,  ..., -1.6003,  0.1822, -1.0809],
          [ 1.3079, -0.4324, -0.0147,  ..., -0.5249,  0.0305, -0.8150]],

         [[-0.6108,  1.2957, -0.6788,  ..., -0.7383,  0.0433,  0.3279],
          [ 0.1887, -0.4337, -1.2826,  ...,  1.0833,  0.8152,  0.5958],
          [-0.4950, -0.7507, -0.2569,  ..., -0.3001,  0.0082,  1.3554],
          ...,
          [ 0.7512, -1.1572, -1.3979,  ..., -0.8164, -0.4833,  0.6999],
          [ 1.0026,  0.6994,  0.5976,  ...,  0.8389, -1.5561, -0.5754],
          [ 0.1015,  0.8575, -0.2490,  ...,  0.9199, -0.7382,  1.1388]],

         [[-0.0467,  1.4865, -1.3524,  ...,  0.2768,  1.2314, -1.3469],
          [ 2.1830, -0.9034,  1.0596,  ...,  1.7396,  0.2898, -0.3759],
          [-1.1247,  0.3101, -0.0416,  ...,  1.2717, -0.9376,  0.1824],
          ...,
          [ 0.5462,  0.5649, -0.7327,  ..., -0.1197, -0.7022, -1.4968],
          [ 1.7633,  0.9342,  0.4078,  ...,  1.3010, -0.7228,  1.0101],
          [ 0.3016, -0.9188,  0.5750,  ..., -0.0647,  1.3696,  0.8703]]]])}, 'action': tensor([[6],
        [2],
        [6],
        [1]]), 'reward': tensor([[0.1842],
        [0.3674],
        [0.0044],
        [0.9744]]), 'done': tensor([False, False, False, False])}

We can see that the structure process function need to be fully implemented (like the function stack). This code is actually not clear, and due to hard coding, if you need to support more data types (such as integer), you must make special modifications to the function.

Stack With TreeTensor API

The same workflow can be implemented with treetensor API like the code below.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch

import treetensor.torch as ttorch

B = 4


def get_item():
    return {
        'obs': {
            'scalar': torch.randn(12),
            'image': torch.randn(3, 32, 32),
        },
        'action': torch.randint(0, 10, size=(1,)),
        'reward': torch.rand(1),
        'done': False,
    }


data = [get_item() for _ in range(B)]
# execute `stack` op
data = [ttorch.tensor(d) for d in data]
stacked_data = ttorch.stack(data, dim=0)
# validate
print(stacked_data)
assert stacked_data.obs.image.shape == (B, 3, 32, 32)
assert stacked_data.action.shape == (B, 1)
assert stacked_data.reward.shape == (B, 1)
assert stacked_data.done.shape == (B,)
assert stacked_data.done.dtype == torch.bool

The output should be like below, and the assertion statements can be all passed as well.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
<Tensor 0x7f500410eb20>
├── 'action' --> tensor([[6],
│                        [7],
│                        [1],
│                        [9]])
├── 'done' --> tensor([False, False, False, False])
├── 'obs' --> <Tensor 0x7f4f9d398ee0>
│   ├── 'image' --> tensor([[[[-1.2692e+00, -3.6622e-01, -7.3840e-01,  ...,  1.9273e-01,
│   │                          -5.0934e-01,  2.0247e-01],
│   │                         [ 5.1316e-01, -1.4380e+00,  8.7029e-01,  ..., -5.8476e-01,
│   │                          -4.2869e-01, -2.7239e+00],
│   │                         [-2.1976e-01, -1.2909e+00, -1.5434e+00,  ..., -2.6732e-01,
│   │                          -9.7129e-01,  1.1480e-01],
│   │                         ...,
│   │                         [ 3.3290e-01,  1.4550e+00,  1.9952e-01,  ..., -1.1732e+00,
│   │                           1.9837e+00,  1.0909e+00],
│   │                         [-6.8405e-01,  4.1167e-01, -5.3117e-01,  ..., -6.2145e-01,
│   │                           1.4966e+00, -1.6404e+00],
│   │                         [-4.9550e-01,  1.7281e+00, -5.3675e-01,  ..., -1.1654e+00,
│   │                          -7.8275e-02,  3.7119e-01]],
│   │               
│   │                        [[ 1.3455e+00, -7.2077e-01, -8.2507e-02,  ...,  2.6581e-01,
│   │                          -9.8652e-01,  1.3191e+00],
│   │                         [ 5.2628e-01, -6.6392e-01, -1.1221e+00,  ...,  1.3940e+00,
│   │                          -6.6631e-01,  1.3138e+00],
│   │                         [ 1.1016e+00, -4.3340e-01,  3.6715e-01,  ...,  1.0000e+00,
│   │                          -4.9327e-01,  1.5104e-01],
│   │                         ...,
│   │                         [ 5.1366e-01, -7.5680e-02,  4.3060e-01,  ..., -4.1284e-01,
│   │                           9.0706e-01,  8.0087e-01],
│   │                         [ 1.0083e+00,  2.8025e-01,  1.3199e+00,  ...,  1.8280e+00,
│   │                           4.1746e+00, -1.6393e-01],
│   │                         [-1.1768e+00, -6.6227e-01,  1.0277e+00,  ...,  4.4549e-02,
│   │                          -1.6616e-01, -1.0124e+00]],
│   │               
│   │                        [[-1.6166e-01, -1.4278e-01, -6.4356e-02,  ..., -6.1412e-01,
│   │                          -4.9879e-01, -5.6753e-01],
│   │                         [-6.2176e-02,  8.4296e-01,  2.1192e-01,  ...,  9.0633e-01,
│   │                           2.6497e-01, -6.9571e-01],
│   │                         [-1.2853e+00, -8.3053e-01, -1.2962e+00,  ...,  3.5739e-01,
│   │                          -1.5596e+00, -1.9680e+00],
│   │                         ...,
│   │                         [ 1.1446e+00, -8.6353e-01,  1.7932e-01,  ..., -3.7121e-01,
│   │                          -7.2157e-01, -1.1217e+00],
│   │                         [ 7.3079e-01,  8.3081e-01,  2.2869e+00,  ..., -2.4642e-01,
│   │                           2.2952e-01,  4.3093e-01],
│   │                         [ 4.4519e-01,  1.2219e+00,  4.9790e-01,  ...,  2.1361e-01,
│   │                          -8.6046e-01,  9.2222e-01]]],
│   │               
│   │               
│   │                       [[[-1.1623e+00,  4.6029e-01, -5.9470e-01,  ...,  3.2578e-01,
│   │                          -2.4708e+00,  6.4365e-01],
│   │                         [-7.4677e-01,  1.2760e+00,  6.3892e-01,  ..., -1.1104e+00,
│   │                           8.4852e-01, -3.0662e-02],
│   │                         [-2.1731e+00, -3.2645e-01,  9.1738e-01,  ...,  6.7163e-01,
│   │                           6.6023e-01,  4.5575e-01],
│   │                         ...,
│   │                         [ 1.2953e+00,  4.2413e-01, -1.2185e-01,  ..., -1.5103e+00,
│   │                          -4.5888e-02, -1.6670e+00],
│   │                         [-3.9992e-01, -2.1961e-01, -1.6008e+00,  ..., -8.6469e-01,
│   │                           6.4168e-01,  2.4480e-02],
│   │                         [-3.8710e-01, -6.3891e-01, -1.0644e+00,  ...,  6.4404e-01,
│   │                           3.7341e-01, -2.7270e-01]],
│   │               
│   │                        [[-8.0999e-01, -1.9910e-01, -5.7344e-01,  ..., -1.9861e-01,
│   │                          -7.9934e-01,  2.5309e-01],
│   │                         [-1.3631e-01, -6.3120e-01, -1.1128e+00,  ...,  1.0277e+00,
│   │                           1.5828e+00, -5.9765e-01],
│   │                         [ 1.9970e+00, -7.9221e-01,  1.1281e+00,  ...,  1.5955e-01,
│   │                          -3.7999e-01, -1.1448e+00],
│   │                         ...,
│   │                         [ 3.5317e-01, -4.8090e-02, -1.1155e+00,  ..., -7.0478e-01,
│   │                           4.1332e-01, -7.7929e-01],
│   │                         [-6.6900e-01,  2.3488e+00,  1.0264e+00,  ...,  3.0631e-01,
│   │                          -1.5448e+00, -1.1410e+00],
│   │                         [ 1.1558e+00, -6.0560e-01, -1.1032e+00,  ..., -1.4310e+00,
│   │                          -1.5016e+00,  1.3021e+00]],
│   │               
│   │                        [[ 4.8918e-01,  1.0193e+00,  1.1744e+00,  ...,  1.2039e+00,
│   │                          -2.4318e+00, -1.9988e-01],
│   │                         [ 7.2306e-01,  5.0835e-01,  1.3305e-01,  ..., -1.0545e+00,
│   │                           7.5123e-01,  3.7912e-01],
│   │                         [ 5.8082e-01,  7.0993e-01, -8.0798e-01,  ..., -1.6195e+00,
│   │                          -5.0154e-01, -5.0263e-03],
│   │                         ...,
│   │                         [ 1.8146e+00, -1.6835e-01, -2.8061e-01,  ..., -1.5520e-01,
│   │                          -5.4571e-01,  5.6566e-02],
│   │                         [-3.0293e-01, -3.4937e-01, -9.6995e-01,  ...,  1.0733e+00,
│   │                           2.0882e+00,  1.1180e+00],
│   │                         [ 2.2188e-01,  1.8296e+00,  9.1848e-01,  ..., -8.4241e-01,
│   │                           6.4728e-01,  6.0259e-02]]],
│   │               
│   │               
│   │                       [[[-2.3750e-01,  3.2818e-01,  5.3622e-01,  ..., -7.4065e-01,
│   │                          -1.4016e-01, -1.3092e+00],
│   │                         [-1.0824e+00, -5.0342e-01,  3.5214e-01,  ...,  1.3930e+00,
│   │                           2.1095e-01,  4.4865e-02],
│   │                         [ 1.1067e+00, -4.8788e-02, -4.3917e-01,  ...,  7.3184e-01,
│   │                           2.4479e-01,  3.9616e-01],
│   │                         ...,
│   │                         [ 4.3837e-01, -6.5767e-01, -1.2888e-01,  ..., -2.6374e-01,
│   │                          -2.7671e+00, -9.5575e-01],
│   │                         [-2.4965e+00,  1.0269e-01, -1.1650e+00,  ...,  8.4358e-01,
│   │                           1.0494e+00,  4.7932e-01],
│   │                         [ 8.2518e-01, -7.2377e-01, -4.7852e-01,  ..., -7.9251e-01,
│   │                           7.2410e-01, -4.1895e-01]],
│   │               
│   │                        [[ 8.7116e-01,  7.0249e-01,  8.2566e-02,  ..., -2.7794e+00,
│   │                           1.6488e-01, -2.3619e+00],
│   │                         [-8.2394e-01,  7.5954e-01, -1.0429e+00,  ..., -7.6347e-01,
│   │                          -5.1096e-01,  2.8637e-01],
│   │                         [-9.1372e-01, -2.0921e+00, -7.6063e-01,  ...,  1.1659e+00,
│   │                          -1.9836e-03,  2.3760e-01],
│   │                         ...,
│   │                         [-1.4566e+00,  2.2453e-02,  1.5795e+00,  ...,  4.8395e-01,
│   │                          -8.3417e-01, -1.6380e+00],
│   │                         [-8.5328e-01,  4.2822e-01, -9.5144e-01,  ..., -8.2777e-01,
│   │                           4.7990e-01, -1.5050e-01],
│   │                         [ 6.6055e-01,  2.1597e-02,  1.0470e+00,  ..., -2.3445e-01,
│   │                          -2.4408e+00,  5.1953e-01]],
│   │               
│   │                        [[ 1.0166e+00,  1.5594e+00, -3.0586e-01,  ...,  9.2226e-01,
│   │                           3.2886e-01,  2.6476e-01],
│   │                         [-1.9432e+00, -2.2786e-01,  1.3511e+00,  ...,  1.4441e+00,
│   │                           7.4018e-01, -6.4395e-01],
│   │                         [ 7.6252e-02,  6.0436e-01,  1.2924e+00,  ..., -1.4633e+00,
│   │                           1.5242e+00, -1.6537e+00],
│   │                         ...,
│   │                         [ 1.3360e+00, -2.4210e-01, -1.3163e-01,  ...,  5.3742e-01,
│   │                          -1.8975e-01,  7.7575e-01],
│   │                         [ 3.3887e-01, -1.4851e-01,  1.5986e+00,  ..., -3.5859e-01,
│   │                          -2.5503e-01, -7.9668e-01],
│   │                         [-1.1391e+00, -1.0063e+00, -8.1631e-01,  ...,  8.4646e-01,
│   │                           7.0147e-01, -1.1986e+00]]],
│   │               
│   │               
│   │                       [[[ 6.1937e-01, -1.6306e+00, -3.2112e-01,  ..., -1.7948e-01,
│   │                           1.1583e+00,  3.3740e-01],
│   │                         [-5.8402e-01, -1.9505e-01,  4.3594e-01,  ...,  4.4074e-01,
│   │                          -5.8311e-01,  9.7460e-01],
│   │                         [-4.5605e-01,  1.0135e-01,  1.2391e+00,  ..., -8.6551e-01,
│   │                           7.0250e-01, -1.1408e+00],
│   │                         ...,
│   │                         [-3.3248e-01, -1.7820e+00, -1.4215e+00,  ..., -4.5832e-01,
│   │                          -3.9562e-01, -1.9363e+00],
│   │                         [-1.8826e+00,  2.7027e-01, -9.7251e-01,  ..., -4.6080e-01,
│   │                           9.8424e-01, -7.2023e-01],
│   │                         [ 9.9932e-02,  3.2561e-01,  1.0375e-01,  ..., -2.4991e-01,
│   │                           6.4828e-01, -1.0039e+00]],
│   │               
│   │                        [[-3.0340e-01,  9.4991e-02, -2.5039e-01,  ..., -9.9865e-01,
│   │                           7.0590e-01,  6.8293e-01],
│   │                         [-9.7423e-01,  2.5702e-01, -1.1515e+00,  ...,  2.0653e-01,
│   │                          -8.7691e-01, -9.5434e-01],
│   │                         [ 1.4993e+00,  3.1006e-01,  1.2097e+00,  ...,  7.1011e-01,
│   │                          -3.3426e-01,  1.1680e+00],
│   │                         ...,
│   │                         [-2.3830e-01,  1.3778e+00,  7.0230e-01,  ...,  1.4697e-01,
│   │                           5.3213e-01, -5.2457e-01],
│   │                         [-1.1461e+00,  1.9016e+00,  1.9572e+00,  ...,  5.9289e-01,
│   │                          -8.9165e-01, -1.0406e+00],
│   │                         [ 1.1991e+00,  1.2000e+00,  5.3297e-01,  ...,  1.6936e+00,
│   │                           1.9165e+00, -2.5502e-01]],
│   │               
│   │                        [[ 8.1804e-01,  4.8194e-01,  3.6908e-01,  ..., -4.2221e-01,
│   │                           1.0013e+00,  6.8501e-02],
│   │                         [-7.3687e-01,  9.7740e-01, -6.9388e-01,  ..., -1.0019e-01,
│   │                           2.6619e-01,  1.5235e+00],
│   │                         [ 4.0515e-02,  6.9092e-01, -3.1769e-01,  ...,  1.4087e+00,
│   │                          -8.6673e-01, -4.4229e-01],
│   │                         ...,
│   │                         [ 2.0953e+00,  1.4011e-01, -1.2230e+00,  ..., -1.3723e+00,
│   │                          -4.8912e-01,  3.2164e-01],
│   │                         [-9.7725e-01, -8.9874e-01, -1.0128e+00,  ...,  6.0906e-01,
│   │                          -4.5124e-01, -3.5709e-01],
│   │                         [ 9.6041e-02, -1.1079e+00, -4.9511e-01,  ...,  5.2546e-01,
│   │                          -1.3343e+00, -1.4234e+00]]]])
│   └── 'scalar' --> tensor([[-1.6646, -0.8516,  0.5725, -1.1030, -0.8802, -0.0249,  0.9284, -0.0486,
│                             -1.2981,  0.9735, -0.1460, -0.5829],
│                            [ 1.4272, -0.7911, -0.3393, -1.2050,  0.5979, -0.6951,  0.0240,  0.8651,
│                              0.0206, -0.5671,  1.3159,  0.5801],
│                            [-0.4131,  0.4651, -0.6594, -0.7940,  1.1277, -1.3491, -0.2353,  1.1156,
│                             -0.0998,  1.2126,  2.0984,  2.1589],
│                            [-0.6850, -1.0653, -1.2417,  1.1818,  0.9535, -0.4631,  0.1623, -0.1625,
│                              0.1475,  1.0661, -0.0755, -0.9905]])
└── 'reward' --> tensor([[0.8974],
                         [0.6996],
                         [0.1059],
                         [0.1139]])

This code looks much simpler and clearer.