squeeze

Documentation

treetensor.torch.squeeze(input, *args, **kwargs)[source]

Returns a tensor with all the dimensions of input of size 1 removed.

Examples:

>>> import torch
>>> import treetensor.torch as ttorch
>>> t1 = torch.randint(100, (2, 1, 2, 1, 2))
>>> t1.shape
torch.Size([2, 1, 2, 1, 2])
>>> ttorch.squeeze(t1).shape
torch.Size([2, 2, 2])

>>> tt1 = ttorch.randint(100, {
...     'a': (2, 1, 2, 1, 2),
...     'b': {'x': (2, 1, 1, 3)},
... })
>>> tt1.shape
<Size 0x7fa4c1b05410>
├── a --> torch.Size([2, 1, 2, 1, 2])
└── b --> <Size 0x7fa4c1b05510>
    └── x --> torch.Size([2, 1, 1, 3])
>>> ttorch.squeeze(tt1).shape
<Size 0x7fa4c1b9f3d0>
├── a --> torch.Size([2, 2, 2])
└── b --> <Size 0x7fa4c1afe710>
    └── x --> torch.Size([2, 3])

Torch Version Related

This documentation is based on torch.squeeze in torch v2.0.1+cu117. Its arguments’ arrangements depend on the version of pytorch you installed.

If some arguments listed here are not working properly, please check your pytorch’s version with the following command and find its documentation.

1
python -c 'import torch;print(torch.__version__)'

The arguments and keyword arguments supported in torch v2.0.1+cu117 is listed below.

Description From Torch v2.0.1+cu117

torch.squeeze(input, dim=None, *, out=None)Tensor

Returns a tensor with all the dimensions of input of size 1 removed.

For example, if input is of shape: \((A \times 1 \times B \times C \times 1 \times D)\) then the out tensor will be of shape: \((A \times B \times C \times D)\).

When dim is given, a squeeze operation is done only in the given dimension. If input is of shape: \((A \times 1 \times B)\), squeeze(input, 0) leaves the tensor unchanged, but squeeze(input, 1) will squeeze the tensor to the shape \((A \times B)\).

Note

The returned tensor shares the storage with the input tensor, so changing the contents of one will change the contents of the other.

Warning

If the tensor has a batch dimension of size 1, then squeeze(input) will also remove the batch dimension, which can lead to unexpected errors.

Args:

input (Tensor): the input tensor. dim (int, optional): if given, the input will be squeezed only in

this dimension

Keyword args:

out (Tensor, optional): the output tensor.

Example:

>>> x = torch.zeros(2, 1, 2, 1, 2)
>>> x.size()
torch.Size([2, 1, 2, 1, 2])
>>> y = torch.squeeze(x)
>>> y.size()
torch.Size([2, 2, 2])
>>> y = torch.squeeze(x, 0)
>>> y.size()
torch.Size([2, 1, 2, 1, 2])
>>> y = torch.squeeze(x, 1)
>>> y.size()
torch.Size([2, 2, 1, 2])