Source code for treetensor.torch.funcs.matrix

import torch

from .base import doc_from_base, func_treelize

__all__ = [
    'dot', 'matmul', 'mm',
]


# noinspection PyShadowingBuiltins
[docs]@doc_from_base() @func_treelize() def dot(input, other, *args, **kwargs): """ In ``treetensor``, you can get the dot product of 2 tree tensors with :func:`treetensor.torch.dot`. Examples:: >>> import torch >>> import treetensor.torch as ttorch >>> ttorch.dot(torch.tensor([1, 2]), torch.tensor([2, 3])) tensor(8) >>> ttorch.dot( ... ttorch.tensor({ ... 'a': [1, 2, 3], ... 'b': {'x': [3, 4]}, ... }), ... ttorch.tensor({ ... 'a': [5, 6, 7], ... 'b': {'x': [1, 2]}, ... }) ... ) <Tensor 0x7feac55bde50> ├── a --> tensor(38) └── b --> <Tensor 0x7feac55c9250> └── x --> tensor(11) """ return torch.dot(input, other, *args, **kwargs)
# noinspection PyShadowingBuiltins
[docs]@doc_from_base() @func_treelize() def matmul(input, other, *args, **kwargs): """ In ``treetensor``, you can create a matrix product with :func:`treetensor.torch.matmul`. Examples:: >>> import torch >>> import treetensor.torch as ttorch >>> ttorch.matmul( ... torch.tensor([[1, 2], [3, 4]]), ... torch.tensor([[5, 6], [7, 2]]), ... ) tensor([[19, 10], [43, 26]]) >>> ttorch.matmul( ... ttorch.tensor({ ... 'a': [[1, 2], [3, 4]], ... 'b': {'x': [3, 4, 5, 6]}, ... }), ... ttorch.tensor({ ... 'a': [[5, 6], [7, 2]], ... 'b': {'x': [4, 3, 2, 1]}, ... }), ... ) <Tensor 0x7f2e74883f40> ├── a --> tensor([[19, 10], │ [43, 26]]) └── b --> <Tensor 0x7f2e74886430> └── x --> tensor(40) """ return torch.matmul(input, other, *args, **kwargs)
# noinspection PyShadowingBuiltins
[docs]@doc_from_base() @func_treelize() def mm(input, mat2, *args, **kwargs): """ In ``treetensor``, you can create a matrix multiplication with :func:`treetensor.torch.mm`. Examples:: >>> import torch >>> import treetensor.torch as ttorch >>> ttorch.mm( ... torch.tensor([[1, 2], [3, 4]]), ... torch.tensor([[5, 6], [7, 2]]), ... ) tensor([[19, 10], [43, 26]]) >>> ttorch.mm( ... ttorch.tensor({ ... 'a': [[1, 2], [3, 4]], ... 'b': {'x': [[3, 4, 5], [6, 7, 8]]}, ... }), ... ttorch.tensor({ ... 'a': [[5, 6], [7, 2]], ... 'b': {'x': [[6, 5], [4, 3], [2, 1]]}, ... }), ... ) <Tensor 0x7f2e7489f340> ├── a --> tensor([[19, 10], │ [43, 26]]) └── b --> <Tensor 0x7f2e74896e50> └── x --> tensor([[44, 32], [80, 59]]) """ return torch.mm(input, mat2, *args, **kwargs)