cat¶
Documentation¶
-
treetensor.torch.
cat
(tensors, *args, **kwargs)[source]¶ Concatenates the given sequence of
seq
tensors in the given dimension. All tensors must either have the same shape (except in the concatenating dimension) or be empty.Examples:
>>> import torch >>> import treetensor.torch as ttorch >>> t1 = torch.randint(10, 30, (2, 3)) >>> t1 tensor([[21, 29, 17], [16, 11, 16]]) >>> t2 = torch.randint(30, 50, (2, 3)) tensor([[46, 46, 46], [30, 47, 36]]) >>> t2 >>> t3 = torch.randint(50, 70, (2, 3)) tensor([[51, 65, 65], [54, 67, 57]]) >>> t3 >>> ttorch.cat((t1, t2, t3)) tensor([[21, 29, 17], [16, 11, 16], [46, 46, 46], [30, 47, 36], [51, 65, 65], [54, 67, 57]])
>>> tt1 = ttorch.Tensor({ ... 'a': t1, ... 'b': {'x': t2, 'y': t3}, ... }) >>> tt1 <Tensor 0x7fed579acf60> ├── a --> tensor([[21, 29, 17], │ [16, 11, 16]]) └── b --> <Tensor 0x7fed579acf28> ├── x --> tensor([[46, 46, 46], │ [30, 47, 36]]) └── y --> tensor([[51, 65, 65], [54, 67, 57]]) >>> tt2 = ttorch.Tensor({ ... 'a': t2, ... 'b': {'x': t3, 'y': t1}, ... }) >>> tt2 <Tensor 0x7fed579d62e8> ├── a --> tensor([[46, 46, 46], │ [30, 47, 36]]) └── b --> <Tensor 0x7fed579d62b0> ├── x --> tensor([[51, 65, 65], │ [54, 67, 57]]) └── y --> tensor([[21, 29, 17], [16, 11, 16]]) >>> tt3 = ttorch.Tensor({ ... 'a': t3, ... 'b': {'x': t1, 'y': t2}, ... }) >>> tt3 <Tensor 0x7fed579d66a0> ├── a --> tensor([[51, 65, 65], │ [54, 67, 57]]) └── b --> <Tensor 0x7fed579d65f8> ├── x --> tensor([[21, 29, 17], │ [16, 11, 16]]) └── y --> tensor([[46, 46, 46], [30, 47, 36]] >>> ttorch.cat((tt1, tt2, tt3)) <Tensor 0x7fed579d6ac8> ├── a --> tensor([[21, 29, 17], │ [16, 11, 16], │ [46, 46, 46], │ [30, 47, 36], │ [51, 65, 65], │ [54, 67, 57]]) └── b --> <Tensor 0x7fed579d6a90> ├── x --> tensor([[46, 46, 46], │ [30, 47, 36], │ [51, 65, 65], │ [54, 67, 57], │ [21, 29, 17], │ [16, 11, 16]]) └── y --> tensor([[51, 65, 65], [54, 67, 57], [21, 29, 17], [16, 11, 16], [46, 46, 46], [30, 47, 36]]) >>> ttorch.cat((tt1, tt2, tt3), dim=1) <Tensor 0x7fed579644a8> ├── a --> tensor([[21, 29, 17, 46, 46, 46, 51, 65, 65], │ [16, 11, 16, 30, 47, 36, 54, 67, 57]]) └── b --> <Tensor 0x7fed57964438> ├── x --> tensor([[46, 46, 46, 51, 65, 65, 21, 29, 17], │ [30, 47, 36, 54, 67, 57, 16, 11, 16]]) └── y --> tensor([[51, 65, 65, 21, 29, 17, 46, 46, 46], [54, 67, 57, 16, 11, 16, 30, 47, 36]])
Torch Version Related
This documentation is based on torch.cat in torch v2.4.1+cu121. Its arguments’ arrangements depend on the version of pytorch you installed.
If some arguments listed here are not working properly, please check your pytorch’s version with the following command and find its documentation.
1 | python -c 'import torch;print(torch.__version__)' |
The arguments and keyword arguments supported in torch v2.4.1+cu121 is listed below.
Description From Torch v2.4.1+cu121¶
-
torch.
cat
(tensors, dim=0, *, out=None) → Tensor¶ Concatenates the given sequence of
seq
tensors in the given dimension. All tensors must either have the same shape (except in the concatenating dimension) or be empty.torch.cat()
can be seen as an inverse operation fortorch.split()
andtorch.chunk()
.torch.cat()
can be best understood via examples.- Args:
- tensors (sequence of Tensors): any python sequence of tensors of the same type.
Non-empty tensors provided must have the same shape, except in the cat dimension.
dim (int, optional): the dimension over which the tensors are concatenated
- Keyword args:
out (Tensor, optional): the output tensor.
Example:
>>> x = torch.randn(2, 3) >>> x tensor([[ 0.6580, -1.0969, -0.4614], [-0.1034, -0.5790, 0.1497]]) >>> torch.cat((x, x, x), 0) tensor([[ 0.6580, -1.0969, -0.4614], [-0.1034, -0.5790, 0.1497], [ 0.6580, -1.0969, -0.4614], [-0.1034, -0.5790, 0.1497], [ 0.6580, -1.0969, -0.4614], [-0.1034, -0.5790, 0.1497]]) >>> torch.cat((x, x, x), 1) tensor([[ 0.6580, -1.0969, -0.4614, 0.6580, -1.0969, -0.4614, 0.6580, -1.0969, -0.4614], [-0.1034, -0.5790, 0.1497, -0.1034, -0.5790, 0.1497, -0.1034, -0.5790, 0.1497]])