cat

Documentation

treetensor.torch.cat(tensors, *args, **kwargs)[source]

Concatenates the given sequence of seq tensors in the given dimension. All tensors must either have the same shape (except in the concatenating dimension) or be empty.

Examples:

>>> import torch
>>> import treetensor.torch as ttorch
>>> t1 = torch.randint(10, 30, (2, 3))
>>> t1
tensor([[21, 29, 17],
        [16, 11, 16]])
>>> t2 = torch.randint(30, 50, (2, 3))
tensor([[46, 46, 46],
        [30, 47, 36]])
>>> t2
>>> t3 = torch.randint(50, 70, (2, 3))
tensor([[51, 65, 65],
        [54, 67, 57]])
>>> t3
>>> ttorch.cat((t1, t2, t3))
tensor([[21, 29, 17],
        [16, 11, 16],
        [46, 46, 46],
        [30, 47, 36],
        [51, 65, 65],
        [54, 67, 57]])
>>> tt1 = ttorch.Tensor({
...    'a': t1,
...    'b': {'x': t2, 'y': t3},
... })
>>> tt1
<Tensor 0x7fed579acf60>
├── a --> tensor([[21, 29, 17],
│                 [16, 11, 16]])
└── b --> <Tensor 0x7fed579acf28>
    ├── x --> tensor([[46, 46, 46],
    │                 [30, 47, 36]])
    └── y --> tensor([[51, 65, 65],
                      [54, 67, 57]])
>>> tt2 = ttorch.Tensor({
...    'a': t2,
...    'b': {'x': t3, 'y': t1},
... })
>>> tt2
<Tensor 0x7fed579d62e8>
├── a --> tensor([[46, 46, 46],
│                 [30, 47, 36]])
└── b --> <Tensor 0x7fed579d62b0>
    ├── x --> tensor([[51, 65, 65],
    │                 [54, 67, 57]])
    └── y --> tensor([[21, 29, 17],
                      [16, 11, 16]])
>>> tt3 = ttorch.Tensor({
...    'a': t3,
...    'b': {'x': t1, 'y': t2},
... })
>>> tt3
<Tensor 0x7fed579d66a0>
├── a --> tensor([[51, 65, 65],
│                 [54, 67, 57]])
└── b --> <Tensor 0x7fed579d65f8>
    ├── x --> tensor([[21, 29, 17],
    │                 [16, 11, 16]])
    └── y --> tensor([[46, 46, 46],
                      [30, 47, 36]]
>>> ttorch.cat((tt1, tt2, tt3))
<Tensor 0x7fed579d6ac8>
├── a --> tensor([[21, 29, 17],
│                 [16, 11, 16],
│                 [46, 46, 46],
│                 [30, 47, 36],
│                 [51, 65, 65],
│                 [54, 67, 57]])
└── b --> <Tensor 0x7fed579d6a90>
    ├── x --> tensor([[46, 46, 46],
    │                 [30, 47, 36],
    │                 [51, 65, 65],
    │                 [54, 67, 57],
    │                 [21, 29, 17],
    │                 [16, 11, 16]])
    └── y --> tensor([[51, 65, 65],
                      [54, 67, 57],
                      [21, 29, 17],
                      [16, 11, 16],
                      [46, 46, 46],
                      [30, 47, 36]])
>>> ttorch.cat((tt1, tt2, tt3), dim=1)
<Tensor 0x7fed579644a8>
├── a --> tensor([[21, 29, 17, 46, 46, 46, 51, 65, 65],
│                 [16, 11, 16, 30, 47, 36, 54, 67, 57]])
└── b --> <Tensor 0x7fed57964438>
    ├── x --> tensor([[46, 46, 46, 51, 65, 65, 21, 29, 17],
    │                 [30, 47, 36, 54, 67, 57, 16, 11, 16]])
    └── y --> tensor([[51, 65, 65, 21, 29, 17, 46, 46, 46],
                      [54, 67, 57, 16, 11, 16, 30, 47, 36]])

Torch Version Related

This documentation is based on torch.cat in torch v2.4.1+cu121. Its arguments’ arrangements depend on the version of pytorch you installed.

If some arguments listed here are not working properly, please check your pytorch’s version with the following command and find its documentation.

1
python -c 'import torch;print(torch.__version__)'

The arguments and keyword arguments supported in torch v2.4.1+cu121 is listed below.

Description From Torch v2.4.1+cu121

torch.cat(tensors, dim=0, *, out=None)Tensor

Concatenates the given sequence of seq tensors in the given dimension. All tensors must either have the same shape (except in the concatenating dimension) or be empty.

torch.cat() can be seen as an inverse operation for torch.split() and torch.chunk().

torch.cat() can be best understood via examples.

Args:
tensors (sequence of Tensors): any python sequence of tensors of the same type.

Non-empty tensors provided must have the same shape, except in the cat dimension.

dim (int, optional): the dimension over which the tensors are concatenated

Keyword args:

out (Tensor, optional): the output tensor.

Example:

>>> x = torch.randn(2, 3)
>>> x
tensor([[ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497]])
>>> torch.cat((x, x, x), 0)
tensor([[ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497],
        [ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497],
        [ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497]])
>>> torch.cat((x, x, x), 1)
tensor([[ 0.6580, -1.0969, -0.4614,  0.6580, -1.0969, -0.4614,  0.6580,
         -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497, -0.1034, -0.5790,  0.1497, -0.1034,
         -0.5790,  0.1497]])