requires_grad

Documentation

class treetensor.torch.Tensor(data, *args, constraint=None, **kwargs)[source]
property requires_grad

Return the grad situation of current tree.

Examples:

>>> import torch
>>> import treetensor.torch as ttorch
>>> tt = ttorch.randn({
...     'a': (2, 3),
...     'b': {'x': (3, 4)},
... })
>>> tt.requires_grad_(True)
>>> tt.requires_grad
<Object 0x7feec3c229e8>
├── a --> True
└── b --> <Object 0x7feec3c22940>
    └── x --> True

>>> tt.a.requires_grad_(False)
>>> tt.requires_grad
<Object 0x7feec3c0fa58>
├── a --> False
└── b --> <Object 0x7feec3c0f5f8>
    └── x --> True