Stack Structured Data

When the tensors form the tree structures, they are often needed to be stacked together, like the torch.stack() implemented in torch.

Stack With Native PyTorch API

Here is the common code implement with native pytorch API.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch

B = 4


def get_item():
    return {
        'obs': {
            'scalar': torch.randn(12),
            'image': torch.randn(3, 32, 32),
        },
        'action': torch.randint(0, 10, size=(1,)),
        'reward': torch.rand(1),
        'done': False,
    }


data = [get_item() for _ in range(B)]


# execute `stack` op
def stack(data, dim):
    elem = data[0]
    if isinstance(elem, torch.Tensor):
        return torch.stack(data, dim)
    elif isinstance(elem, dict):
        return {k: stack([item[k] for item in data], dim) for k in elem.keys()}
    elif isinstance(elem, bool):
        return torch.BoolTensor(data)
    else:
        raise TypeError("not support elem type: {}".format(type(elem)))


stacked_data = stack(data, dim=0)
# validate
print(stacked_data)
assert stacked_data['obs']['image'].shape == (B, 3, 32, 32)
assert stacked_data['action'].shape == (B, 1)
assert stacked_data['reward'].shape == (B, 1)
assert stacked_data['done'].shape == (B,)
assert stacked_data['done'].dtype == torch.bool

