Source code for treetensor.torch.funcs.operation

import torch
from hbutils.reflection import post_process
from treevalue import TreeValue

from .base import doc_from_base, func_treelize, auto_tensor

__all__ = [
    'cat', 'split', 'chunk', 'stack',
    'reshape', 'where', 'squeeze', 'unsqueeze',
    'index_select',
]


[docs]@doc_from_base() @func_treelize(subside=True) def cat(tensors, *args, **kwargs): """ Concatenates the given sequence of ``seq`` tensors in the given dimension. All tensors must either have the same shape (except in the concatenating dimension) or be empty. Examples: >>> import torch >>> import treetensor.torch as ttorch >>> t1 = torch.randint(10, 30, (2, 3)) >>> t1 tensor([[21, 29, 17], [16, 11, 16]]) >>> t2 = torch.randint(30, 50, (2, 3)) tensor([[46, 46, 46], [30, 47, 36]]) >>> t2 >>> t3 = torch.randint(50, 70, (2, 3)) tensor([[51, 65, 65], [54, 67, 57]]) >>> t3 >>> ttorch.cat((t1, t2, t3)) tensor([[21, 29, 17], [16, 11, 16], [46, 46, 46], [30, 47, 36], [51, 65, 65], [54, 67, 57]]) >>> tt1 = ttorch.Tensor({ ... 'a': t1, ... 'b': {'x': t2, 'y': t3}, ... }) >>> tt1 <Tensor 0x7fed579acf60> ├── a --> tensor([[21, 29, 17], │ [16, 11, 16]]) └── b --> <Tensor 0x7fed579acf28> ├── x --> tensor([[46, 46, 46], │ [30, 47, 36]]) └── y --> tensor([[51, 65, 65], [54, 67, 57]]) >>> tt2 = ttorch.Tensor({ ... 'a': t2, ... 'b': {'x': t3, 'y': t1}, ... }) >>> tt2 <Tensor 0x7fed579d62e8> ├── a --> tensor([[46, 46, 46], │ [30, 47, 36]]) └── b --> <Tensor 0x7fed579d62b0> ├── x --> tensor([[51, 65, 65], │ [54, 67, 57]]) └── y --> tensor([[21, 29, 17], [16, 11, 16]]) >>> tt3 = ttorch.Tensor({ ... 'a': t3, ... 'b': {'x': t1, 'y': t2}, ... }) >>> tt3 <Tensor 0x7fed579d66a0> ├── a --> tensor([[51, 65, 65], │ [54, 67, 57]]) └── b --> <Tensor 0x7fed579d65f8> ├── x --> tensor([[21, 29, 17], │ [16, 11, 16]]) └── y --> tensor([[46, 46, 46], [30, 47, 36]] >>> ttorch.cat((tt1, tt2, tt3)) <Tensor 0x7fed579d6ac8> ├── a --> tensor([[21, 29, 17], │ [16, 11, 16], │ [46, 46, 46], │ [30, 47, 36], │ [51, 65, 65], │ [54, 67, 57]]) └── b --> <Tensor 0x7fed579d6a90> ├── x --> tensor([[46, 46, 46], │ [30, 47, 36], │ [51, 65, 65], │ [54, 67, 57], │ [21, 29, 17], │ [16, 11, 16]]) └── y --> tensor([[51, 65, 65], [54, 67, 57], [21, 29, 17], [16, 11, 16], [46, 46, 46], [30, 47, 36]]) >>> ttorch.cat((tt1, tt2, tt3), dim=1) <Tensor 0x7fed579644a8> ├── a --> tensor([[21, 29, 17, 46, 46, 46, 51, 65, 65], │ [16, 11, 16, 30, 47, 36, 54, 67, 57]]) └── b --> <Tensor 0x7fed57964438> ├── x --> tensor([[46, 46, 46, 51, 65, 65, 21, 29, 17], │ [30, 47, 36, 54, 67, 57, 16, 11, 16]]) └── y --> tensor([[51, 65, 65, 21, 29, 17, 46, 46, 46], [54, 67, 57, 16, 11, 16, 30, 47, 36]]) """ return torch.cat(tensors, *args, **kwargs)
# noinspection PyShadowingNames
[docs]@doc_from_base() @post_process(auto_tensor) @func_treelize(return_type=TreeValue, rise=True) def split(tensor, split_size_or_sections, *args, **kwargs): """ Splits the tensor into chunks. Each chunk is a view of the original tensor. Examples:: >>> import torch >>> import treetensor.torch as ttorch >>> t1 = torch.randint(100, (6, 2)) >>> t1 tensor([[59, 82], [86, 42], [71, 84], [61, 58], [82, 37], [14, 31]]) >>> ttorch.split(t1, (1, 2, 3)) (tensor([[59, 82]]), tensor([[86, 42], [71, 84]]), tensor([[61, 58], [82, 37], [14, 31]])) >>> tt1 = ttorch.randint(100, { ... 'a': (6, 2), ... 'b': {'x': (6, 2, 3)}, ... }) >>> tt1 <Tensor 0x7f4c8d786400> ├── a --> tensor([[ 1, 65], │ [68, 31], │ [76, 73], │ [74, 76], │ [90, 0], │ [95, 89]]) └── b --> <Tensor 0x7f4c8d786320> └── x --> tensor([[[11, 20, 74], [17, 85, 44]], [[67, 37, 89], [76, 28, 0]], [[56, 12, 7], [17, 63, 32]], [[81, 75, 19], [89, 21, 55]], [[71, 53, 0], [66, 82, 57]], [[73, 81, 11], [58, 54, 78]]]) >>> ttorch.split(tt1, (1, 2, 3)) (<Tensor 0x7f4c8d7861d0> ├── a --> tensor([[ 1, 65]]) └── b --> <Tensor 0x7f4c8d786128> └── x --> tensor([[[11, 20, 74], [17, 85, 44]]]) , <Tensor 0x7f4c8d7860f0> ├── a --> tensor([[68, 31], │ [76, 73]]) └── b --> <Tensor 0x7f4c8d7860b8> └── x --> tensor([[[67, 37, 89], [76, 28, 0]], [[56, 12, 7], [17, 63, 32]]]) , <Tensor 0x7f4c8d7866d8> ├── a --> tensor([[74, 76], │ [90, 0], │ [95, 89]]) └── b --> <Tensor 0x7f4c8d786668> └── x --> tensor([[[81, 75, 19], [89, 21, 55]], [[71, 53, 0], [66, 82, 57]], [[73, 81, 11], [58, 54, 78]]]) ) """ return torch.split(tensor, split_size_or_sections, *args, **kwargs)
# noinspection PyShadowingBuiltins
[docs]@doc_from_base() @post_process(auto_tensor) @func_treelize(return_type=TreeValue, rise=True) def chunk(input, chunks, *args, **kwargs): """ Splits a tensor into a specific number of chunks. Each chunk is a view of the input tensor. Examples:: >>> import torch >>> import treetensor.torch as ttorch >>> t = torch.randint(100, (4, 5)) >>> t tensor([[54, 97, 12, 48, 62], [92, 87, 28, 53, 54], [65, 82, 40, 26, 61], [75, 43, 86, 99, 7]]) >>> ttorch.chunk(t, 2) (tensor([[54, 97, 12, 48, 62], [92, 87, 28, 53, 54]]), tensor([[65, 82, 40, 26, 61], [75, 43, 86, 99, 7]])) >>> tt = ttorch.randint(100, { ... 'a': (4, 5), ... 'b': {'x': (2, 3, 4)}, ... }) >>> tt <Tensor 0x7f667e2fb358> ├── a --> tensor([[80, 2, 15, 45, 48], │ [38, 89, 34, 10, 34], │ [18, 99, 33, 38, 20], │ [43, 21, 35, 43, 37]]) └── b --> <Tensor 0x7f667e2fb278> └── x --> tensor([[[19, 17, 39, 68], [41, 69, 33, 89], [31, 88, 39, 14]], [[27, 81, 84, 35], [29, 65, 17, 72], [53, 50, 75, 0]]]) >>> ttorch.chunk(tt, 2) (<Tensor 0x7f667e9b7eb8> ├── a --> tensor([[80, 2, 15, 45, 48], │ [38, 89, 34, 10, 34]]) └── b --> <Tensor 0x7f667e2e7cf8> └── x --> tensor([[[19, 17, 39, 68], [41, 69, 33, 89], [31, 88, 39, 14]]]) , <Tensor 0x7f66f176dac8> ├── a --> tensor([[18, 99, 33, 38, 20], │ [43, 21, 35, 43, 37]]) └── b --> <Tensor 0x7f668030ba58> └── x --> tensor([[[27, 81, 84, 35], [29, 65, 17, 72], [53, 50, 75, 0]]]) """ return torch.chunk(input, chunks, *args, **kwargs)
[docs]@doc_from_base() @func_treelize(subside=True) def stack(tensors, *args, **kwargs): """ Concatenates a sequence of tensors along a new dimension. Examples:: >>> import torch >>> import treetensor.torch as ttorch >>> t1 = torch.randint(10, 30, (2, 3)) >>> t1 tensor([[17, 15, 27], [12, 17, 29]]) >>> t2 = torch.randint(30, 50, (2, 3)) >>> t2 tensor([[45, 41, 47], [37, 37, 36]]) >>> t3 = torch.randint(50, 70, (2, 3)) >>> t3 tensor([[60, 50, 55], [69, 54, 58]]) >>> ttorch.stack((t1, t2, t3)) tensor([[[17, 15, 27], [12, 17, 29]], [[45, 41, 47], [37, 37, 36]], [[60, 50, 55], [69, 54, 58]]]) >>> tt1 = ttorch.randint(10, 30, { ... 'a': (2, 3), ... 'b': {'x': (3, 4)}, ... }) >>> tt1 <Tensor 0x7f4c8eba9630> ├── a --> tensor([[25, 22, 29], │ [19, 21, 27]]) └── b --> <Tensor 0x7f4c8eba9550> └── x --> tensor([[20, 17, 28, 10], [28, 16, 27, 27], [18, 21, 17, 12]]) >>> tt2 = ttorch.randint(30, 50, { ... 'a': (2, 3), ... 'b': {'x': (3, 4)}, ... }) >>> tt2 <Tensor 0x7f4c8eba97b8> ├── a --> tensor([[40, 44, 41], │ [39, 44, 40]]) └── b --> <Tensor 0x7f4c8eba9710> └── x --> tensor([[44, 42, 38, 44], [30, 44, 42, 31], [36, 30, 33, 31]]) >>> ttorch.stack((tt1, tt2)) <Tensor 0x7f4c8eb411d0> ├── a --> tensor([[[25, 22, 29], │ [19, 21, 27]], │ [[40, 44, 41], │ [39, 44, 40]]]) └── b --> <Tensor 0x7f4c8eb410b8> └── x --> tensor([[[20, 17, 28, 10], [28, 16, 27, 27], [18, 21, 17, 12]], [[44, 42, 38, 44], [30, 44, 42, 31], [36, 30, 33, 31]]]) >>> ttorch.stack((tt1, tt2), dim=1) <Tensor 0x7f4c8eba9da0> ├── a --> tensor([[[25, 22, 29], │ [40, 44, 41]], │ [[19, 21, 27], │ [39, 44, 40]]]) └── b --> <Tensor 0x7f4d01fb4898> └── x --> tensor([[[20, 17, 28, 10], [44, 42, 38, 44]], [[28, 16, 27, 27], [30, 44, 42, 31]], [[18, 21, 17, 12], [36, 30, 33, 31]]]) """ return torch.stack(tensors, *args, **kwargs)
# noinspection PyShadowingBuiltins
[docs]@doc_from_base() @func_treelize() def reshape(input, shape): """ Returns a tensor with the same data and number of elements as ``input``, but with the specified shape. When possible, the returned tensor will be a view of ``input``. Examples:: >>> import torch >>> import treetensor.torch as ttorch >>> ttorch.reshape(torch.tensor([[1, 2], [3, 4]]), (-1, )) tensor([1, 2, 3, 4]) >>> ttorch.reshape(ttorch.tensor({ ... 'a': [[1, 2], [3, 4]], ... 'b': {'x': [[2], [3], [5], [7], [11], [13]]}, ... }), (-1, )) <Tensor 0x7fc9efa3bda0> ├── a --> tensor([1, 2, 3, 4]) └── b --> <Tensor 0x7fc9efa3bcf8> └── x --> tensor([ 2, 3, 5, 7, 11, 13]) .. note:: If the given ``shape`` is only one tuple, it should make sure that all the tensors in this tree can be reshaped to the given ``shape``. Or you can give a tree of tuples to reshape the tensors to different shapes. >>> import torch >>> import treetensor.torch as ttorch >>> ttorch.reshape(ttorch.tensor({ ... 'a': [[1, 2], [3, 4]], ... 'b': {'x': [[2], [3], [5], [7], [11], [13]]}, ... }), {'a': (4, ), 'b': {'x': (3, 2)}}) <Tensor 0x7fc9efa3bd68> ├── a --> tensor([1, 2, 3, 4]) └── b --> <Tensor 0x7fc9efa3bf28> └── x --> tensor([[ 2, 3], [ 5, 7], [11, 13]]) """ return torch.reshape(input, shape)
# noinspection PyShadowingBuiltins
[docs]@doc_from_base() @func_treelize() def squeeze(input, *args, **kwargs): """ Returns a tensor with all the dimensions of ``input`` of size 1 removed. Examples:: >>> import torch >>> import treetensor.torch as ttorch >>> t1 = torch.randint(100, (2, 1, 2, 1, 2)) >>> t1.shape torch.Size([2, 1, 2, 1, 2]) >>> ttorch.squeeze(t1).shape torch.Size([2, 2, 2]) >>> tt1 = ttorch.randint(100, { ... 'a': (2, 1, 2, 1, 2), ... 'b': {'x': (2, 1, 1, 3)}, ... }) >>> tt1.shape <Size 0x7fa4c1b05410> ├── a --> torch.Size([2, 1, 2, 1, 2]) └── b --> <Size 0x7fa4c1b05510> └── x --> torch.Size([2, 1, 1, 3]) >>> ttorch.squeeze(tt1).shape <Size 0x7fa4c1b9f3d0> ├── a --> torch.Size([2, 2, 2]) └── b --> <Size 0x7fa4c1afe710> └── x --> torch.Size([2, 3]) """ return torch.squeeze(input, *args, *kwargs)
# noinspection PyShadowingBuiltins
[docs]@doc_from_base() @func_treelize() def unsqueeze(input, dim): """ Returns a new tensor with a dimension of size one inserted at the specified position. Examples:: >>> import torch >>> import treetensor.torch as ttorch >>> t1 = torch.randint(100, (100, )) >>> t1.shape torch.Size([100]) >>> ttorch.unsqueeze(t1, 0).shape torch.Size([1, 100]) >>> tt1 = ttorch.randint(100, { ... 'a': (2, 2, 2), ... 'b': {'x': (2, 3)}, ... }) >>> tt1.shape <Size 0x7f5d1a5741d0> ├── a --> torch.Size([2, 2, 2]) └── b --> <Size 0x7f5d1a5740b8> └── x --> torch.Size([2, 3]) >>> ttorch.unsqueeze(tt1, 1).shape <Size 0x7f5d1a5c98d0> ├── a --> torch.Size([2, 1, 2, 2]) └── b --> <Size 0x7f5d1a5c99b0> └── x --> torch.Size([2, 1, 3]) """ return torch.unsqueeze(input, dim)
[docs]@doc_from_base() @func_treelize() def where(condition, x, y): """ Return a tree of tensors of elements selected from either ``x`` or ``y``, depending on ``condition``. Examples:: >>> import torch >>> import treetensor.torch as ttorch >>> ttorch.where( ... torch.tensor([[True, False], [False, True]]), ... torch.tensor([[2, 8], [16, 4]]), ... torch.tensor([[3, 11], [5, 7]]), ... ) tensor([[ 2, 11], [ 5, 4]]) >>> tt1 = ttorch.randint(1, 99, {'a': (2, 3), 'b': {'x': (3, 2, 4)}}) >>> tt1 <Tensor 0x7f6760ad9908> ├── a --> tensor([[27, 90, 80], │ [12, 59, 5]]) └── b --> <Tensor 0x7f6760ad9860> └── x --> tensor([[[71, 52, 92, 79], [48, 4, 13, 96]], [[72, 89, 44, 62], [32, 4, 29, 76]], [[ 6, 3, 93, 89], [44, 89, 85, 90]]]) >>> ttorch.where(tt1 % 2 == 1, tt1, 0) <Tensor 0x7f6760ad9d30> ├── a --> tensor([[27, 0, 0], │ [ 0, 59, 5]]) └── b --> <Tensor 0x7f6760ad9f98> └── x --> tensor([[[71, 0, 0, 79], [ 0, 0, 13, 0]], [[ 0, 89, 0, 0], [ 0, 0, 29, 0]], [[ 0, 3, 93, 89], [ 0, 89, 85, 0]]]) """ return torch.where(condition, x, y)
# noinspection PyShadowingBuiltins
[docs]@doc_from_base() @func_treelize() def index_select(input, dim, index, *args, **kwargs): """ Returns a new tensor which indexes the ``input`` tensor along dimension ``dim`` using the entries in ``index`` which is a LongTensor. Examples:: >>> import torch >>> import treetensor.torch as ttorch >>> t = torch.randn(3, 4) >>> t tensor([[ 0.2247, -0.1441, -1.2249, -0.2738], [-0.1496, -0.4883, -1.2442, 0.6374], [ 0.8017, 1.1220, -2.1013, -0.5951]]) >>> ttorch.index_select(t, 1, torch.tensor([1, 2])) tensor([[-0.1441, -1.2249], [-0.4883, -1.2442], [ 1.1220, -2.1013]]) >>> tt = ttorch.randn({ ... 'a': (3, 4), ... 'b': {'x': (5, 6)}, ... }) >>> tt <Tensor 0x7f6b636c1cf8> ├── a --> tensor([[ 3.9724e-05, -3.3134e-01, -1.0441e+00, 7.9233e-01], │ [-1.0035e-01, 2.3422e+00, 1.9307e+00, -1.7215e-01], │ [ 1.9069e+00, 1.1852e+00, -1.0672e+00, 1.3463e+00]]) └── b --> <Tensor 0x7f6b636c1be0> └── x --> tensor([[ 0.5200, -0.3595, -1.4235, -0.2655, 0.9504, -1.7564], [-1.6577, -0.5516, 0.1660, -2.3273, -0.9811, -0.4677], [ 0.7047, -1.6920, 0.3139, 0.6220, 0.4758, -1.2637], [-0.3945, -2.1694, 0.8404, -0.4224, -1.4819, 0.3998], [-0.0308, 0.9777, -0.7776, -0.0101, -1.0446, -1.1500]]) >>> ttorch.index_select(tt, 1, torch.tensor([1, 2])) <Tensor 0x7f6b636c1f28> ├── a --> tensor([[-0.3313, -1.0441], │ [ 2.3422, 1.9307], │ [ 1.1852, -1.0672]]) └── b --> <Tensor 0x7f6b636c1e80> └── x --> tensor([[-0.3595, -1.4235], [-0.5516, 0.1660], [-1.6920, 0.3139], [-2.1694, 0.8404], [ 0.9777, -0.7776]]) .. note:: If you need to select different indices in the tensors, just do like this. >>> ttorch.index_select(tt, 1, ttorch.tensor({'a': [1, 2], 'b': {'x': [1, 3, 5]}})) <Tensor 0x7f6b636dbf60> ├── a --> tensor([[-0.3313, -1.0441], │ [ 2.3422, 1.9307], │ [ 1.1852, -1.0672]]) └── b --> <Tensor 0x7f6b636dbe80> └── x --> tensor([[-0.3595, -0.2655, -1.7564], [-0.5516, -2.3273, -0.4677], [-1.6920, 0.6220, -1.2637], [-2.1694, -0.4224, 0.3998], [ 0.9777, -0.0101, -1.1500]]) """ return torch.index_select(input, dim, index, *args, **kwargs)