chunk

Documentation

treetensor.torch.chunk(input, chunks, *args, **kwargs)[source]

Splits a tensor into a specific number of chunks. Each chunk is a view of the input tensor.

Examples:

>>> import torch
>>> import treetensor.torch as ttorch
>>> t = torch.randint(100, (4, 5))
>>> t
tensor([[54, 97, 12, 48, 62],
        [92, 87, 28, 53, 54],
        [65, 82, 40, 26, 61],
        [75, 43, 86, 99,  7]])
>>> ttorch.chunk(t, 2)
(tensor([[54, 97, 12, 48, 62],
        [92, 87, 28, 53, 54]]), tensor([[65, 82, 40, 26, 61],
        [75, 43, 86, 99,  7]]))

>>> tt = ttorch.randint(100, {
...     'a': (4, 5),
...     'b': {'x': (2, 3, 4)},
... })
>>> tt
<Tensor 0x7f667e2fb358>
├── a --> tensor([[80,  2, 15, 45, 48],
│                 [38, 89, 34, 10, 34],
│                 [18, 99, 33, 38, 20],
│                 [43, 21, 35, 43, 37]])
└── b --> <Tensor 0x7f667e2fb278>
    └── x --> tensor([[[19, 17, 39, 68],
                       [41, 69, 33, 89],
                       [31, 88, 39, 14]],

                      [[27, 81, 84, 35],
                       [29, 65, 17, 72],
                       [53, 50, 75,  0]]])
>>> ttorch.chunk(tt, 2)
(<Tensor 0x7f667e9b7eb8>
├── a --> tensor([[80,  2, 15, 45, 48],
│                 [38, 89, 34, 10, 34]])
└── b --> <Tensor 0x7f667e2e7cf8>
    └── x --> tensor([[[19, 17, 39, 68],
                       [41, 69, 33, 89],
                       [31, 88, 39, 14]]])
, <Tensor 0x7f66f176dac8>
├── a --> tensor([[18, 99, 33, 38, 20],
│                 [43, 21, 35, 43, 37]])
└── b --> <Tensor 0x7f668030ba58>
    └── x --> tensor([[[27, 81, 84, 35],
                       [29, 65, 17, 72],
                       [53, 50, 75,  0]]])

Torch Version Related

This documentation is based on torch.chunk in torch v1.9.0+cu102. Its arguments’ arrangements depend on the version of pytorch you installed.

If some arguments listed here are not working properly, please check your pytorch’s version with the following command and find its documentation.

1
python -c 'import torch;print(torch.__version__)'

The arguments and keyword arguments supported in torch v1.9.0+cu102 is listed below.

Description From Torch v1.9.0+cu102

torch.chunk(input, chunks, dim=0) → List of Tensors

Attempts to split a tensor into the specified number of chunks. Each chunk is a view of the input tensor.

Note

This function may return less then the specified number of chunks!

See also

torch.tensor_split() a function that always returns exactly the specified number of chunks

If the tensor size along the given dimesion dim is divisible by chunks, all returned chunks will be the same size. If the tensor size along the given dimension dim is not divisible by chunks, all returned chunks will be the same size, except the last one. If such division is not possible, this function may return less than the specified number of chunks.

Arguments:

input (Tensor): the tensor to split chunks (int): number of chunks to return dim (int): dimension along which to split the tensor

Example::
>>> torch.arange(11).chunk(6)
(tensor([0, 1]),
 tensor([2, 3]),
 tensor([4, 5]),
 tensor([6, 7]),
 tensor([8, 9]),
 tensor([10]))
>>> torch.arange(12).chunk(6)
(tensor([0, 1]),
 tensor([2, 3]),
 tensor([4, 5]),
 tensor([6, 7]),
 tensor([8, 9]),
 tensor([10, 11]))
>>> torch.arange(13).chunk(6)
(tensor([0, 1, 2]),
 tensor([3, 4, 5]),
 tensor([6, 7, 8]),
 tensor([ 9, 10, 11]),
 tensor([12]))