shape

Documentation

class treetensor.torch.Tensor(data, *args, constraint=None, **kwargs)[source]
property shape

Get the size of the tensors in the tree.

Example:

>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.tensor({
...     'a': [[1, 11], [2, 22], [3, 33]],
...     'b': {'x': [[4, 5], [6, 7]]},
... }).shape
<Size 0x7ff363bbbd68>
├── a --> torch.Size([3, 2])
└── b --> <Size 0x7ff363bbbcf8>
    └── x --> torch.Size([2, 2])