Stack Structured Data

When the tensors form the tree structures, they are often needed to be stacked together, like the torch.stack() implemented in torch.

Stack With Native PyTorch API

Here is the common code implement with native pytorch API.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch

B = 4


def get_item():
    return {
        'obs': {
            'scalar': torch.randn(12),
            'image': torch.randn(3, 32, 32),
        },
        'action': torch.randint(0, 10, size=(1,)),
        'reward': torch.rand(1),
        'done': False,
    }


data = [get_item() for _ in range(B)]


# execute `stack` op
def stack(data, dim):
    elem = data[0]
    if isinstance(elem, torch.Tensor):
        return torch.stack(data, dim)
    elif isinstance(elem, dict):
        return {k: stack([item[k] for item in data], dim) for k in elem.keys()}
    elif isinstance(elem, bool):
        return torch.BoolTensor(data)
    else:
        raise TypeError("not support elem type: {}".format(type(elem)))


stacked_data = stack(data, dim=0)
# validate
print(stacked_data)
assert stacked_data['obs']['image'].shape == (B, 3, 32, 32)
assert stacked_data['action'].shape == (B, 1)
assert stacked_data['reward'].shape == (B, 1)
assert stacked_data['done'].shape == (B,)
assert stacked_data['done'].dtype == torch.bool

