Overview
Implementation of grad_ignore_norm Related Link
Different from grad_clip_norm , grad_ignore_norm ignore those gradients that have a norm exceeds the specified threshold, instead of cliping their norm to the threshold.
import torch
from torch._six import inf
from typing import Union, Iterable
_tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]]
def grad_ignore_norm_(parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.0) -> torch.Tensor:
Save the parameters with non-empty gradient into a list.
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters))
Convert max_norm and norm_type to float.
max_norm = float(max_norm)
norm_type = float(norm_type)
device = parameters[0].grad.device
The max norm of gradient: $$\mathrm{total\_norm}^{\infty} = \max_{\theta_i\in \Theta} |\mathrm{grad}(\theta_i)|$$
if norm_type == inf:
norms = [p.grad.detach().abs().max().to(device) for p in parameters]
total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms))
The p-norm of gradient: $$\begin{split}\mathrm{total\_norm} &= (\sum_{\theta\in\Theta}((\sum_{\theta_i}\mathrm{grad}(\theta_i)^p)^\frac{1}{p})^p)^\frac{1}{p}\\&=(\sum_{\theta\in\Theta}(\sum_{\theta_i}\mathrm{grad}(\theta_i)^p))^\frac{1}{p}\end{split}$$
else:
total_norm = torch.norm(
torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type
)
The clip coefficient (the 1e-6 is used to avoid zero in the denominator): $$\mathrm{clip\_coef} = \frac{\mathrm{max\_norm}}{\mathrm{total\_norm}}$$
clip_coef = max_norm / (total_norm + 1e-6)
If total_norm > max_norm, all the gradient is clipped to zero.
if clip_coef < 1:
for p in parameters:
p.grad.zero_()
return total_norm
Overview
Implementation of grad_ignore_value Related Link
Different from grad_clip_value , grad_ignore_value ignore all the gradients when any of them exceeds the specified threshold, instead of cliping them to the threshold.
def grad_ignore_value_(parameters: _tensor_or_tensors, clip_value: float) -> None:
Save the parameters with non-empty gradient into a list.
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters))
Convert clip_value to float.
clip_value = float(clip_value)
flag = False
Check if there is any gradient that exceeds the clip_value.
for p in parameters:
val = p.grad.data.abs().max()
if val >= clip_value:
flag = True
break
If there exists a gradient that exceeds the clip_value, then clip all the gradients to zero.
if flag:
for p in parameters:
p.grad.data.zero_()
Overview
Test function of grad ignore norm.
def test_grad_ignore_norm_():
batch size=4, action=32
B, N = 4, 32
Generate logit and label.
logit = torch.randn(B, N).requires_grad_(True)
label = torch.randn(B, N)
Define criterion and compute loss.
criterion = torch.nn.MSELoss()
output = criterion(logit, label)
Loss backward and compute gradients.
output.backward()
Set a gradient that exceeds the threshold.
logit.grad[0] = 0.5
Clip the gradient.
grad_ignore_norm(logit, 0.5, 2)
Assert that all gradients are clipped to zero.
assert isinstance(logit.grad, torch.Tensor)
for g in logit.grad:
assert (g == 0).all()
Overview
Test function of grad ignore clip.
def test_grad_ignore_value_():
batch size=4, action=32
B, N = 4, 32
Set clip_value as 0.5.
clip_value = 0.5
Generate logit and label.
logit = torch.randn(B, N).requires_grad_(True)
label = torch.randn(B, N)
Define criterion and compute loss.
criterion = torch.nn.MSELoss()
output = criterion(logit, label)
Loss backward and compute gradients.
output.backward()
Set a gradient that exceeds the threshold.
logit.grad[0] = 0.6
Clip the gradient
grad_ignore_value(logit, clip_value)
Assert that all gradients are clipped to zero.
assert isinstance(logit.grad, torch.Tensor)
for g in logit.grad:
assert (g == 0).all()
If you have any questions or advices about this documation, you can raise issues in GitHub (https://github.com/opendilab/PPOxFamily) or email us (opendilab@pjlab.org.cn).