HOME

GitHub bilibili twitter
View code on GitHub

``torch.nn.utils.clip_grad_norm 的 PyTorch 版实现。

概述
torch.nn.utils.clip_grad_norm 的 PyTorch 版实现。Related Link

import torch
from torch._six import inf
from typing import Union, Iterable

_tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]]


def clip_grad_norm_(parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.0) -> torch.Tensor:

将可训练参数的非空梯度保存到列表 grads 中。

    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]
    grads = [p.grad for p in parameters if p.grad is not None]

将 max_norm 和 norm_type 转换为 float 类型。

    max_norm = float(max_norm)
    norm_type = float(norm_type)
    device = grads[0].device

梯度的最大范数(max norm):$$\mathrm{total\_norm}^{\infty} = \max_{\theta_i\in \Theta} |\mathrm{grad}(\theta_i)|$$

    if norm_type == inf:
        norms = [g.detach().abs().max().to(device) for g in grads]
        total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms))

梯度的 p-范数(p-norm):$$\begin{split}\mathrm{total\_norm} &= (\sum_{\theta\in\Theta}((\sum_{\theta_i}\mathrm{grad}(\theta_i)^p)^\frac{1}{p})^p)^\frac{1}{p}\\&=(\sum_{\theta\in\Theta}(\sum_{\theta_i}\mathrm{grad}(\theta_i)^p))^\frac{1}{p}\end{split}$$

    else:
        total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type)

裁减系数(1e-6用于避免分母为零):$$\mathrm{clip\_coef} = \frac{\mathrm{max\_norm}}{\mathrm{total\_norm}}$$

    clip_coef = max_norm / (total_norm + 1e-6)

将系数的最大值固定为1

    clip_coef_clamped = torch.clamp(clip_coef, max=1.0)

如果 total_norm < max_norm,torch.clamp 操作将 clip_coef 的最大值固定为1,所以 clip_coef_clamped = 1,这样 total_norm 将保持不变。
如果 total_norm > max_norm,将对原来的梯度进行裁减,使得裁减后的梯度对应的 total_norm 的大小为 max_norm:$$\begin{split}\mathrm{total\_norm'}&=(\sum_{\theta\in\Theta}(\sum_{\theta}(\mathrm{grad}(\theta_i)\cdot\frac{\mathrm{max\_norm}}{\mathrm{total\_norm}})^p))^{\frac{1}{p}}\\&=\frac{(\sum_{\theta\in\Theta}(\sum_{\theta}\mathrm{grad}(\theta_i)^p))^{\frac{1}{p}}}{\mathrm{total\_norm}}\cdot\mathrm{max\_norm}\\&=\mathrm{max\_norm}\end{split}$$

    for g in grads:
        g.detach().mul_(clip_coef_clamped.to(g.device))
    return total_norm

概述
梯度正则化的测试函数。

def test_clip_grad_norm_():

设置相关超参数:batch size=4, action=32

    B, N = 4, 32

从随机分布中生成测试数据(类似回归任务):logit,label,实践中,logit 一般是网络的输出,并要求可以会传梯度。

    logit = torch.randn(B, N).requires_grad_(True)
    label = torch.randn(B, N)

定义损失函数并计算具体的数值。

    criterion = torch.nn.MSELoss()
    output = criterion(logit, label)

损失函数数值张量执行反向传播,计算相应的网络参数的梯度。

    output.backward()

根据梯度的 total_norm 对梯度进行裁减:
如果其 total_norm 超过 max_norm,则裁减并使得裁减后的梯度对应的 total_norm 的大小为 max_norm,否则不裁减。

    clip_grad_norm(logit, 0.5, 2)

测试裁减后的 total_norm 的大小是否在预期范围内。

    assert isinstance(logit.grad, torch.Tensor)
    grads = logit.grad
    total_norm = torch.norm(torch.stack([torch.norm(g.detach(), 2) for g in grads]), 2)
    assert total_norm < 0.5

如果读者关于本文档有任何问题和建议,可以在 GitHub 提 issue 或是直接发邮件给我们 (opendilab@pjlab.org.cn) 。