The output should be like below, and the assertion statements can be all passed.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
{'obs': {'scalar': tensor([[-0.9113, -0.7209, -0.4847, -1.4918,  0.8335, -0.2019,  0.2615,  0.3216,
         -0.2142, -0.5677,  0.7933,  0.8666],
        [ 0.9760, -1.5132, -1.1468,  0.6075,  0.4654,  1.8258,  0.5017,  0.2355,
          0.5415,  2.3308, -0.6925,  0.5418],
        [-0.3199, -0.5272,  1.7276,  0.7962, -0.6678,  0.4108, -0.4179, -0.9046,
          0.2216, -1.0984, -1.9052, -0.7313],
        [-0.1924, -0.2541, -0.7395,  1.1319,  0.8176,  0.7279,  0.6416, -0.5540,
          1.4819,  0.7817,  1.4280, -0.8363]]), 'image': tensor([[[[-1.1964e+00, -8.6079e-01,  4.7092e-02,  ...,  1.2303e-01,
           -6.3324e-01, -9.6039e-01],
          [ 3.5278e-01, -1.0073e+00, -1.3316e+00,  ..., -4.4948e-01,
           -1.0528e+00, -6.2638e-01],
          [ 4.9949e-01,  1.2163e+00,  7.0224e-01,  ..., -1.2757e+00,
            6.0162e-02, -1.7536e-01],
          ...,
          [ 1.1617e+00,  1.0584e+00,  1.2937e+00,  ..., -9.5239e-05,
            1.1444e+00,  5.0917e-01],
          [ 2.1623e-01, -1.7004e+00, -5.1624e-02,  ...,  1.7221e-01,
            3.3747e+00, -6.5288e-01],
          [-4.4038e-01, -8.8284e-01, -8.6973e-01,  ...,  1.1058e+00,
            1.1596e-01, -1.6266e+00]],

         [[ 2.5965e-01, -2.0223e-01,  4.2455e-01,  ..., -1.2973e+00,
           -1.4218e+00, -4.4277e-01],
          [ 3.5503e-01, -2.8316e-01, -1.0177e-01,  ..., -1.4600e+00,
            2.3569e+00,  2.8268e-01],
          [-1.8721e-01,  4.2144e-02, -7.2326e-01,  ..., -1.1010e-01,
           -7.7564e-01, -6.7578e-01],
          ...,
          [-1.2661e+00, -1.8040e+00,  1.7180e+00,  ..., -6.5114e-02,
           -4.1953e-01, -1.2971e+00],
          [ 2.1322e+00, -1.5557e+00, -4.0394e-01,  ..., -4.5846e-01,
           -4.2880e-01,  1.1495e+00],
          [ 4.2218e-01,  1.4657e+00,  8.1626e-01,  ..., -6.0209e-01,
            4.2050e-01, -9.0197e-01]],

         [[-1.2321e+00,  1.2804e+00,  1.4743e+00,  ...,  1.2236e+00,
           -8.3963e-01, -7.4460e-01],
          [ 8.2249e-02, -8.3473e-01,  2.1426e-01,  ...,  7.8123e-01,
            1.6367e-01,  3.6201e-01],
          [ 1.7081e+00,  3.3161e-01, -1.3525e+00,  ...,  1.4539e+00,
           -1.0296e+00,  9.3551e-01],
          ...,
          [-2.1335e-01, -4.2061e-01, -1.7966e+00,  ...,  1.7179e-01,
            2.4941e+00, -1.5568e+00],
          [ 1.3207e+00,  1.7287e+00,  1.5876e+00,  ...,  1.0109e-01,
           -3.7381e-01,  2.5621e+00],
          [ 1.5598e-01,  5.6712e-01,  8.1285e-01,  ..., -1.1041e+00,
           -1.2515e+00, -1.4027e+00]]],


        [[[ 2.3766e-01, -2.8881e-01,  9.8879e-01,  ..., -6.5365e-01,
           -2.2780e-02,  3.3694e-01],
          [-1.0389e+00,  3.6671e-01,  2.2884e-01,  ...,  3.9363e-01,
            5.4940e-01, -2.3018e-01],
          [ 1.9703e-01, -2.2683e-01, -5.4376e-01,  ...,  3.8348e-01,
            1.8677e+00,  2.1625e-01],
          ...,
          [ 5.7010e-01,  1.0903e+00,  7.0785e-01,  ...,  2.5870e-01,
           -1.1420e+00, -7.2567e-01],
          [-1.1939e-01,  1.5118e+00, -1.1263e+00,  ...,  7.3601e-01,
            4.5913e-01, -1.3921e+00],
          [-1.0915e+00, -2.0509e-01, -5.3563e-01,  ..., -1.3924e-01,
            7.8354e-03,  1.6672e+00]],

         [[ 1.0841e+00,  1.3654e+00, -6.9407e-01,  ...,  1.9960e-01,
            1.9760e+00, -1.1395e+00],
          [-1.7497e+00,  8.9221e-02, -4.0902e-01,  ..., -9.4724e-01,
           -8.2431e-01,  2.5489e-01],
          [ 2.8065e-01,  6.7473e-01,  1.3273e+00,  ..., -2.5227e-01,
           -7.0834e-01,  9.8351e-01],
          ...,
          [ 5.4765e-01, -2.0284e-01,  7.6671e-01,  ...,  4.3138e-01,
            1.3686e+00,  1.9292e-01],
          [ 3.8161e-01, -6.0902e-01,  4.4155e-01,  ..., -3.6514e-01,
            8.5676e-02, -1.4406e+00],
          [-6.1829e-01,  2.3909e-01,  1.7460e+00,  ...,  1.5342e+00,
           -1.0532e+00,  5.9932e-01]],

         [[ 7.1244e-01,  1.2411e-01, -3.3333e-01,  ...,  3.5249e-01,
            9.9047e-01, -1.4281e+00],
          [-1.9010e-01, -7.0332e-02,  4.6186e-01,  ...,  8.4931e-01,
           -1.9607e+00,  7.1134e-01],
          [ 9.2714e-01, -6.8384e-01,  7.2820e-02,  ...,  1.8362e-01,
            3.3311e-01, -6.6974e-01],
          ...,
          [ 5.0792e-01,  4.0025e-01,  8.7194e-02,  ...,  5.7357e-01,
            4.3979e-01,  5.8732e-01],
          [-1.0336e+00, -3.9545e-01, -7.0326e-01,  ...,  1.2167e+00,
           -2.0329e+00, -1.2869e+00],
          [-6.4341e-01, -1.1833e+00, -5.1797e-01,  ...,  3.4462e-01,
            1.5828e+00,  1.9972e-02]]],


        [[[-1.0255e+00,  1.3340e+00, -7.3676e-01,  ...,  1.3319e+00,
            3.5739e-02,  2.2997e-01],
          [-8.7920e-01,  2.4783e-01, -4.8060e-01,  ..., -1.2910e+00,
            3.1002e-02,  1.4098e+00],
          [-8.5380e-02,  1.1558e+00, -1.6116e-01,  ..., -5.0077e-02,
           -4.0487e-01, -2.3421e-01],
          ...,
          [ 6.6562e-01,  4.8378e-01, -1.2141e-01,  ..., -1.3205e+00,
            1.0615e+00,  7.1050e-01],
          [-9.1371e-01, -4.8785e-01, -3.5342e-01,  ..., -2.3588e+00,
            1.7215e+00,  7.0626e-01],
          [ 8.7267e-01,  1.1375e+00, -7.9470e-01,  ..., -8.1569e-01,
            8.5635e-01,  3.1755e-02]],

         [[ 3.5999e-01,  7.7980e-01, -3.7833e-01,  ..., -3.6177e-01,
            1.0416e+00, -1.0043e-01],
          [ 1.0319e+00,  7.0434e-01, -8.3656e-02,  ..., -1.9342e+00,
           -6.5364e-01, -9.5450e-01],
          [ 8.0914e-02, -1.4190e+00,  1.4827e+00,  ...,  1.3368e+00,
           -1.7807e+00, -5.6781e-01],
          ...,
          [ 3.3535e-01,  8.3233e-01,  8.4417e-01,  ...,  8.8743e-01,
           -1.0978e+00,  1.6088e+00],
          [-1.5287e+00,  1.3854e+00, -1.0972e+00,  ..., -9.2300e-01,
            4.0743e-03, -2.3488e-01],
          [ 6.5995e-01, -1.1610e+00,  1.5864e-01,  ..., -1.1585e-01,
            3.8544e-01,  6.6391e-01]],

         [[ 5.7360e-01,  8.9739e-01, -1.9272e+00,  ...,  1.7101e+00,
           -4.2272e-01,  1.0572e+00],
          [ 1.1105e+00,  5.2765e-02,  4.5948e-01,  ...,  2.1639e-01,
            8.0196e-01, -1.0610e-01],
          [ 1.4433e-01,  1.2607e+00,  8.5110e-01,  ..., -3.2086e-01,
            9.8531e-01, -6.4746e-03],
          ...,
          [ 9.5911e-02, -2.1081e-01,  2.5407e-01,  ...,  2.3643e+00,
            5.4184e-01, -2.1505e+00],
          [-7.7973e-02,  1.5132e+00, -3.2005e-01,  ..., -7.8942e-01,
            1.0684e+00,  4.2773e-01],
          [ 1.7395e+00,  1.0566e+00,  1.9927e-01,  ...,  1.1217e+00,
           -1.8321e+00, -5.0340e-01]]],


        [[[-3.9652e-01,  3.4130e-01, -1.4673e+00,  ..., -6.9828e-01,
            2.3274e-01, -1.1805e+00],
          [ 1.0989e+00, -1.2300e+00, -1.6241e+00,  ..., -1.2917e-01,
            2.4630e-01,  7.0535e-01],
          [-2.0685e-01, -2.0882e-01,  2.4048e-01,  ...,  2.3500e-01,
            6.5637e-01,  8.8955e-01],
          ...,
          [-2.0183e-01,  3.3561e-01,  2.2136e+00,  ..., -1.2366e+00,
           -8.9954e-02,  2.7329e-01],
          [-3.9722e-01,  1.1728e+00,  1.4809e-01,  ..., -1.0439e+00,
            5.0509e-01,  2.4112e+00],
          [-6.3980e-01,  1.8956e+00, -7.0277e-01,  ...,  5.2495e-01,
           -1.5189e+00, -5.7212e-01]],

         [[-1.1504e+00,  9.6436e-01, -1.0619e+00,  ..., -1.4124e+00,
           -5.8424e-01,  1.1320e+00],
          [ 1.0901e+00,  1.8096e-02, -9.6452e-01,  ..., -6.6738e-01,
           -1.0322e+00, -1.6333e+00],
          [ 1.2873e+00, -1.2441e+00,  7.2931e-01,  ..., -1.2518e+00,
            5.1202e-01, -4.7914e-02],
          ...,
          [ 5.1793e-01, -6.1198e-01, -3.7510e-02,  ..., -1.2568e-02,
            1.2235e+00, -4.7580e-01],
          [-4.6653e-02,  9.8209e-02,  2.1920e-01,  ..., -1.2246e+00,
           -4.6448e-01, -7.8586e-01],
          [ 1.0001e+00, -3.5796e-01,  3.9165e-01,  ..., -1.2324e-01,
           -1.2457e+00,  7.6203e-01]],

         [[-6.5394e-01,  9.6436e-01, -2.8315e+00,  ...,  9.5748e-03,
           -4.4815e-01, -1.1895e+00],
          [ 1.1951e+00, -5.6323e-01,  5.8942e-01,  ..., -4.0217e-01,
           -8.6341e-01,  7.5368e-01],
          [ 2.7925e-01, -5.6985e-01,  3.8201e-02,  ...,  4.4162e-01,
           -8.9407e-01,  2.2242e+00],
          ...,
          [-1.0422e+00,  1.4222e-01,  5.2648e-01,  ...,  1.2441e+00,
           -2.6578e-01, -2.6100e+00],
          [-1.5528e+00,  1.1261e+00, -3.3701e-01,  ...,  1.0143e+00,
            7.5371e-01, -4.5958e-01],
          [ 2.5475e+00, -7.3839e-01,  6.9958e-01,  ...,  6.7692e-01,
           -4.0688e-01, -1.5574e-01]]]])}, 'action': tensor([[1],
        [7],
        [9],
        [9]]), 'reward': tensor([[0.4491],
        [0.0669],
        [0.6433],
        [0.7731]]), 'done': tensor([False, False, False, False])}

