max

Documentation

treetensor.torch.max(input, *args, reduce=None, **kwargs)[source]

In treetensor, you can get the max result of a whole tree with this function.

Example:

>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.max(torch.tensor([1.0, 2.0, 1.5]))  # the same as torch.max
tensor(2.)

>>> ttorch.max(ttorch.tensor({
...     'a': [1.0, 2.0, 1.5],
...     'b': {'x': [[1.8, 0.9], [1.3, 2.5]]},
... }))
tensor(2.5000)

>>> ttorch.max(ttorch.tensor({
...     'a': [1.0, 2.0, 1.5],
...     'b': {'x': [[1.8, 0.9], [1.3, 2.5]]},
... }), reduce=False)
<Tensor 0x7fd45b52d940>
├── a --> tensor(2.)
└── b --> <Tensor 0x7fd45b52d908>
    └── x --> tensor(2.5000)

>>> ttorch.max(ttorch.tensor({
...     'a': [1.0, 2.0, 1.5],
...     'b': {'x': [[1.8, 0.9], [1.3, 2.5]]},
... }), dim=0)
torch.return_types.max(
values=<Tensor 0x7fd45b5345f8>
├── a --> tensor(2.)
└── b --> <Tensor 0x7fd45b5345c0>
    └── x --> tensor([1.8000, 2.5000])
,
indices=<Tensor 0x7fd45b5346d8>
├── a --> tensor(1)
└── b --> <Tensor 0x7fd45b5346a0>
    └── x --> tensor([0, 1])
)

Torch Version Related

This documentation is based on torch.max in torch v2.4.1+cu121. Its arguments’ arrangements depend on the version of pytorch you installed.

If some arguments listed here are not working properly, please check your pytorch’s version with the following command and find its documentation.

1
python -c 'import torch;print(torch.__version__)'

The arguments and keyword arguments supported in torch v2.4.1+cu121 is listed below.

Description From Torch v2.4.1+cu121

torch.max(input)Tensor

Returns the maximum value of all elements in the input tensor.

Warning

This function produces deterministic (sub)gradients unlike max(dim=0)

Args:

input (Tensor): the input tensor.

Example:

>>> a = torch.randn(1, 3)
>>> a
tensor([[ 0.6763,  0.7445, -2.2369]])
>>> torch.max(a)
tensor(0.7445)
torch.max(input, dim, keepdim=False, *, out=None)

Returns a namedtuple (values, indices) where values is the maximum value of each row of the input tensor in the given dimension dim. And indices is the index location of each maximum value found (argmax).

If keepdim is True, the output tensors are of the same size as input except in the dimension dim where they are of size 1. Otherwise, dim is squeezed (see torch.squeeze()), resulting in the output tensors having 1 fewer dimension than input.

Note

If there are multiple maximal values in a reduced row then the indices of the first maximal value are returned.

Args:

input (Tensor): the input tensor. dim (int): the dimension to reduce. keepdim (bool): whether the output tensor has dim retained or not. Default: False.

Keyword args:

out (tuple, optional): the result tuple of two output tensors (max, max_indices)

Example:

>>> a = torch.randn(4, 4)
>>> a
tensor([[-1.2360, -0.2942, -0.1222,  0.8475],
        [ 1.1949, -1.1127, -2.2379, -0.6702],
        [ 1.5717, -0.9207,  0.1297, -1.8768],
        [-0.6172,  1.0036, -0.6060, -0.2432]])
>>> torch.max(a, 1)
torch.return_types.max(values=tensor([0.8475, 1.1949, 1.5717, 1.0036]), indices=tensor([3, 0, 0, 1]))
torch.max(input, other, *, out=None)Tensor

See torch.maximum().