Stack Structured Data

When the tensors form the tree structures, they are often needed to be stacked together, like the torch.stack() implemented in torch.

Stack With Native PyTorch API

Here is the common code implement with native pytorch API.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch

B = 4


def get_item():
    return {
        'obs': {
            'scalar': torch.randn(12),
            'image': torch.randn(3, 32, 32),
        },
        'action': torch.randint(0, 10, size=(1,)),
        'reward': torch.rand(1),
        'done': False,
    }


data = [get_item() for _ in range(B)]


# execute `stack` op
def stack(data, dim):
    elem = data[0]
    if isinstance(elem, torch.Tensor):
        return torch.stack(data, dim)
    elif isinstance(elem, dict):
        return {k: stack([item[k] for item in data], dim) for k in elem.keys()}
    elif isinstance(elem, bool):
        return torch.BoolTensor(data)
    else:
        raise TypeError("not support elem type: {}".format(type(elem)))


stacked_data = stack(data, dim=0)
# validate
print(stacked_data)
assert stacked_data['obs']['image'].shape == (B, 3, 32, 32)
assert stacked_data['action'].shape == (B, 1)
assert stacked_data['reward'].shape == (B, 1)
assert stacked_data['done'].shape == (B,)
assert stacked_data['done'].dtype == torch.bool

