grad¶
Documentation¶
-
class
treetensor.torch.
Tensor
(data, *args, constraint=None, **kwargs)[source]¶ -
property
grad
¶ Return the grad data of the whole tree.
Examples:
>>> import torch >>> import treetensor.torch as ttorch >>> tt = ttorch.randn({ ... 'a': (2, 3), ... 'b': {'x': (3, 4)}, ... }) >>> tt.requires_grad_(True) >>> tt <Tensor 0x7feec3bcce80> ├── a --> tensor([[-1.4375, 0.0988, 1.2198], │ [-0.7627, -0.8797, -0.9299]], requires_grad=True) └── b --> <Tensor 0x7feec3bccdd8> └── x --> tensor([[ 0.2149, -0.5839, -0.6049, -0.9151], [ 1.5381, -1.4386, 0.1831, 0.2018], [-0.0725, -0.9062, -2.6212, 0.5929]], requires_grad=True) >>> mq = tt.mean() ** 2 >>> mq.backward() >>> tt.grad <Tensor 0x7feec3c0fa90> ├── a --> tensor([[-0.0438, -0.0438, -0.0438], │ [-0.0438, -0.0438, -0.0438]]) └── b --> <Tensor 0x7feec3c0f9e8> └── x --> tensor([[-0.0438, -0.0438, -0.0438, -0.0438], [-0.0438, -0.0438, -0.0438, -0.0438], [-0.0438, -0.0438, -0.0438, -0.0438]])
-
property