grad

Documentation

class treetensor.torch.Tensor(data, *args, **kwargs)[source]
property grad

Return the grad data of the whole tree.

Examples:

>>> import torch
>>> import treetensor.torch as ttorch
>>> tt = ttorch.randn({
...     'a': (2, 3),
...     'b': {'x': (3, 4)},
... })
>>> tt.requires_grad_(True)
>>> tt
<Tensor 0x7feec3bcce80>
├── a --> tensor([[-1.4375,  0.0988,  1.2198],
│                 [-0.7627, -0.8797, -0.9299]], requires_grad=True)
└── b --> <Tensor 0x7feec3bccdd8>
    └── x --> tensor([[ 0.2149, -0.5839, -0.6049, -0.9151],
                      [ 1.5381, -1.4386,  0.1831,  0.2018],
                      [-0.0725, -0.9062, -2.6212,  0.5929]], requires_grad=True)
>>> mq = tt.mean() ** 2
>>> mq.backward()
>>> tt.grad
<Tensor 0x7feec3c0fa90>
├── a --> tensor([[-0.0438, -0.0438, -0.0438],
│                 [-0.0438, -0.0438, -0.0438]])
└── b --> <Tensor 0x7feec3c0f9e8>
    └── x --> tensor([[-0.0438, -0.0438, -0.0438, -0.0438],
                      [-0.0438, -0.0438, -0.0438, -0.0438],
                      [-0.0438, -0.0438, -0.0438, -0.0438]])