We can see that the structure process function need to be fully implemented (like the function stack). This code is actually not clear, and due to hard coding, if you need to support more data types (such as integer), you must make special modifications to the function.

Stack With TreeTensor API

The same workflow can be implemented with treetensor API like the code below.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch

import treetensor.torch as ttorch

B = 4


def get_item():
    return {
        'obs': {
            'scalar': torch.randn(12),
            'image': torch.randn(3, 32, 32),
        },
        'action': torch.randint(0, 10, size=(1,)),
        'reward': torch.rand(1),
        'done': False,
    }


data = [get_item() for _ in range(B)]
# execute `stack` op
data = [ttorch.tensor(d) for d in data]
stacked_data = ttorch.stack(data, dim=0)
# validate
print(stacked_data)
assert stacked_data.obs.image.shape == (B, 3, 32, 32)
assert stacked_data.action.shape == (B, 1)
assert stacked_data.reward.shape == (B, 1)
assert stacked_data.done.shape == (B,)
assert stacked_data.done.dtype == torch.bool

The output should be like below, and the assertion statements can be all passed as well.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
<Tensor 0x7fe40196e190>
├── 'action' --> tensor([[6],
│                        [0],
│                        [0],
│                        [8]])
├── 'done' --> tensor([False, False, False, False])
├── 'obs' --> <Tensor 0x7fe4025b3b20>
│   ├── 'image' --> tensor([[[[ 0.0267,  0.8343,  2.5198,  ...,  1.0571, -0.4453, -0.5944],
│   │                         [ 0.8863,  0.3701,  1.3998,  ...,  2.3309, -0.7624, -0.2236],
│   │                         [-0.4457, -0.7192, -0.0828,  ..., -1.0354,  1.1340,  0.0211],
│   │                         ...,
│   │                         [-0.5816, -0.6829, -0.3241,  ...,  0.1864,  0.1250, -0.3425],
│   │                         [ 0.9339, -0.1827, -0.1843,  ...,  2.3085, -2.0940,  1.3535],
│   │                         [ 0.4749,  1.7333, -0.1295,  ..., -0.2963,  0.8657, -0.2915]],
│   │               
│   │                        [[ 2.2788,  2.5834,  0.5397,  ...,  1.0994, -0.3483, -0.7041],
│   │                         [-1.2579, -0.8342,  1.0182,  ..., -1.3065,  1.0446,  0.5431],
│   │                         [-0.3019, -0.7245, -0.1647,  ..., -1.2206,  0.2576,  0.3957],
│   │                         ...,
│   │                         [-0.8776, -1.1805,  0.6373,  ..., -0.3053, -0.8648,  0.0935],
│   │                         [ 0.1497,  0.1045, -1.3512,  ...,  0.9487, -0.3691, -1.0121],
│   │                         [ 2.6098,  0.6113, -0.7116,  ...,  1.3492,  0.6084,  0.6611]],
│   │               
│   │                        [[-0.1536, -1.4904,  0.9155,  ...,  1.9811,  0.7838, -2.1625],
│   │                         [-0.2147,  0.7841, -0.4452,  ...,  2.2239,  0.5407, -1.6761],
│   │                         [ 0.3650,  0.4774, -0.4216,  ...,  1.6431, -1.4251, -0.1586],
│   │                         ...,
│   │                         [ 1.6534, -0.3216,  1.2252,  ..., -1.7739, -1.1805,  0.5154],
│   │                         [ 0.9687,  0.8533,  1.2503,  ..., -0.2928, -0.8879,  0.3960],
│   │                         [-0.7562,  1.6897,  0.2523,  ...,  0.1908, -0.7237, -0.9633]]],
│   │               
│   │               
│   │                       [[[-0.0237,  0.0276,  2.2183,  ..., -0.4149,  0.9596,  0.0898],
│   │                         [-0.7752,  1.1594,  1.1570,  ..., -0.5257,  0.0557,  0.3701],
│   │                         [-1.8050, -1.0791,  0.4008,  ..., -2.3522, -2.0305, -1.6548],
│   │                         ...,
│   │                         [-1.6219, -1.3961,  0.8476,  ...,  1.0411,  0.5228, -0.7115],
│   │                         [-0.2220, -0.8346, -0.5224,  ..., -0.2866, -0.4520, -0.8397],
│   │                         [-1.5436,  0.2160,  0.4307,  ..., -0.0243,  0.2860,  1.5847]],
│   │               
│   │                        [[ 0.2585, -2.9524, -1.1559,  ...,  1.2932,  0.1052,  0.6047],
│   │                         [ 1.2901,  0.5247, -0.3650,  ..., -3.2959,  0.3932,  0.2023],
│   │                         [ 1.7493,  1.4975, -2.0130,  ..., -1.5628,  0.7498,  0.9718],
│   │                         ...,
│   │                         [-1.1317,  0.1539,  0.0803,  ..., -0.4512, -1.4542, -1.0835],
│   │                         [ 0.6582, -0.3268, -0.5817,  ...,  1.6970,  1.4607, -0.9747],
│   │                         [-0.8706,  0.1163, -0.6912,  ..., -1.6270,  0.5980, -2.5715]],
│   │               
│   │                        [[ 0.4119,  0.0897, -1.1856,  ..., -0.3508,  0.1616,  0.1913],
│   │                         [ 1.0237,  0.6625, -1.4668,  ...,  0.4864, -0.3385,  0.2232],
│   │                         [-2.0009,  0.1625, -1.6798,  ..., -0.1154, -0.5619, -0.3741],
│   │                         ...,
│   │                         [-1.0302,  0.8595,  0.8145,  ..., -1.4556,  0.6917, -0.5305],
│   │                         [-0.4360,  1.5964,  0.7274,  ...,  1.6975,  0.7693,  0.4317],
│   │                         [-0.4281, -0.2501, -0.6060,  ..., -0.7333, -1.8474,  0.8055]]],
│   │               
│   │               
│   │                       [[[ 0.5985,  0.3883, -1.0900,  ...,  0.8271, -1.1860,  0.1572],
│   │                         [-0.1344, -1.0403,  1.0422,  ...,  0.1132, -1.2074,  1.2836],
│   │                         [ 0.6032,  1.7911, -0.5705,  ..., -0.5540,  1.8938, -0.9266],
│   │                         ...,
│   │                         [ 0.0629, -1.8127,  0.3975,  ..., -0.8243,  1.0682,  2.2695],
│   │                         [-1.0177,  0.3524,  0.3970,  ...,  0.2359, -0.7051,  0.6873],
│   │                         [-0.7503, -2.0267,  0.3791,  ..., -0.1665, -1.0114, -0.7592]],
│   │               
│   │                        [[-0.2016, -0.6799, -0.4761,  ...,  1.0312,  0.7132,  0.4491],
│   │                         [ 0.2942, -0.5316, -0.0127,  ...,  0.5509,  0.1028,  0.1102],
│   │                         [-0.1310,  2.4959,  0.6950,  ..., -1.0916,  0.4613,  0.2002],
│   │                         ...,
│   │                         [-0.8046, -2.1256, -0.0986,  ...,  0.8290,  1.0535, -0.8581],
│   │                         [ 1.6637,  1.4343,  0.2300,  ..., -1.3360,  2.0963, -1.6006],
│   │                         [-1.4208,  0.9017,  1.7077,  ..., -1.2383, -2.1597,  0.3052]],
│   │               
│   │                        [[-1.4921, -1.3696, -2.5126,  ..., -1.3825, -0.6352,  0.3851],
│   │                         [-0.3665,  0.4416, -0.3694,  ..., -0.4565, -0.5017, -0.0906],
│   │                         [-0.2232, -0.4061,  0.4095,  ..., -0.1806, -1.3190,  0.5997],
│   │                         ...,
│   │                         [-0.8555,  0.4470, -0.0720,  ..., -0.6140, -0.1779, -0.0705],
│   │                         [-0.5370, -1.0031,  0.9889,  ...,  0.3644,  0.0998, -0.0149],
│   │                         [-1.4868, -0.6425,  1.4812,  ..., -0.9816, -0.7006, -0.7555]]],
│   │               
│   │               
│   │                       [[[ 0.8749,  1.2708,  0.9577,  ..., -0.3517, -0.6191, -0.0096],
│   │                         [-0.0848, -0.2048, -0.5411,  ...,  1.5998, -0.2143,  2.3013],
│   │                         [-1.6901, -1.0551, -0.5186,  ...,  0.3698,  0.6493, -1.0633],
│   │                         ...,
│   │                         [-0.8857, -1.3365,  0.7419,  ...,  0.2862, -1.1870, -1.7038],
│   │                         [-1.1026, -0.1183, -0.3383,  ..., -1.0253,  1.7018,  0.1913],
│   │                         [-0.0362, -0.3655, -0.5131,  ..., -0.6326,  1.8401, -1.0757]],
│   │               
│   │                        [[ 0.9615, -0.3033,  0.2580,  ..., -0.5370, -0.6287, -1.7773],
│   │                         [-0.5534,  0.9427,  1.6600,  ...,  0.6480,  0.3010, -0.8434],
│   │                         [ 1.5634, -0.0763,  1.5279,  ...,  0.7159, -0.2430, -0.7256],
│   │                         ...,
│   │                         [-0.5573,  0.5660,  1.0498,  ..., -0.7879, -0.6572,  0.2114],
│   │                         [-0.1946,  0.3268,  0.6696,  ..., -0.3016,  0.0972,  0.4719],
│   │                         [ 0.8164, -1.2675,  0.4952,  ..., -1.3345, -1.3787, -1.5843]],
│   │               
│   │                        [[-1.2139, -1.3951,  0.1942,  ..., -0.0086,  1.3501,  0.6939],
│   │                         [ 0.0388, -0.6861, -1.3002,  ..., -0.7628, -0.4953,  0.8375],
│   │                         [ 2.1851, -0.2986, -0.1220,  ..., -0.3593, -0.3614, -0.8529],
│   │                         ...,
│   │                         [-0.9956,  0.1536, -0.1458,  ..., -1.1383,  0.5041, -0.2082],
│   │                         [ 2.0207,  0.5529, -1.9429,  ..., -0.5377,  0.3415, -0.1027],
│   │                         [ 1.4176, -0.7118, -0.4360,  ..., -0.2984, -1.5083,  0.2817]]]])
│   └── 'scalar' --> tensor([[ 1.0591, -0.6506,  0.3507, -1.2608, -0.4514,  0.6761,  1.0317,  1.3853,
│                             -1.3662, -0.6998,  0.7108,  0.7072],
│                            [-0.6296,  0.5803, -0.1321,  1.0066, -0.3476, -0.8866, -0.4370, -0.9014,
│                             -0.4515, -0.1951,  0.7336, -0.3984],
│                            [-0.8335, -0.4403,  1.3306,  1.5303,  0.4417,  0.6500,  1.2325,  0.6306,
│                              1.3441,  0.1902, -0.1798, -0.7636],
│                            [ 0.0438,  0.4476,  0.7897, -1.0885,  0.7525,  2.5854,  0.0656, -1.8516,
│                             -1.4118,  1.2392,  0.0295, -0.6463]])
└── 'reward' --> tensor([[0.5136],
                         [0.9807],
                         [0.1835],
                         [0.7997]])

This code looks much simpler and clearer.