The output should be like below, and the assertion statements can be all passed.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
{'obs': {'scalar': tensor([[-0.8321,  0.4115, -1.2981, -0.6643,  0.0807, -0.9570,  2.4863,  0.1631,
         -1.0269, -0.4991, -2.1319,  1.2083],
        [ 1.5919, -1.4477, -0.9831, -0.8956,  0.0234,  1.0293,  0.1199,  0.8425,
          1.2851,  1.1866,  0.3653, -1.0179],
        [-1.6906,  0.9263, -0.0904, -1.7171, -0.3145,  0.0919, -0.8474, -0.2647,
          0.6702,  0.1404, -0.7951, -0.3285],
        [-0.7176,  0.9064,  0.2999, -0.7348, -0.1395,  0.3248,  0.0108,  0.4680,
          0.0825,  0.9976,  0.8506, -0.2151]]), 'image': tensor([[[[ 4.2852e-01, -7.5208e-01, -1.7388e-01,  ...,  6.5257e-01,
           -1.1350e-01, -1.1175e+00],
          [ 3.4163e-01,  8.7786e-01,  7.7874e-01,  ..., -1.6007e-01,
            3.9022e-01,  1.3513e+00],
          [ 2.1908e+00, -9.6296e-01,  3.4991e-01,  ...,  6.3264e-01,
            5.6023e-01,  8.2719e-02],
          ...,
          [-2.9324e-01,  9.0977e-01,  1.2530e+00,  ...,  1.1444e+00,
           -9.5810e-01,  1.0529e-01],
          [-2.3462e-01, -1.1177e+00,  1.1494e-01,  ..., -1.5784e-01,
           -4.0094e-01,  1.0268e+00],
          [-1.9087e+00,  1.0450e+00, -1.3892e+00,  ...,  3.3922e-01,
            2.6151e-01, -9.0820e-01]],

         [[-1.8531e+00, -1.2567e+00,  2.6250e-01,  ...,  1.1829e+00,
            4.4227e-01,  1.0393e-01],
          [-7.5756e-01, -4.2016e-01, -1.7783e-01,  ...,  6.8571e-01,
            7.7436e-02,  6.4996e-01],
          [-2.1739e+00, -9.4435e-01, -8.6653e-01,  ...,  1.5597e-02,
           -1.5937e-02, -1.2477e+00],
          ...,
          [ 6.4178e-02,  2.4286e+00,  5.0346e-01,  ...,  6.2248e-01,
           -6.8733e-01,  3.5479e-01],
          [ 2.0025e+00,  8.2035e-01, -2.3170e-01,  ...,  1.9981e-01,
            1.0842e+00,  1.2622e+00],
          [ 9.0386e-01, -2.1911e-01, -7.3026e-01,  ...,  5.9536e-01,
           -3.1632e-01, -7.8107e-01]],

         [[-5.0562e-02, -1.0324e+00,  3.6389e-01,  ..., -4.1243e-01,
            1.0697e-01, -1.5743e+00],
          [ 1.3207e+00, -1.9610e-01, -1.0565e+00,  ..., -5.7704e-01,
            5.4279e-02, -1.2365e+00],
          [ 1.2610e+00,  1.9393e-01,  1.3241e+00,  ...,  3.4693e-01,
            2.7117e-01,  4.9549e-01],
          ...,
          [-5.2695e-01, -6.6137e-01,  9.3481e-03,  ...,  5.9376e-01,
           -1.8476e+00, -2.4352e-01],
          [-1.3726e+00,  5.0851e-01, -5.5054e-02,  ..., -8.5862e-01,
            3.6279e-01, -5.9956e-01],
          [-1.2422e+00,  5.1579e-01, -5.0417e-02,  ..., -2.4008e-01,
           -5.4265e-01,  1.2616e-01]]],


        [[[-2.2100e+00, -4.9205e-01,  1.1229e-01,  ..., -1.2827e+00,
           -1.7008e+00, -1.1416e+00],
          [-1.2711e+00, -7.1648e-01,  3.3727e-01,  ...,  3.3699e-01,
           -4.0348e-01, -2.8878e-01],
          [-8.7005e-01,  7.5652e-01,  6.6588e-02,  ..., -3.0309e-01,
            7.7675e-01,  2.5652e-01],
          ...,
          [ 3.7656e-01,  4.6825e-01,  9.9774e-01,  ...,  6.9514e-02,
           -5.9286e-01, -8.5538e-01],
          [-1.0126e-01,  1.7465e+00, -5.9243e-01,  ..., -1.3387e+00,
           -7.8379e-01,  9.0370e-01],
          [ 1.0035e+00,  9.5334e-01, -4.3017e-02,  ...,  8.1674e-01,
           -9.8918e-01,  5.9236e-01]],

         [[ 1.4258e+00, -8.9211e-01,  4.1404e-01,  ...,  2.5181e-01,
            1.2596e+00, -5.7541e-01],
          [-1.6198e+00, -2.0796e+00, -2.7016e-01,  ..., -1.2391e+00,
            1.1010e+00,  7.9630e-02],
          [ 2.4068e-01,  1.8113e+00,  1.4147e+00,  ..., -8.1587e-01,
            3.4670e-02,  2.1777e-01],
          ...,
          [ 7.5705e-01,  3.8163e-01,  8.8424e-01,  ..., -8.4945e-01,
           -1.3719e+00, -9.7024e-01],
          [ 4.0816e-01,  5.6532e-01, -1.1005e+00,  ..., -4.1002e-01,
            5.9238e-01, -1.0132e+00],
          [ 1.6248e+00,  3.8377e-01,  9.0332e-02,  ...,  1.2838e+00,
            7.9996e-01,  3.8081e-01]],

         [[-4.6070e-01, -4.4864e-01,  6.8577e-01,  ..., -4.0362e-01,
           -1.0136e+00,  6.4471e-02],
          [-5.4373e-01, -1.3905e+00, -2.7576e+00,  ..., -6.4313e-01,
            9.0840e-01,  4.1336e-01],
          [-5.4031e-01, -2.2778e+00,  5.1013e-01,  ..., -2.9793e-02,
            1.1146e+00, -1.1620e-01],
          ...,
          [-2.8573e-01, -1.2984e+00,  1.3267e+00,  ..., -6.3259e-01,
            7.5652e-01,  6.8424e-01],
          [-3.9018e-01, -1.2319e-01, -1.8936e+00,  ..., -4.2255e-01,
            2.6317e-02, -1.6355e+00],
          [-1.4471e+00,  1.1978e+00,  5.2484e-01,  ..., -4.8446e-01,
           -1.2648e+00,  6.2655e-01]]],


        [[[-7.1309e-01,  1.3992e+00,  4.2175e-01,  ...,  1.0609e+00,
           -1.4525e-01, -4.0711e-01],
          [-1.1026e+00, -4.7054e-02,  2.4239e+00,  ...,  4.8280e-01,
           -1.1162e+00, -1.4086e+00],
          [ 1.1731e+00,  5.9791e-01,  6.7690e-01,  ..., -1.5229e+00,
            1.3113e+00,  2.4117e+00],
          ...,
          [ 1.5302e+00,  9.7605e-01,  1.3727e+00,  ...,  1.1335e-01,
            3.3282e-01,  1.0516e+00],
          [-1.6011e+00,  1.3640e+00,  6.6849e-01,  ...,  1.0708e+00,
           -5.8364e-01, -9.4140e-01],
          [ 5.6715e-01, -1.7530e+00, -4.3532e-01,  ...,  7.4842e-01,
           -6.2953e-01,  4.4398e-01]],

         [[ 2.6444e-01, -2.4894e-01, -5.3038e-01,  ..., -1.0138e-01,
            8.7250e-01, -1.1218e+00],
          [-1.1219e+00,  7.4967e-01,  7.2936e-01,  ...,  3.1409e-01,
           -4.2283e-01,  8.2313e-01],
          [-2.0530e+00, -5.2646e-02,  1.7570e+00,  ..., -3.7291e-01,
            1.6394e+00, -1.6216e+00],
          ...,
          [-3.6495e-01, -2.5900e-01, -2.2151e-01,  ...,  6.3007e-01,
           -6.7897e-01,  1.0782e+00],
          [-3.4637e-01,  3.8246e-01,  8.4671e-01,  ...,  2.1174e-01,
           -1.1715e+00, -2.5063e-01],
          [ 5.3622e-01,  6.8732e-02, -8.9642e-01,  ...,  6.1801e-01,
           -6.0200e-01, -8.8174e-01]],

         [[-9.0628e-02,  6.7713e-01,  1.5489e+00,  ..., -1.9710e-01,
           -7.1502e-01, -9.5739e-01],
          [ 1.8930e+00, -4.6703e-01,  1.4281e+00,  ...,  2.2534e+00,
            1.7150e+00, -6.5481e-01],
          [ 1.3556e-01, -7.4351e-01, -7.8192e-01,  ..., -6.5805e-01,
            1.7676e+00, -1.9225e+00],
          ...,
          [-5.1577e-01,  5.1639e-01, -5.6279e-01,  ..., -1.8252e+00,
            2.7366e-01, -2.0543e-02],
          [ 1.0252e-01, -6.9031e-01,  3.7258e-01,  ...,  6.9040e-02,
           -4.6241e-01,  1.0692e+00],
          [-6.6968e-01, -7.8919e-01,  4.1917e-01,  ...,  2.4317e+00,
            1.6968e+00, -8.1849e-01]]],


        [[[ 2.0016e-03, -5.7886e-01, -1.7169e+00,  ...,  2.5055e-01,
            8.2737e-01,  8.6523e-01],
          [-4.0587e-01,  7.9177e-01,  1.4863e+00,  ...,  1.3820e+00,
           -5.1700e-01,  4.2729e-01],
          [-1.1575e+00, -4.4472e-01, -1.1722e+00,  ...,  5.9480e-01,
           -5.1535e-01,  1.9522e-01],
          ...,
          [ 7.8547e-01,  1.6541e+00, -2.1588e+00,  ...,  1.2489e+00,
            3.4968e-01,  9.7536e-01],
          [ 8.8159e-01, -1.2106e+00,  7.5257e-01,  ...,  1.2380e-01,
            1.5472e+00,  1.6789e-02],
          [ 3.3039e+00, -1.9642e+00,  8.9381e-01,  ..., -8.5707e-01,
            4.9600e-01, -2.4343e-01]],

         [[ 3.4797e-01, -5.1898e-01, -1.0933e+00,  ..., -1.8727e-02,
            2.0163e-01, -6.8754e-02],
          [ 7.9409e-01, -2.8152e-01,  2.1057e+00,  ...,  4.3352e-01,
            3.9443e-01, -4.0992e-01],
          [-2.7853e+00,  6.2406e-02, -5.2967e-01,  ...,  3.0205e-01,
            1.3831e+00,  2.1162e+00],
          ...,
          [ 2.6771e-02,  1.2483e-01, -1.1640e+00,  ..., -2.1621e+00,
           -3.3267e-02, -1.5374e+00],
          [-7.0803e-01,  3.2913e-01, -1.1831e-01,  ...,  5.0756e-02,
           -9.2405e-01,  1.8659e+00],
          [ 5.3362e-01,  2.3305e-01,  8.1592e-01,  ..., -4.2848e-01,
           -2.0241e+00,  1.0748e+00]],

         [[-1.3563e-01,  2.9222e-01,  9.8270e-02,  ..., -1.5085e+00,
            1.1467e+00,  1.1489e-01],
          [ 4.8113e-01, -1.3435e+00,  1.8291e+00,  ...,  5.2886e-01,
           -4.7839e-01,  5.6480e-01],
          [-1.1525e+00,  6.8927e-01,  4.2303e-01,  ...,  1.5632e-01,
            1.3267e+00,  1.2292e+00],
          ...,
          [ 1.1167e+00,  1.3048e+00,  4.4516e-01,  ...,  8.7841e-01,
           -1.0668e+00,  7.5817e-01],
          [ 1.6006e+00,  3.6132e-01, -1.5933e+00,  ...,  1.3984e-01,
            2.1368e+00,  9.9310e-01],
          [-4.7729e-01,  1.2652e+00, -1.3762e+00,  ...,  3.0158e-01,
           -6.7458e-01,  2.0324e-01]]]])}, 'action': tensor([[5],
        [7],
        [9],
        [2]]), 'reward': tensor([[0.8089],
        [0.1110],
        [0.9295],
        [0.0974]]), 'done': tensor([False, False, False, False])}