The output should be like below, and the assertion statements can be all passed.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
{'obs': {'scalar': tensor([[ 0.5784,  0.3535,  0.4976, -0.8075,  0.7813, -1.2106,  0.3359,  0.2939,
          1.6414, -0.0313,  1.1735,  0.3287],
        [-1.6757, -1.3516, -0.7728,  0.3899, -0.9174,  0.3017, -1.3204,  0.6177,
         -1.0680, -0.6462, -0.2077, -0.1928],
        [ 0.1499, -0.9025,  0.7933,  1.4425,  0.6130,  0.9109, -0.8047,  0.0022,
          0.6294,  0.2999,  1.2497, -0.1656],
        [ 1.5044,  0.5726,  0.3693, -0.3974,  0.1145, -1.4203,  0.2160,  0.2163,
          0.6900,  0.5886, -0.2948, -0.3692]]), 'image': tensor([[[[ 1.6523e+00, -1.7520e+00, -6.3636e-01,  ...,  1.3818e-01,
            2.9619e-01, -2.4713e-01],
          [ 1.4577e+00, -1.2294e-01, -1.7670e+00,  ..., -1.0706e+00,
            1.0098e-01,  4.3042e-01],
          [-4.9565e-02,  9.7667e-01, -3.6495e-01,  ...,  9.5388e-01,
           -5.8718e-01,  6.2887e-01],
          ...,
          [-5.5253e-01,  1.3820e+00, -1.9481e+00,  ..., -3.3860e-01,
           -9.9484e-01, -6.5657e-01],
          [-1.7435e+00, -5.9881e-01, -5.9248e-02,  ..., -5.3646e-01,
           -2.2727e-01,  9.4532e-01],
          [ 2.8682e+00,  1.2511e+00, -6.9985e-01,  ...,  1.5079e+00,
           -4.4694e-01,  6.7518e-02]],

         [[-1.2851e+00, -6.1747e-01, -5.5490e-01,  ..., -8.0671e-01,
           -2.6908e+00,  1.1359e-01],
          [ 5.1843e-01, -3.1802e-01, -2.1712e+00,  ..., -8.5754e-02,
            1.3466e+00,  5.3193e-02],
          [-1.5905e-01, -9.4360e-01,  2.3216e+00,  ...,  1.0666e+00,
            1.5072e+00, -8.1482e-01],
          ...,
          [-2.0394e+00, -1.3302e+00, -1.0421e+00,  ..., -7.5486e-01,
            6.0526e-01,  1.8305e-02],
          [-2.9571e-01, -1.1280e+00, -9.4381e-01,  ...,  5.6464e-01,
            1.4815e+00, -2.3876e-01],
          [-1.1972e+00,  5.5697e-01,  7.7642e-01,  ..., -2.2847e-01,
            1.3118e+00,  6.8881e-02]],

         [[ 2.9727e+00,  3.8465e-01, -2.3062e+00,  ..., -1.6097e+00,
           -1.1750e+00,  2.5572e-01],
          [ 5.1409e-01, -7.7778e-02,  1.1075e+00,  ..., -1.3984e+00,
           -1.1662e+00,  1.7398e+00],
          [ 2.9467e-01, -4.2621e-01,  3.4111e-01,  ...,  5.8067e-02,
            9.6240e-02, -3.7722e-01],
          ...,
          [-4.9899e-03, -1.4938e+00,  2.7438e+00,  ..., -1.8318e+00,
           -1.3049e+00,  2.7904e-01],
          [-3.1831e-01, -7.9892e-01, -2.3295e-01,  ..., -1.5867e+00,
            1.1413e+00, -1.4698e-01],
          [-1.1445e+00,  1.2240e+00,  1.3163e-01,  ..., -3.0017e-01,
           -1.3561e-01, -1.0129e-01]]],


        [[[ 9.6571e-01, -8.9200e-04, -7.1341e-01,  ...,  4.8646e-01,
            1.0451e+00,  2.9915e+00],
          [ 9.7016e-02,  2.2937e+00, -1.2632e+00,  ..., -2.1610e+00,
           -2.1801e-01,  1.9183e+00],
          [ 2.3081e+00, -6.4152e-01,  1.0609e+00,  ..., -1.0208e+00,
           -1.0675e+00, -1.6837e+00],
          ...,
          [ 7.5154e-01, -1.9395e-01,  2.6245e-01,  ..., -1.0729e+00,
           -7.7819e-01, -1.2370e+00],
          [-5.8781e-01,  1.0393e+00, -1.0454e+00,  ..., -1.7417e+00,
            1.7743e+00,  1.7797e-01],
          [-1.2443e-01,  2.7471e-01, -5.6317e-01,  ..., -1.7718e+00,
            8.5358e-01,  6.5662e-01]],

         [[-1.7653e+00,  1.5323e+00, -3.5561e-01,  ...,  1.7633e+00,
           -1.2588e-01,  5.1816e-01],
          [ 2.8463e-01,  1.7686e+00, -8.6031e-01,  ...,  1.0575e-01,
            1.8197e+00, -1.1922e-01],
          [-3.7177e-01, -1.0421e+00, -4.6851e-01,  ..., -9.0616e-01,
            8.3851e-01,  1.9800e+00],
          ...,
          [-1.3283e+00,  2.2063e+00,  1.8353e-01,  ..., -8.2135e-01,
           -1.7690e-01,  5.2786e-01],
          [ 7.8178e-01,  6.3860e-02, -1.1851e+00,  ...,  1.5998e-01,
           -3.8092e-01,  1.0326e+00],
          [-9.6547e-01, -2.1089e+00,  2.8610e-01,  ..., -1.3771e-01,
           -4.3200e-01, -1.2824e+00]],

         [[-1.5353e+00, -8.2541e-01,  4.7273e-01,  ...,  8.8199e-01,
           -1.2140e-01, -2.0026e+00],
          [-1.3996e+00,  8.8113e-01,  4.3590e-01,  ..., -8.5575e-01,
           -1.5004e+00, -1.8567e+00],
          [-7.1134e-01,  3.9540e-02,  3.4035e-01,  ...,  7.3982e-01,
           -1.5314e+00,  4.3028e-01],
          ...,
          [ 1.8231e+00,  1.1007e+00, -6.4464e-01,  ...,  1.9537e-01,
            5.4914e-01,  4.6801e-01],
          [ 1.1215e+00,  8.3756e-01,  5.6014e-01,  ...,  4.0299e-02,
           -1.7927e+00, -3.3347e-01],
          [-4.9385e-01, -5.3484e-01, -5.1721e-01,  ..., -1.2701e+00,
           -5.2076e-01,  7.2898e-01]]],


        [[[-1.0970e+00, -8.3543e-01,  7.1867e-02,  ...,  8.5818e-02,
           -4.9403e-01,  9.5811e-01],
          [-9.0140e-01, -1.8565e+00,  1.7330e-01,  ...,  6.1784e-01,
            1.2642e+00,  4.4792e-01],
          [ 5.1121e-01, -3.3872e-01,  6.3232e-01,  ...,  2.3091e-01,
            1.8662e+00,  5.8821e-01],
          ...,
          [-4.8641e-01,  6.4698e-01, -1.2678e+00,  ...,  5.2124e-01,
            6.6244e-01,  1.4152e+00],
          [ 1.0591e+00, -1.4838e+00, -2.0090e+00,  ...,  1.0824e+00,
           -9.0392e-01, -2.0171e-01],
          [-1.0590e+00,  2.7747e-01,  9.5979e-01,  ..., -2.6590e+00,
           -7.9162e-01,  1.3431e-02]],

         [[ 5.7201e-01,  2.3254e-01, -3.2918e-01,  ...,  3.2619e-01,
            9.4716e-02,  1.2581e+00],
          [ 5.7841e-01,  1.9476e+00, -1.3422e+00,  ...,  8.2653e-01,
           -2.2005e+00,  8.1949e-02],
          [ 2.0757e+00,  4.0436e-01,  3.1606e-01,  ...,  1.2785e-01,
           -4.3243e-02,  1.1752e+00],
          ...,
          [ 2.0156e-01,  7.7485e-01,  1.4912e+00,  ...,  1.4990e+00,
           -9.3711e-01,  1.2386e-01],
          [ 6.1004e-01,  1.4886e+00, -8.1336e-01,  ..., -2.1933e+00,
            1.3010e+00, -4.4808e-01],
          [-1.3140e+00, -4.7222e-01,  9.7576e-01,  ...,  1.5509e+00,
           -2.0406e-01,  6.7849e-01]],

         [[-7.6999e-01,  1.0107e+00,  6.7580e-01,  ...,  3.2558e-01,
            1.3498e-02, -7.0631e-01],
          [-6.9439e-01,  1.6513e+00, -1.2722e+00,  ..., -2.7790e-01,
            4.5541e-01,  4.6749e-01],
          [ 9.1847e-01, -9.4091e-02, -1.7262e-01,  ...,  4.1319e-02,
            3.3527e-01,  8.2720e-01],
          ...,
          [ 1.2652e+00,  5.6968e-01,  1.4446e-01,  ..., -1.4031e+00,
            1.1444e+00, -4.3358e-01],
          [ 6.9904e-01,  1.6406e-01, -4.4530e-01,  ..., -3.6353e-01,
           -3.2857e-01,  7.4968e-01],
          [ 2.7647e-01, -1.1737e+00, -1.3222e+00,  ...,  7.4433e-01,
            2.1960e-02,  1.9756e-01]]],


        [[[-2.9970e-01,  2.0543e-01, -4.3324e-01,  ..., -3.2860e-01,
            7.6650e-01, -2.2200e+00],
          [ 1.1938e+00, -4.8465e-01, -3.6081e-01,  ...,  9.7356e-02,
           -2.6552e+00, -6.3208e-01],
          [-9.1986e-01, -2.0949e+00, -2.0733e+00,  ..., -3.6779e-01,
           -5.5521e-01, -9.9669e-01],
          ...,
          [-4.7520e-01,  8.0063e-01, -1.6425e+00,  ...,  9.0609e-01,
            3.6046e-01,  9.8889e-02],
          [ 5.4142e-01,  3.7725e-01, -3.3522e-01,  ..., -1.3885e-01,
           -7.3564e-01,  2.5317e-01],
          [ 2.8832e-01,  6.6266e-01,  9.9882e-01,  ..., -7.5650e-01,
            2.9723e-01, -1.0744e-01]],

         [[-2.6866e-03,  3.1956e-01,  1.4829e+00,  ..., -8.7006e-01,
           -4.7296e-01, -5.4673e-01],
          [-9.8001e-01,  5.4135e-01,  1.0210e+00,  ...,  2.1763e+00,
           -3.9719e-01, -5.6080e-01],
          [ 6.6867e-01,  2.0534e+00, -1.7178e+00,  ...,  6.8412e-01,
            6.7001e-01, -2.4218e-02],
          ...,
          [ 3.9550e-01,  7.9910e-01, -1.1495e+00,  ..., -3.2503e-01,
           -4.0351e-01,  8.5208e-01],
          [ 1.3896e-02,  7.7222e-01, -4.8240e-01,  ..., -1.1722e+00,
           -1.4013e-03, -4.2187e-01],
          [ 7.0360e-01,  8.9505e-01, -1.1098e+00,  ..., -1.3128e+00,
            1.5266e+00,  4.2579e-01]],

         [[ 2.3705e-01,  8.1205e-01, -9.5006e-01,  ...,  2.8167e-01,
           -6.1279e-01, -9.8695e-02],
          [-4.0517e-01,  6.1329e-01,  1.7191e+00,  ...,  1.7165e-01,
            1.3566e-01,  5.7121e-01],
          [-4.7057e-01, -2.0436e-01,  3.1502e-01,  ...,  2.6031e-01,
           -1.3072e-02,  5.3095e-02],
          ...,
          [ 7.2936e-01, -1.5534e+00,  3.6800e-01,  ...,  6.5748e-01,
           -4.6215e-01,  2.4150e-01],
          [ 5.9573e-02,  8.5819e-01, -5.1194e-01,  ...,  1.4387e+00,
           -7.6777e-01, -4.7266e-01],
          [ 1.1863e-01,  2.1461e-01,  5.2001e-01,  ..., -6.3462e-01,
           -2.2936e-01, -6.8790e-01]]]])}, 'action': tensor([[6],
        [2],
        [8],
        [7]]), 'reward': tensor([[0.2958],
        [0.4639],
        [0.3224],
        [0.5207]]), 'done': tensor([False, False, False, False])}

