Stack Structured Data

When the tensors form the tree structures, they are often needed to be stacked together, like the torch.stack() implemented in torch.

Stack With Native PyTorch API

Here is the common code implement with native pytorch API.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch

B = 4


def get_item():
    return {
        'obs': {
            'scalar': torch.randn(12),
            'image': torch.randn(3, 32, 32),
        },
        'action': torch.randint(0, 10, size=(1,)),
        'reward': torch.rand(1),
        'done': False,
    }


data = [get_item() for _ in range(B)]


# execute `stack` op
def stack(data, dim):
    elem = data[0]
    if isinstance(elem, torch.Tensor):
        return torch.stack(data, dim)
    elif isinstance(elem, dict):
        return {k: stack([item[k] for item in data], dim) for k in elem.keys()}
    elif isinstance(elem, bool):
        return torch.BoolTensor(data)
    else:
        raise TypeError("not support elem type: {}".format(type(elem)))


stacked_data = stack(data, dim=0)
# validate
print(stacked_data)
assert stacked_data['obs']['image'].shape == (B, 3, 32, 32)
assert stacked_data['action'].shape == (B, 1)
assert stacked_data['reward'].shape == (B, 1)
assert stacked_data['done'].shape == (B,)
assert stacked_data['done'].dtype == torch.bool