We can see that the structure process function need to be fully implemented (like the function stack). This code is actually not clear, and due to hard coding, if you need to support more data types (such as integer), you must make special modifications to the function.

Stack With TreeTensor API

The same workflow can be implemented with treetensor API like the code below.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch

import treetensor.torch as ttorch

B = 4


def get_item():
    return {
        'obs': {
            'scalar': torch.randn(12),
            'image': torch.randn(3, 32, 32),
        },
        'action': torch.randint(0, 10, size=(1,)),
        'reward': torch.rand(1),
        'done': False,
    }


data = [get_item() for _ in range(B)]
# execute `stack` op
data = [ttorch.tensor(d) for d in data]
stacked_data = ttorch.stack(data, dim=0)
# validate
print(stacked_data)
assert stacked_data.obs.image.shape == (B, 3, 32, 32)
assert stacked_data.action.shape == (B, 1)
assert stacked_data.reward.shape == (B, 1)
assert stacked_data.done.shape == (B,)
assert stacked_data.done.dtype == torch.bool

The output should be like below, and the assertion statements can be all passed as well.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
<Tensor 0x7fafefaed190>
├── 'action' --> tensor([[3],
│                        [3],
│                        [3],
│                        [7]])
├── 'done' --> tensor([False, False, False, False])
├── 'obs' --> <Tensor 0x7fafefaed2b0>
│   ├── 'image' --> tensor([[[[ 1.3920, -0.8865,  1.6668,  ...,  0.2620, -2.9271, -0.2037],
│   │                         [ 0.5440, -0.2689, -0.5091,  ...,  1.2014, -1.1526, -0.4832],
│   │                         [-2.2782, -0.0907, -1.4654,  ...,  0.6948,  1.1884,  0.2286],
│   │                         ...,
│   │                         [ 2.6919,  0.4603, -0.0149,  ...,  0.9791,  0.8234, -1.3533],
│   │                         [-1.4266, -0.9778, -1.1365,  ..., -0.6470,  2.0225, -1.4291],
│   │                         [-0.0479,  0.6626, -0.7104,  ..., -0.3527,  0.9521,  0.2552]],
│   │               
│   │                        [[ 0.7801, -1.0889, -0.3401,  ...,  0.4524, -0.7254,  0.5578],
│   │                         [-0.3366,  2.1877, -0.6251,  ..., -0.8137,  0.3517,  0.9949],
│   │                         [ 1.0531, -0.3080,  1.0723,  ..., -0.4369, -1.0756,  0.3074],
│   │                         ...,
│   │                         [ 2.2210, -1.2851, -0.0850,  ...,  0.1368,  0.7156, -0.1134],
│   │                         [ 1.5946, -0.5895,  1.3384,  ..., -0.9277, -0.9059, -0.6265],
│   │                         [-0.3297, -0.6275,  0.7793,  ..., -0.6945,  0.3103, -1.2492]],
│   │               
│   │                        [[-0.2230,  0.0489,  0.0393,  ...,  0.2449,  0.1952,  0.7578],
│   │                         [ 0.1796, -0.5396,  0.8207,  ...,  0.7355,  0.9064,  0.6463],
│   │                         [-0.4505,  0.3988,  1.2850,  ...,  1.1552,  2.1312, -0.2536],
│   │                         ...,
│   │                         [-0.1777, -0.2812,  0.7973,  ..., -0.6302, -1.1515, -1.4901],
│   │                         [-0.0111, -0.2463,  0.8124,  ..., -1.8334,  0.3945,  0.8603],
│   │                         [-1.0702,  0.8764, -1.3327,  ..., -0.9637,  0.3337, -0.7143]]],
│   │               
│   │               
│   │                       [[[ 1.3995,  0.0549,  0.4937,  ..., -0.8713, -0.4656,  0.0158],
│   │                         [-0.4180, -0.4157, -0.4975,  ...,  1.3634, -1.2806, -0.1712],
│   │                         [ 1.0339,  0.0878, -1.7462,  ...,  0.3000,  0.9082,  1.1016],
│   │                         ...,
│   │                         [-1.4643,  0.6138,  0.2582,  ..., -0.4585,  1.6768,  1.1729],
│   │                         [-0.0869,  0.6846, -0.1677,  ...,  0.7271,  1.5834, -0.3425],
│   │                         [ 0.4731,  2.0381, -0.2106,  ...,  1.0306, -0.5247, -2.0794]],
│   │               
│   │                        [[-0.1023,  1.0517, -0.0158,  ...,  0.9636,  1.1394, -0.3940],
│   │                         [ 1.5903,  1.7189,  2.3240,  ...,  0.0594, -0.5286, -0.2816],
│   │                         [ 0.4443,  0.1537, -0.4325,  ...,  0.1363,  0.8305,  0.7353],
│   │                         ...,
│   │                         [-0.9877,  0.7616, -0.1951,  ...,  1.6501, -0.7489, -0.6057],
│   │                         [ 0.0143,  0.9955,  0.1296,  ...,  1.8576,  0.1749, -0.8439],
│   │                         [ 0.3082,  1.4621, -0.4100,  ..., -1.2227,  0.0718,  0.9295]],
│   │               
│   │                        [[ 0.7757, -0.5048, -0.6325,  ...,  0.6356,  0.5520,  0.8157],
│   │                         [-0.0715, -1.4107,  0.6195,  ..., -2.1127,  1.8454, -1.3855],
│   │                         [ 0.5594,  1.9252, -0.9331,  ...,  0.3402,  1.2449, -1.2256],
│   │                         ...,
│   │                         [ 0.5670, -0.8053,  1.0296,  ...,  0.1248,  0.0180,  0.2733],
│   │                         [-1.8360, -0.0119, -0.3739,  ...,  0.2331, -0.3593, -0.4119],
│   │                         [ 0.3682,  0.2333,  1.3177,  ..., -0.1743,  1.9016, -0.2039]]],
│   │               
│   │               
│   │                       [[[-0.2261,  1.0991, -0.0520,  ..., -0.9408, -1.6285,  0.3938],
│   │                         [ 0.3988, -0.5101,  0.4242,  ..., -1.2440, -0.7914,  0.3440],
│   │                         [ 0.7235, -0.3799, -0.3964,  ..., -1.3564,  0.5100, -0.1313],
│   │                         ...,
│   │                         [ 0.3838,  0.9837,  0.4385,  ..., -0.4829, -1.1187,  0.4702],
│   │                         [-0.0780, -0.5538, -0.6747,  ..., -0.6217,  0.2664,  0.1519],
│   │                         [ 1.3075, -0.9424, -0.7075,  ..., -1.0966, -1.1101, -0.4213]],
│   │               
│   │                        [[ 0.0664, -0.4343,  0.1032,  ..., -1.0542, -0.1192, -1.1612],
│   │                         [-1.0548, -0.3236,  0.6410,  ..., -1.1824,  0.3827, -0.4164],
│   │                         [ 0.2300,  0.6026, -0.1242,  ..., -0.0978, -1.5669,  0.4418],
│   │                         ...,
│   │                         [ 0.1801,  0.3477,  0.2565,  ..., -0.9460,  2.3692, -0.6397],
│   │                         [ 0.5199,  1.0521,  0.0844,  ..., -0.6008,  0.2239, -0.8769],
│   │                         [ 0.8122, -1.5200,  1.6122,  ..., -2.1596,  0.2823, -1.7134]],
│   │               
│   │                        [[-0.1015,  0.3743, -0.0067,  ...,  1.1639,  0.6132,  0.2452],
│   │                         [-0.3569, -0.5615,  0.3039,  ..., -0.9395, -0.0672, -0.0299],
│   │                         [ 0.3492, -0.5049, -0.0379,  ..., -0.7586,  1.0934,  0.1140],
│   │                         ...,
│   │                         [ 0.3273,  0.0063, -1.6045,  ..., -0.2299, -0.4463, -0.3471],
│   │                         [ 0.5456, -0.8327,  1.1277,  ..., -0.6565,  1.3670,  1.1586],
│   │                         [-0.6851, -1.4420,  1.4105,  ..., -1.2554, -0.4006,  0.6563]]],
│   │               
│   │               
│   │                       [[[ 0.1247, -0.1976, -0.1188,  ...,  0.0451, -1.4528, -0.9902],
│   │                         [ 0.8898, -1.9705, -0.3585,  ..., -0.4237, -0.2914, -0.1517],
│   │                         [-0.2534, -1.0738, -0.4903,  ..., -0.4245,  0.6223,  0.1729],
│   │                         ...,
│   │                         [ 0.9523, -1.3623, -0.1290,  ...,  0.8980, -0.2984,  1.1681],
│   │                         [ 0.7917, -1.6211, -1.2120,  ...,  0.3611, -0.2688, -0.4129],
│   │                         [-1.0221,  0.5387,  0.6960,  ..., -1.0993, -0.8395, -1.8821]],
│   │               
│   │                        [[-0.7791,  0.0362,  0.2613,  ..., -0.0601, -0.6577, -0.6591],
│   │                         [ 0.4978,  0.6533, -0.1684,  ...,  0.4970,  0.2426,  1.5893],
│   │                         [ 0.8312,  1.5447,  1.0367,  ..., -1.3944, -1.6834, -0.2324],
│   │                         ...,
│   │                         [-0.6982, -0.6749, -0.1744,  ...,  0.3756,  0.3420, -0.7598],
│   │                         [ 0.3477, -0.1874,  0.1860,  ...,  1.1443,  1.1687, -0.5945],
│   │                         [ 1.1584,  0.0895,  2.3563,  ..., -0.5935, -1.2246, -0.5154]],
│   │               
│   │                        [[-0.2248,  0.4800,  0.8997,  ...,  1.2539,  0.7265, -2.7110],
│   │                         [-0.9289,  1.4271, -1.1098,  ...,  1.2831,  1.2386,  0.7930],
│   │                         [-1.5400, -1.6965,  0.8993,  ..., -0.0736, -0.7558,  2.8332],
│   │                         ...,
│   │                         [ 0.8784,  1.4417,  1.0031,  ..., -0.1161, -0.1153, -0.0988],
│   │                         [-1.3206,  0.6437, -0.4315,  ..., -1.3464, -1.4482,  2.2152],
│   │                         [-0.4652,  0.1442, -0.0260,  ...,  0.1855,  2.4188, -0.3384]]]])
│   └── 'scalar' --> tensor([[-1.9775,  1.2561, -1.8104, -0.0423, -1.6107,  0.0803, -1.2937, -0.7758,
│                             -1.4195,  1.1695, -0.3743,  1.9428],
│                            [-0.8542,  0.8703, -1.7483,  0.2176, -1.1914,  0.6103, -0.3354, -0.5292,
│                             -0.4301, -0.5459, -3.4956,  0.4404],
│                            [ 1.5109, -0.9939,  0.4760,  0.4251, -0.0542, -0.5597,  1.9502,  1.0798,
│                              0.0862, -0.8159,  1.1634,  0.2715],
│                            [-1.4832, -1.2423, -0.0416,  1.1550, -0.2121,  0.5468,  0.6032, -1.4731,
│                              0.3059, -0.8361, -0.3137,  0.7426]])
└── 'reward' --> tensor([[0.9880],
                         [0.1117],
                         [0.4755],
                         [0.4468]])

This code looks much simpler and clearer.