We can see that the structure process function need to be fully implemented (like the function stack). This code is actually not clear, and due to hard coding, if you need to support more data types (such as integer), you must make special modifications to the function.

Stack With TreeTensor API

The same workflow can be implemented with treetensor API like the code below.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch

import treetensor.torch as ttorch

B = 4


def get_item():
    return {
        'obs': {
            'scalar': torch.randn(12),
            'image': torch.randn(3, 32, 32),
        },
        'action': torch.randint(0, 10, size=(1,)),
        'reward': torch.rand(1),
        'done': False,
    }


data = [get_item() for _ in range(B)]
# execute `stack` op
data = [ttorch.tensor(d) for d in data]
stacked_data = ttorch.stack(data, dim=0)
# validate
print(stacked_data)
assert stacked_data.obs.image.shape == (B, 3, 32, 32)
assert stacked_data.action.shape == (B, 1)
assert stacked_data.reward.shape == (B, 1)
assert stacked_data.done.shape == (B,)
assert stacked_data.done.dtype == torch.bool

The output should be like below, and the assertion statements can be all passed as well.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
<Tensor 0x7fc009614f40>
├── 'action' --> tensor([[9],
│                        [0],
│                        [0],
│                        [0]])
├── 'done' --> tensor([False, False, False, False])
├── 'obs' --> <Tensor 0x7fc009614f10>
│   ├── 'image' --> tensor([[[[ 1.0066, -0.5125, -0.5100,  ...,  0.4304, -0.4791,  0.0037],
│   │                         [ 0.9816,  0.9596,  0.6719,  ..., -0.5245, -0.0722, -1.8649],
│   │                         [ 0.4676, -1.0411, -0.9968,  ..., -2.6556, -1.3733,  0.5220],
│   │                         ...,
│   │                         [ 0.5454, -1.3960, -1.0521,  ..., -0.5556,  0.2616,  0.3796],
│   │                         [-0.7684,  2.0678,  0.0217,  ..., -0.9004, -1.7169, -0.3276],
│   │                         [-0.0768,  0.3314,  1.0701,  ...,  1.8504, -0.3538,  0.2596]],
│   │               
│   │                        [[-0.9392, -0.7079,  0.0103,  ..., -0.4460, -1.4312,  0.2705],
│   │                         [-2.5275, -0.8551,  1.8839,  ...,  0.7568, -1.5573, -1.4480],
│   │                         [ 2.8549,  1.2432,  0.5100,  ...,  0.0081, -0.9087, -1.4696],
│   │                         ...,
│   │                         [ 1.1787, -0.6722,  1.6035,  ...,  0.7070,  0.8882, -0.8686],
│   │                         [-0.2725,  1.3458, -0.1880,  ...,  0.0574, -0.8437,  0.6458],
│   │                         [-1.6703, -0.0460, -1.4201,  ..., -1.2151, -0.4388,  1.7011]],
│   │               
│   │                        [[ 0.1337,  0.3359, -1.8712,  ..., -0.7771,  1.0711,  1.8957],
│   │                         [-1.7527,  0.6338,  0.9952,  ..., -0.2820,  0.5544, -0.6044],
│   │                         [ 0.0077,  0.0945, -0.0825,  ..., -0.4621,  0.2942,  1.0229],
│   │                         ...,
│   │                         [-0.2110,  0.4341,  0.3970,  ...,  1.2205,  1.2036, -1.0865],
│   │                         [-1.7070,  2.1940, -0.2095,  ...,  1.0198,  1.3933, -0.5540],
│   │                         [ 2.5871,  0.4120, -1.6745,  ...,  0.6403,  1.0392, -1.3896]]],
│   │               
│   │               
│   │                       [[[ 0.9347, -0.0240,  0.6249,  ...,  1.2845,  0.8269,  0.4282],
│   │                         [-1.1517, -0.4547,  0.0419,  ..., -1.2576, -0.0424,  0.9613],
│   │                         [ 0.2276, -0.2302,  0.5161,  ..., -1.6687,  0.2716, -0.9414],
│   │                         ...,
│   │                         [ 1.1351, -1.3335,  1.6448,  ..., -0.9331,  0.7740, -0.3198],
│   │                         [ 0.2640, -0.2563, -0.8956,  ..., -0.4128,  0.8902,  0.5108],
│   │                         [-0.7482, -0.8230,  0.5216,  ...,  1.2161, -0.1435,  0.8289]],
│   │               
│   │                        [[ 0.5484, -0.0790, -0.7419,  ..., -1.5692,  1.3066, -0.5309],
│   │                         [-0.6011, -1.1791,  0.2000,  ...,  0.8918, -1.1714,  0.9579],
│   │                         [-0.1215, -0.3698, -0.6632,  ..., -0.1972, -0.2141, -0.8800],
│   │                         ...,
│   │                         [-0.3785, -0.8487,  0.1326,  ...,  0.3809, -1.4617, -0.3041],
│   │                         [-0.0449, -0.9626,  0.9752,  ..., -0.0051,  0.0315,  1.3118],
│   │                         [ 1.9892,  0.4154,  0.3769,  ...,  0.1489,  0.6702,  0.3442]],
│   │               
│   │                        [[-0.4406,  0.3815, -1.0188,  ..., -1.2785,  0.7905,  0.3470],
│   │                         [ 0.2139,  1.1612,  0.9128,  ...,  1.0073, -0.7710,  0.2179],
│   │                         [-1.2217, -0.7307, -0.9411,  ...,  0.3325, -0.3286, -0.0920],
│   │                         ...,
│   │                         [-0.3463, -0.0945,  0.9432,  ..., -1.6828,  2.1865,  0.1395],
│   │                         [ 0.4374, -0.1431,  1.3116,  ...,  0.7757, -0.7253, -0.9562],
│   │                         [-1.2028,  0.9317, -1.2930,  ...,  2.0751, -0.1609, -1.0734]]],
│   │               
│   │               
│   │                       [[[ 0.5703, -1.2568, -0.0638,  ...,  0.4189,  0.9797,  0.8111],
│   │                         [ 0.1262, -0.7507, -1.3518,  ..., -0.6599,  0.1899, -0.4528],
│   │                         [ 0.1792,  0.1125, -0.1614,  ..., -1.1983, -0.3885,  1.2360],
│   │                         ...,
│   │                         [ 0.6688,  0.4607,  1.1723,  ...,  0.1955,  1.3250,  0.3252],
│   │                         [-1.1828, -1.1331, -1.3031,  ...,  1.3129, -1.0821, -0.5326],
│   │                         [-0.8788, -0.5775,  0.1613,  ..., -0.2965,  0.4179,  0.0110]],
│   │               
│   │                        [[-0.2690,  1.6157, -1.3517,  ..., -1.5291,  0.1210,  0.0737],
│   │                         [ 1.4261,  0.1859,  2.2158,  ...,  1.4676,  1.1399,  0.8509],
│   │                         [ 1.7635, -0.0086,  0.6643,  ..., -0.0423, -1.4192, -0.9809],
│   │                         ...,
│   │                         [-1.0882, -1.6168,  1.0177,  ..., -0.1138,  0.2225, -0.2775],
│   │                         [-0.2816,  1.3698, -0.2802,  ..., -0.1700,  0.3518,  0.2016],
│   │                         [ 0.3858,  0.1409,  0.1205,  ...,  0.6933, -1.8943, -0.0080]],
│   │               
│   │                        [[ 0.2672,  0.9045, -0.6331,  ...,  0.3863,  0.1183, -0.8072],
│   │                         [-0.2371, -0.2344, -0.7020,  ..., -1.4168,  0.4153,  1.7615],
│   │                         [ 0.5611,  1.4971,  0.5270,  ..., -1.3174,  1.3251,  0.1133],
│   │                         ...,
│   │                         [-1.1819,  0.2242, -1.4139,  ..., -0.0732, -1.3229, -0.8533],
│   │                         [ 0.8031,  0.0191,  2.2125,  ..., -1.9870, -0.5304, -2.1601],
│   │                         [-0.9146, -0.7404, -0.3249,  ..., -0.5452, -0.0414, -1.6604]]],
│   │               
│   │               
│   │                       [[[ 0.8214,  0.7324,  1.7469,  ...,  0.8984, -0.3821,  0.8144],
│   │                         [ 0.1516,  0.6113, -1.9657,  ...,  0.5959, -0.2202,  0.0907],
│   │                         [-1.5093,  1.3312,  0.5243,  ...,  0.1106,  0.2880, -1.5461],
│   │                         ...,
│   │                         [-1.0750, -0.3997, -0.8149,  ..., -0.9195, -0.1075, -1.2298],
│   │                         [ 0.5945, -0.0188,  0.0934,  ...,  0.2349,  1.3371, -0.6677],
│   │                         [ 0.4770, -0.7522,  1.5829,  ...,  0.3556, -1.6438, -0.9779]],
│   │               
│   │                        [[ 1.8211,  0.0571,  0.6973,  ..., -1.2604,  1.1725, -0.3503],
│   │                         [-0.1182,  0.1670, -0.6894,  ...,  1.3592,  0.5219, -1.1902],
│   │                         [ 1.5181, -0.3473,  0.9742,  ..., -0.0555, -0.7948,  0.0249],
│   │                         ...,
│   │                         [-0.9457, -0.9585, -0.4152,  ...,  0.5383,  1.1772,  0.3163],
│   │                         [ 0.5345, -0.4548, -0.8079,  ...,  0.9607,  1.2593,  1.7718],
│   │                         [-0.7634, -0.8629,  0.4670,  ...,  0.3997, -0.3634,  0.2452]],
│   │               
│   │                        [[-0.1457,  0.6575, -0.7636,  ..., -1.6490, -0.3365, -1.8475],
│   │                         [ 1.1464,  1.1932,  0.1570,  ..., -0.4292,  1.5352, -0.1729],
│   │                         [ 0.8268,  0.2154,  0.0238,  ..., -0.1977, -0.6478,  1.6909],
│   │                         ...,
│   │                         [-0.9885, -0.2291,  1.0967,  ...,  0.4283, -0.7276, -1.4722],
│   │                         [ 0.2845,  0.2774, -0.9401,  ..., -1.2472, -0.6350,  0.7520],
│   │                         [-0.1413, -0.3974,  0.8726,  ..., -0.3979, -1.1564, -0.1461]]]])
│   └── 'scalar' --> tensor([[ 1.4567, -1.0372,  1.4316, -1.7959, -0.3340, -0.0713, -0.5631,  0.6710,
│                             -0.9753, -1.0369,  1.2262, -1.3162],
│                            [-0.4686, -1.2069,  0.9817, -0.8874, -0.3138, -0.4678,  0.0639, -0.3879,
│                              2.3316, -1.0118, -0.0311,  0.6872],
│                            [ 1.1566, -0.0564, -2.0215, -0.3011, -0.5755,  0.5510, -0.3258,  0.0251,
│                              0.7771, -0.4071, -1.2602, -1.2215],
│                            [ 1.3694, -0.0052, -0.0065, -0.6910,  1.2948,  0.6237, -0.2545, -2.0739,
│                             -0.9283,  0.7941, -1.4230,  0.0979]])
└── 'reward' --> tensor([[0.7627],
                         [0.9121],
                         [0.9081],
                         [0.3071]])

This code looks much simpler and clearer.