The output should be like below, and the assertion statements can be all passed.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
{'obs': {'scalar': tensor([[-6.1698e-01, -3.5911e-02, -8.4162e-01, -1.0243e+00, -3.0843e-01,
          4.9491e-01,  2.4833e-03,  6.2664e-01,  1.6798e+00, -1.2389e+00,
         -5.3085e-02, -1.7674e+00],
        [-2.1730e-01,  3.4035e-02,  9.1718e-02,  3.2348e-01, -4.6191e-01,
         -6.4095e-01, -2.0785e+00,  7.3053e-01,  6.4138e-01,  4.4964e-01,
         -6.0121e-01,  1.3203e+00],
        [ 7.4373e-02,  2.0674e+00,  2.2354e+00,  3.9442e-02, -4.0130e-01,
          4.7402e-01,  1.3394e+00, -2.4323e-01, -7.0623e-02,  3.2360e-01,
         -5.9731e-01, -1.8476e+00],
        [-1.0765e+00,  1.4032e+00, -7.4266e-01,  4.2113e-01,  1.0456e+00,
         -2.7504e+00, -3.9172e-01, -1.4856e+00,  1.3192e+00, -3.1867e-01,
         -1.1507e+00, -1.2567e+00]]), 'image': tensor([[[[ 6.6362e-02, -3.0005e+00,  3.5918e-01,  ...,  3.9558e-01,
            1.2657e+00, -4.8891e-01],
          [-6.4898e-01, -5.4758e-02,  3.9456e-01,  ..., -1.3991e+00,
            9.0560e-01, -2.4295e-01],
          [ 1.6880e+00, -1.1322e+00, -3.8597e-01,  ...,  9.0479e-01,
            6.8429e-01,  7.2944e-02],
          ...,
          [ 5.8127e-01,  9.3033e-02,  2.2347e-01,  ...,  5.0106e-01,
           -4.4185e-02,  9.6664e-01],
          [ 2.6107e+00,  4.0964e-01, -5.8772e-01,  ..., -3.2662e-01,
            1.3588e+00, -2.3649e-01],
          [-1.5033e+00, -6.2906e-02,  9.6726e-01,  ..., -2.7258e-01,
            2.6832e-01,  6.2866e-01]],

         [[ 1.5890e+00,  1.4966e+00,  2.2690e-02,  ..., -4.2269e-01,
           -5.0476e-01, -2.0276e+00],
          [-1.4468e-01,  1.8337e+00,  3.7504e-02,  ...,  7.1109e-01,
           -3.3176e-01, -9.2647e-01],
          [ 5.3248e-01, -5.5462e-01,  2.5544e-01,  ...,  5.5749e-01,
           -6.7063e-01, -4.8085e-01],
          ...,
          [ 1.7303e+00,  1.4727e+00,  2.1973e+00,  ..., -1.1379e+00,
            9.3738e-01, -9.2444e-01],
          [-1.0749e-01, -6.2012e-03,  3.0809e-01,  ...,  1.6250e-01,
           -1.0119e+00, -5.2481e-01],
          [ 3.1633e-01, -8.9269e-01,  2.7168e-01,  ..., -1.3890e+00,
            1.6825e+00,  2.8444e+00]],

         [[ 1.3046e+00,  1.1157e+00, -1.5650e+00,  ..., -1.9687e+00,
            6.3964e-01,  1.7618e+00],
          [ 9.0319e-01,  1.6140e+00,  2.1127e-01,  ..., -2.9324e-01,
            3.6012e-01, -5.6701e-01],
          [-4.5534e-01, -7.7589e-01,  6.1885e-01,  ..., -1.4376e+00,
           -3.3148e-01,  2.4537e-02],
          ...,
          [ 2.8042e-01, -7.2499e-01, -2.1023e+00,  ...,  7.4090e-01,
           -1.7050e-01,  2.5075e-01],
          [ 5.5005e-01, -2.3024e-01,  6.4926e-01,  ..., -7.6544e-01,
           -1.2140e+00,  1.6773e-01],
          [-5.6928e-01, -7.4548e-01,  4.0451e-01,  ...,  1.6896e+00,
            5.1373e-02, -1.8615e-02]]],


        [[[ 2.4119e+00,  3.4386e-01,  7.9097e-01,  ...,  9.0631e-01,
            3.2383e-01, -7.8016e-01],
          [ 3.2924e-01, -9.5325e-01,  3.2414e-02,  ...,  8.4106e-02,
            6.7480e-01,  8.9660e-01],
          [ 3.7338e-01,  2.8451e-01, -4.3049e-01,  ...,  7.0320e-01,
           -8.3111e-01,  4.3689e-01],
          ...,
          [ 2.8916e-02,  1.2719e+00, -1.8477e+00,  ...,  3.5349e-01,
            5.0706e-02,  1.0641e+00],
          [ 7.0002e-01,  3.3319e-01, -4.6249e-01,  ..., -1.6252e-01,
            3.5560e-01,  6.6109e-01],
          [ 7.3627e-01,  1.6854e-01,  1.5095e+00,  ..., -1.5956e+00,
            1.2041e+00,  1.6485e-02]],

         [[-1.9621e-02, -1.2196e+00,  7.0793e-01,  ...,  1.1383e-01,
            1.4532e+00,  2.0531e-02],
          [-2.7061e+00, -1.7879e+00,  6.8884e-01,  ..., -1.0964e+00,
            1.0705e-01, -7.2859e-01],
          [ 1.6323e+00,  1.4051e-01,  8.7622e-02,  ...,  1.1498e+00,
            9.9908e-01,  4.6945e-01],
          ...,
          [ 7.1348e-03,  1.4718e+00, -1.3173e-01,  ...,  1.8705e-01,
           -1.7517e+00,  5.0507e-01],
          [ 4.7941e-01, -2.1836e+00,  1.5491e+00,  ...,  1.5823e-01,
            5.7929e-01, -7.6027e-01],
          [ 1.2470e+00,  5.2481e-01, -2.0458e-02,  ..., -2.7197e-01,
            1.0157e+00, -1.0088e+00]],

         [[-3.7606e-01,  8.3414e-01, -6.3632e-01,  ...,  5.6328e-01,
           -6.0192e-01,  8.6260e-01],
          [ 3.3924e-01, -6.9107e-01, -1.6546e-01,  ..., -3.8606e-01,
           -4.9872e-01,  5.7626e-01],
          [-1.4864e+00, -1.3924e+00,  5.9965e-01,  ..., -1.3341e+00,
            7.7526e-01,  3.6241e-02],
          ...,
          [-4.2669e-01, -7.7961e-03, -8.2440e-01,  ...,  1.9363e+00,
            4.2572e-01, -1.0316e+00],
          [ 6.2787e-01, -5.5919e-01,  2.7254e-02,  ...,  2.9808e-01,
           -2.6985e-01,  5.9688e-01],
          [ 7.8022e-02, -9.5589e-01,  2.1464e+00,  ..., -1.5262e+00,
           -6.6489e-01, -6.2458e-01]]],


        [[[-6.0042e-01, -4.8636e-01,  1.4185e-01,  ...,  1.4224e+00,
            1.8828e-01, -1.2927e-01],
          [-6.2224e-01, -7.6153e-01, -2.2951e-01,  ..., -1.0307e+00,
            2.5702e-02,  4.0326e-02],
          [ 7.0173e-01, -3.9784e-01, -8.5880e-01,  ...,  1.2130e+00,
           -4.8769e-01,  9.1310e-01],
          ...,
          [ 4.8425e-01, -2.5909e-01, -3.4612e-01,  ..., -5.2100e-01,
            8.2617e-01, -6.5741e-01],
          [-2.8116e-02, -1.3599e-01,  2.9853e+00,  ..., -7.3271e-01,
            1.0578e-01, -1.3634e-02],
          [ 1.2401e+00,  1.0661e+00, -4.5835e-01,  ...,  9.3402e-01,
            2.9095e-01,  1.1106e+00]],

         [[-1.6004e+00,  4.2991e-01, -2.4186e+00,  ..., -2.8888e-01,
            3.7330e-01, -1.0260e-01],
          [-2.8052e-01, -8.7444e-01,  2.0940e+00,  ...,  2.4803e-01,
           -2.4667e-01, -1.0779e+00],
          [ 4.5746e-01,  1.0822e+00, -1.1474e+00,  ...,  5.4813e-02,
           -1.4129e-01, -5.7182e-01],
          ...,
          [-8.4463e-01, -5.4933e-01, -1.4725e-03,  ...,  1.1621e+00,
           -1.7123e+00,  6.9042e-01],
          [-2.1078e-01,  6.4637e-01, -8.4442e-02,  ...,  2.0503e-01,
            6.4542e-01, -8.1534e-01],
          [-2.9478e-01,  3.8252e-01,  1.1832e+00,  ...,  4.0255e-01,
           -2.1030e-01,  1.1138e+00]],

         [[ 5.3282e-01, -2.9339e-01,  1.0403e+00,  ..., -1.0782e+00,
           -1.0416e+00,  1.0955e+00],
          [-1.4453e+00, -2.1396e+00, -6.5487e-01,  ..., -1.5739e+00,
           -6.2571e-02, -9.2199e-04],
          [-6.3822e-01, -1.8271e-01, -9.8571e-01,  ..., -1.4647e-01,
           -8.6111e-01,  5.2302e-01],
          ...,
          [-6.3278e-01,  1.5656e-01, -6.4245e-01,  ..., -5.2832e-01,
            4.3505e-01,  2.8483e-01],
          [-9.5317e-01,  6.4241e-01, -2.5192e+00,  ..., -2.5656e-01,
            2.7737e-01, -8.6197e-01],
          [ 1.4951e-01,  5.6822e-01, -8.0139e-01,  ...,  4.1179e-01,
           -3.3907e-01,  9.9490e-01]]],


        [[[ 2.4886e+00,  4.9544e-01,  1.0890e+00,  ...,  5.1683e-01,
           -7.2424e-01, -5.3941e-01],
          [-6.2733e-01, -1.2096e+00, -1.1121e+00,  ...,  1.4382e-01,
           -7.7808e-01,  1.8056e+00],
          [ 1.1223e-01,  4.9913e-01,  2.1522e-01,  ...,  3.8553e-01,
           -1.6511e+00,  9.2808e-01],
          ...,
          [-5.3661e-01,  1.8866e+00,  1.6355e+00,  ..., -1.5545e+00,
           -4.8706e-01,  6.2231e-01],
          [ 4.4359e-01,  6.1325e-01, -5.7754e-01,  ..., -1.4999e+00,
           -2.8484e-02,  1.2325e+00],
          [ 2.6683e-02,  3.1068e-01,  1.2345e-01,  ..., -5.5037e-02,
            8.0377e-01,  1.0676e+00]],

         [[ 2.4074e-01,  9.3555e-01,  2.0545e+00,  ..., -1.7194e+00,
           -8.8702e-01, -6.5643e-02],
          [-1.5555e+00,  1.2204e-01, -4.3912e-01,  ...,  4.1040e-01,
           -8.5497e-01, -1.8267e+00],
          [-6.3392e-01, -6.0568e-01,  7.7994e-01,  ..., -5.3061e-01,
           -5.3385e-01, -1.1861e+00],
          ...,
          [ 2.5485e-01, -6.1928e-01, -4.4060e-01,  ..., -1.8528e+00,
            1.6112e+00, -1.5155e+00],
          [-7.8599e-02,  9.2520e-01, -8.9229e-01,  ...,  4.1619e-02,
           -2.6995e-01,  1.1528e-01],
          [-3.7908e-01,  1.4178e+00, -4.7899e-01,  ...,  2.5579e-01,
            3.3984e-01,  4.0058e-01]],

         [[-1.2469e-01, -1.0694e+00, -1.3426e+00,  ...,  8.8821e-01,
            5.6551e-01,  7.6151e-01],
          [ 8.2287e-01, -1.5267e+00, -8.2538e-01,  ..., -7.5549e-01,
           -5.9792e-01, -8.3010e-01],
          [ 1.7147e+00,  8.5787e-01,  9.1239e-01,  ..., -1.5758e-01,
            7.0117e-01, -8.9143e-01],
          ...,
          [-9.0077e-01,  4.0928e-01,  8.2077e-01,  ...,  2.0920e+00,
           -5.8724e-01,  1.9528e+00],
          [-8.5304e-01,  1.7340e+00,  1.1009e+00,  ...,  5.1343e-01,
           -2.5099e-01,  2.0193e-01],
          [-1.1386e-01,  1.4287e+00,  2.3030e-01,  ..., -5.3843e-02,
           -6.2613e-01, -8.8357e-01]]]])}, 'action': tensor([[6],
        [5],
        [8],
        [7]]), 'reward': tensor([[0.0762],
        [0.0556],
        [0.5972],
        [0.0720]]), 'done': tensor([False, False, False, False])}

We can see that the structure process function need to be fully implemented (like the function stack). This code is actually not clear, and due to hard coding, if you need to support more data types (such as integer), you must make special modifications to the function.

Stack With TreeTensor API

The same workflow can be implemented with treetensor API like the code below.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch

import treetensor.torch as ttorch

B = 4


def get_item():
    return {
        'obs': {
            'scalar': torch.randn(12),
            'image': torch.randn(3, 32, 32),
        },
        'action': torch.randint(0, 10, size=(1,)),
        'reward': torch.rand(1),
        'done': False,
    }


data = [get_item() for _ in range(B)]
# execute `stack` op
data = [ttorch.tensor(d) for d in data]
stacked_data = ttorch.stack(data, dim=0)
# validate
print(stacked_data)
assert stacked_data.obs.image.shape == (B, 3, 32, 32)
assert stacked_data.action.shape == (B, 1)
assert stacked_data.reward.shape == (B, 1)
assert stacked_data.done.shape == (B,)
assert stacked_data.done.dtype == torch.bool

The output should be like below, and the assertion statements can be all passed as well.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
<Tensor 0x7fd8fdb6e190>
├── 'action' --> tensor([[0],
│                        [9],
│                        [9],
│                        [6]])
├── 'done' --> tensor([False, False, False, False])
├── 'obs' --> <Tensor 0x7fd8fdb6e2b0>
│   ├── 'image' --> tensor([[[[ 1.4840, -0.1060,  0.3117,  ...,  1.3222,  0.0554, -0.7188],
│   │                         [-0.3793,  1.4982, -0.9224,  ..., -0.1620,  0.2350,  0.8966],
│   │                         [ 0.2379, -0.3581,  2.6328,  ...,  0.5986,  0.9841, -0.0110],
│   │                         ...,
│   │                         [-1.4120,  0.2829, -0.1942,  ..., -0.5791,  0.9103, -0.0116],
│   │                         [ 1.3092, -0.0546, -0.3074,  ...,  1.8787,  2.1366, -0.0036],
│   │                         [-0.3800,  0.6551, -0.5650,  ...,  1.4795, -0.7699,  0.5425]],
│   │               
│   │                        [[ 0.4516, -1.4252, -0.6245,  ..., -0.3018,  1.1203, -0.8574],
│   │                         [ 1.5125, -1.3168, -0.2520,  ...,  0.4751, -0.3609,  1.0180],
│   │                         [-0.8998,  0.7666,  0.0407,  ..., -0.6801, -0.9916, -0.5293],
│   │                         ...,
│   │                         [ 0.5978, -0.4539,  0.1810,  ...,  1.9599, -2.1043,  0.6454],
│   │                         [ 0.6328,  0.4073, -1.1529,  ...,  0.7862, -0.2449,  0.9037],
│   │                         [ 1.7697,  1.6279,  1.7602,  ...,  0.3091, -0.3448, -0.8423]],
│   │               
│   │                        [[ 0.5599,  0.5356, -0.1288,  ...,  0.5111, -0.8478,  1.9932],
│   │                         [ 0.7565, -0.7621,  0.5595,  ...,  0.9300, -0.4512,  0.5594],
│   │                         [ 0.6076, -0.3534,  0.5591,  ...,  1.2942, -0.9395, -0.1464],
│   │                         ...,
│   │                         [ 0.5563,  0.0891,  0.3135,  ...,  0.8984, -1.4368, -0.4166],
│   │                         [ 0.0626,  0.0841,  0.8749,  ...,  0.8528,  1.5602,  0.5588],
│   │                         [ 0.1554, -1.5173,  0.6627,  ...,  0.8934,  0.3321, -1.2965]]],
│   │               
│   │               
│   │                       [[[-0.4311,  0.0182,  0.2486,  ...,  1.0453,  0.7432, -0.2991],
│   │                         [-0.0233,  0.6602, -1.1078,  ...,  0.0986,  0.8260, -0.2859],
│   │                         [-1.5236,  1.3943,  0.9016,  ..., -1.1093, -0.2815, -1.5809],
│   │                         ...,
│   │                         [ 0.4585,  1.5466,  0.0876,  ...,  0.2054,  0.2216,  1.1908],
│   │                         [ 0.6930,  2.2302, -0.5919,  ..., -1.1132, -0.3055,  1.2813],
│   │                         [ 0.3507, -0.6591, -0.0622,  ..., -1.3133, -1.3663,  0.7370]],
│   │               
│   │                        [[ 0.2953,  0.9746,  0.9612,  ...,  1.0568, -0.9349, -0.0516],
│   │                         [ 0.0717,  0.6003, -1.9908,  ...,  0.1633,  0.6292,  0.1560],
│   │                         [-1.2875, -0.3682,  0.5923,  ..., -1.1560,  0.1348,  1.3057],
│   │                         ...,
│   │                         [ 0.2032,  0.7076, -0.2002,  ..., -1.6046, -0.1308,  0.6110],
│   │                         [-0.1244,  0.6353, -0.4262,  ..., -0.5089, -0.1380,  0.1824],
│   │                         [ 0.3651,  0.3783, -1.3058,  ..., -1.2695, -0.2346,  0.6045]],
│   │               
│   │                        [[-0.9514,  1.0170, -0.4881,  ..., -0.4963,  0.6557, -1.2997],
│   │                         [ 0.6462, -1.6608,  0.9774,  ...,  0.4129,  0.2262, -0.1699],
│   │                         [-0.5496,  0.2942,  0.9288,  ..., -0.5061,  1.6713, -0.8267],
│   │                         ...,
│   │                         [-0.1019,  1.0044,  0.6310,  ..., -0.0954,  1.0305,  0.3730],
│   │                         [-0.7762, -1.7707, -0.0948,  ...,  0.3849, -0.9199, -0.0364],
│   │                         [ 1.1701, -1.0132,  0.1776,  ...,  1.0082, -1.4033,  0.9446]]],
│   │               
│   │               
│   │                       [[[-0.5742, -0.8322,  1.7716,  ..., -0.5219,  0.6629,  0.3064],
│   │                         [ 0.1943, -1.6901, -0.6748,  ..., -2.8692, -0.6480,  0.1384],
│   │                         [-1.1884, -0.8971,  1.9299,  ...,  0.2120,  1.1614,  0.1220],
│   │                         ...,
│   │                         [-1.8986,  0.3338,  0.5753,  ..., -1.4631,  0.3468,  1.3235],
│   │                         [ 0.2218,  2.5386, -1.6293,  ..., -0.8123,  1.8338,  0.1920],
│   │                         [ 0.7684,  0.0183, -0.2705,  ..., -2.6146, -0.0550, -0.0383]],
│   │               
│   │                        [[-0.6213,  0.3122,  0.0265,  ...,  0.4419, -0.0366, -0.3377],
│   │                         [ 0.4183,  0.0469, -1.1002,  ..., -0.6085, -0.5245,  0.2417],
│   │                         [-0.6207,  1.3336,  0.6474,  ...,  1.2527,  0.0943, -0.0990],
│   │                         ...,
│   │                         [-1.5437, -0.4618, -0.4815,  ...,  1.5831, -1.4805,  0.4847],
│   │                         [-1.0478, -1.0128,  0.8363,  ...,  1.2779, -0.6283, -0.8740],
│   │                         [ 0.1765,  1.0445, -0.5841,  ...,  1.2866,  0.2932, -1.7016]],
│   │               
│   │                        [[ 1.6078, -0.2085, -1.2276,  ..., -0.1012, -0.2058, -0.3348],
│   │                         [ 0.8424, -0.7256, -0.2131,  ..., -1.6127, -0.6664,  0.1647],
│   │                         [-0.7344,  1.0712,  2.4946,  ..., -0.7619,  1.9686,  0.8134],
│   │                         ...,
│   │                         [-1.4347, -0.1636,  0.1911,  ...,  1.2704,  0.4908, -0.3364],
│   │                         [-0.7793,  0.2824, -0.3574,  ..., -1.1836,  0.6829,  0.3524],
│   │                         [ 1.3773,  1.4154,  0.7471,  ...,  0.3017,  0.7461, -0.2283]]],
│   │               
│   │               
│   │                       [[[ 0.1233,  0.4324,  0.0315,  ..., -0.4253, -0.0553,  2.5347],
│   │                         [-0.7532, -1.7244, -0.5931,  ...,  0.3755,  0.7378,  0.5857],
│   │                         [-0.8302, -0.2902, -1.6511,  ...,  0.3430,  0.1865,  1.2404],
│   │                         ...,
│   │                         [ 1.0207, -0.0122, -0.0082,  ..., -1.4070, -0.6281, -0.8955],
│   │                         [ 0.1467, -0.0616,  2.1614,  ..., -0.3725,  0.1029,  2.7000],
│   │                         [ 0.2139,  1.1587, -0.4425,  ...,  0.3872,  0.0879, -0.7829]],
│   │               
│   │                        [[-0.6936,  1.1570,  0.8572,  ..., -0.9941,  0.0824,  0.1595],
│   │                         [-2.1201, -1.1531, -0.9221,  ..., -1.3543,  0.7456, -0.4005],
│   │                         [ 0.7869,  0.5606,  1.0926,  ..., -2.4980,  0.4588,  0.7356],
│   │                         ...,
│   │                         [ 0.6115,  1.4341, -1.5251,  ..., -0.8367,  0.6509,  1.3660],
│   │                         [-1.1145,  0.2086, -0.2628,  ..., -0.0346, -1.5826,  1.8712],
│   │                         [-0.0666,  0.7066, -1.3453,  ...,  1.3645, -1.1624, -1.0078]],
│   │               
│   │                        [[-0.8399,  0.0598,  0.8255,  ..., -0.9895,  1.4253,  1.2686],
│   │                         [ 0.6406,  1.4451,  0.2393,  ...,  0.9703,  0.6170,  0.9758],
│   │                         [-0.3409,  1.3924, -0.0298,  ..., -0.8605, -0.1923, -0.9129],
│   │                         ...,
│   │                         [ 0.1663,  0.6345,  1.1979,  ..., -0.4291, -0.9687, -0.4037],
│   │                         [-0.2610, -1.1427, -0.4009,  ..., -0.2829,  0.4579,  0.7632],
│   │                         [ 0.0179,  0.0868,  0.1022,  ...,  0.8387,  0.1201, -0.8609]]]])
│   └── 'scalar' --> tensor([[-0.4414, -0.6400, -1.8005, -0.1124, -0.4949,  0.3936,  1.4571, -1.4940,
│                              0.5947, -0.2943, -0.1531,  1.1663],
│                            [ 2.0924, -0.7180, -1.4545, -0.6236, -0.1886,  1.2371, -0.2248,  0.3237,
│                             -1.3629,  1.3315,  1.8779, -1.0242],
│                            [-0.5884,  0.0132, -2.1677,  0.5975,  0.6938, -0.0754, -0.6058,  1.5750,
│                             -0.0597, -1.4768, -0.4550,  0.8848],
│                            [ 0.4093,  0.6826,  0.2254,  0.9126, -1.3075,  1.5418,  0.2922, -0.1208,
│                              0.2112, -0.1918, -0.0325, -0.4335]])
└── 'reward' --> tensor([[0.1347],
                         [0.3628],
                         [0.5227],
                         [0.3399]])

This code looks much simpler and clearer.