Source code for lightrft.strategy.fsdp.fsdp_utils
"""
Gradient scaling and optimization utilities for deep learning.
This module provides tools for gradient handling, norm computation, and optimization in PyTorch.
It includes dynamic gradient scaling for mixed precision training, gradient norm computation,
and base classes for optimizers.
"""
import math
from abc import ABC, abstractmethod
from typing import Dict, Optional, Union
import torch
import torch.distributed as dist
from torch import Tensor
from torch.optim import Optimizer
from lightrft.utils import get_current_device
try:
import amp_C
from apex.multi_tensor_apply import multi_tensor_applier
APEX_AVAILABLE = True
except (ModuleNotFoundError, ImportError):
print("The torch implementation for cal_l2norm is slower than apex. Please note this!")
APEX_AVAILABLE = False
inf = math.inf
[docs]def is_meta_initialized(model) -> bool:
"""
Check if a PyTorch model's parameters are meta-initialized.
Meta-initialized models contain parameters on a 'meta' device, which are placeholders
that don't allocate actual memory. These are useful for model initialization without
memory overhead, commonly used in model parallelism and large model initialization.
For more information on meta device and meta tensors, see:
https://docs.pytorch.org/docs/stable/meta.html
:param model: The PyTorch module to check.
:type model: torch.nn.Module
:raises TypeError: if ``model`` is not an instance of :class:`torch.nn.Module`.
:returns: True if any parameter in the model is on a meta device, False otherwise.
:rtype: bool
Example::
>>> import torch
>>> model = torch.nn.Linear(10, 1)
>>> is_meta_initialized(model) # False for regular model
False
>>> with torch.device('meta'):
... meta_model = torch.nn.Linear(10, 1)
>>> is_meta_initialized(meta_model) # True for meta model
True
"""
for param in model.parameters():
if hasattr(param, "device") and param.device.type == "meta":
return True
return False
[docs]def multi_tensor_l2norm_torch(tensor_list, per_tensor):
"""
Compute L2 norm of multiple tensors using PyTorch.
This function provides a pure PyTorch implementation for computing L2 norms
when APEX is not available. It converts all tensors to float32 for computation
and returns both overall and per-tensor norms.
:param tensor_list: List of tensors to compute norm for
:type tensor_list: list[torch.Tensor]
:param per_tensor: Whether to return per-tensor norms
:type per_tensor: bool
:return: Tuple of (overall L2 norm, per-tensor norms)
:rtype: tuple[torch.Tensor, torch.Tensor]
Example::
>>> tensors = [torch.randn(3, 3), torch.randn(2, 2)]
>>> overall_norm, per_tensor_norms = multi_tensor_l2norm_torch(tensors, True)
>>> print(overall_norm.shape) # torch.Size([1])
>>> print(per_tensor_norms.shape) # torch.Size([2])
"""
# Convert tensor_list elements to torch.float32
tensor_list = [tensor.float() for tensor in tensor_list]
norms_tensor = torch.stack([torch.norm(tensor, p=2) for tensor in tensor_list])
l2_norm = torch.norm(norms_tensor, p=2).unsqueeze(0)
if per_tensor:
per_tensor_norm = norms_tensor
else:
per_tensor_norm = torch.Tensor([]).to(norms_tensor.device)
return l2_norm, per_tensor_norm
[docs]def calc_l2_norm(grads):
"""
Calculate L2 norm of gradients using optimized implementation when available.
This function automatically selects the fastest available implementation for
computing L2 norms. It uses APEX's multi-tensor operations when available
for better performance, otherwise falls back to PyTorch implementation.
:param grads: List of gradient tensors
:type grads: list[torch.Tensor]
:return: L2 norm of gradients
:rtype: torch.Tensor
Example::
>>> grads = [torch.randn(10, 10).requires_grad_(), torch.randn(5, 5).requires_grad_()]
>>> norm = calc_l2_norm(grads)
>>> print(norm.item()) # scalar value
"""
norm = 0.0
if len(grads) > 0:
if APEX_AVAILABLE:
dummy_overflow_buf = torch.tensor([0], device=get_current_device(), dtype=torch.int32)
norm, _ = multi_tensor_applier(
amp_C.multi_tensor_l2norm,
dummy_overflow_buf,
[grads],
False, # no per-parameter norm
)
else:
norm, _ = multi_tensor_l2norm_torch(grads, False)
return norm
[docs]def calc_lp(grads, norm_type):
"""
Calculate Lp norm of gradients.
Computes the p-norm of a list of gradient tensors, where p is specified
by norm_type. This is useful for gradient clipping and monitoring.
:param grads: List of gradient tensors
:type grads: list[torch.Tensor]
:param norm_type: The p in Lp norm
:type norm_type: float
:return: Lp norm of gradients
:rtype: float
Example::
>>> grads = [torch.randn(3, 3), torch.randn(2, 2)]
>>> l1_norm = calc_lp(grads, 1.0) # L1 norm
>>> l2_norm = calc_lp(grads, 2.0) # L2 norm
"""
norm = 0.0
for grad in grads:
grad_norm = torch.norm(grad, norm_type)
norm += grad_norm ** norm_type
return norm
[docs]def get_norm(grads, norm_type, enable_cuda_kernels):
"""
Get norm of gradients with specified norm type.
This function dispatches to the appropriate norm calculation method based
on the norm type and whether CUDA kernels are available. It handles special
cases like infinity norm and optimized L2 norm computation.
:param grads: List of gradient tensors
:type grads: list[torch.Tensor]
:param norm_type: Type of norm to compute (2.0, inf, etc.)
:type norm_type: float
:param enable_cuda_kernels: Whether to use CUDA optimized kernels
:type enable_cuda_kernels: bool
:return: Norm of gradients
:rtype: float or torch.Tensor
Example::
>>> grads = [torch.randn(3, 3), torch.randn(2, 2)]
>>> l2_norm = get_norm(grads, 2.0, True)
>>> inf_norm = get_norm(grads, float('inf'), True)
"""
if norm_type == inf:
grad_norm = max(g.data.abs().max() for g in grads)
elif norm_type == 2.0 and enable_cuda_kernels:
grad_norm = calc_l2_norm(grads) ** norm_type
else:
grad_norm = calc_lp(grads, norm_type)
return grad_norm
[docs]def reduce_grads(gradients, parameters):
"""
Prepare gradients for norm computation in distributed training.
This function processes gradients to prepare them for norm computation,
particularly in distributed training scenarios. It converts gradients to
float32 for numerical stability during norm calculations.
:param gradients: List of gradient tensors
:type gradients: list[torch.Tensor]
:param parameters: List of parameter tensors
:type parameters: list[torch.Tensor]
:return: List of processed gradient tensors
:rtype: list[torch.Tensor]
Example::
>>> grads = [torch.randn(3, 3), torch.randn(2, 2)]
>>> params = [torch.randn(3, 3), torch.randn(2, 2)]
>>> processed_grads = reduce_grads(grads, params)
"""
parallel_grads = []
for g, _ in zip(gradients, parameters):
# process all ranks for FSDP parameter group
parallel_grads.append(g.data.float())
return parallel_grads
[docs]def get_tensor_norm(norm: Union[float, torch.Tensor], move_to_cuda) -> torch.Tensor:
"""
Convert norm to tensor and move to appropriate device.
This utility function ensures that norm values are properly converted to
tensors and placed on the correct device for further computation. It handles
both scalar and tensor inputs.
:param norm: Norm value as float or tensor
:type norm: Union[float, torch.Tensor]
:param move_to_cuda: Whether to move the tensor to CUDA
:type move_to_cuda: bool
:return: Norm as tensor on appropriate device
:rtype: torch.Tensor
Example::
>>> norm_float = 2.5
>>> norm_tensor = get_tensor_norm(norm_float, True)
>>> print(norm_tensor.device) # cuda:0 (if CUDA available)
"""
if isinstance(norm, float):
norm = torch.Tensor([norm])
if move_to_cuda:
norm = norm.to(get_current_device())
return norm
[docs]def compute_norm(gradients, parameters, norm_type=2):
"""
Get the norm across distributed environment.
This function computes gradient norms in a distributed training setting,
handling device placement and distributed reduction. It's commonly used
for gradient clipping and monitoring training stability.
:param gradients: The gradient values
:type gradients: list[torch.Tensor]
:param parameters: The parameters each gradient corresponds to
:type parameters: list[torch.Tensor]
:param norm_type: Type of the used p-norm. Can be ``'inf'`` for infinity norm
:type norm_type: float or int
:return: Total norm of the parameters, need total_norm**(1/norm) before using
:rtype: float
Example::
>>> grads = [param.grad for param in model.parameters() if param.grad is not None]
>>> params = [param for param in model.parameters() if param.grad is not None]
>>> total_norm = compute_norm(grads, params)
>>> print(f"Gradient norm: {total_norm}")
"""
enable_cuda_kernels = gradients[0].device.type != "cpu"
# Norm parameters.
norm_type = float(norm_type)
tensor_parallel_grads = reduce_grads(gradients, parameters)
tensor_parallel_norm = get_norm(tensor_parallel_grads, norm_type, enable_cuda_kernels)
# If norm is type of float, then we convert them into torch.Tensor.
tensor_parallel_norm = get_tensor_norm(tensor_parallel_norm, enable_cuda_kernels)
# If grads are on CPU, the norms is also on CPU. Cast them to CUDA tensors
if not enable_cuda_kernels:
tensor_parallel_norm = tensor_parallel_norm.to(get_current_device())
total_norm = tensor_parallel_norm
"""
Sum across all model-parallel GPUs.
"""
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM)
if torch.is_tensor(total_norm):
total_norm = total_norm.item()
# Scale.
if total_norm == float("inf") or total_norm == -float("inf"):
total_norm = -1
if math.isnan(total_norm):
total_norm = -2
return total_norm
[docs]class BaseGradScaler(ABC):
"""
A base class for the gradient scaler.
This abstract base class defines the interface for gradient scalers used in
mixed precision training. Gradient scalers help prevent gradient underflow
in float16 training by scaling up the loss before backpropagation.
:param initial_scale: The initial loss scale
:type initial_scale: float
Example::
>>> # Subclass implementation
>>> class MyGradScaler(BaseGradScaler):
... def update(self, overflow: bool) -> None:
... # Custom update logic
... pass
"""
def __init__(self, initial_scale: float):
assert initial_scale > 0
self._scale = torch.tensor([initial_scale], device=get_current_device(), dtype=torch.float32)
@property
def scale(self) -> Tensor:
"""
Returns the loss scale.
:return: Current loss scale
:rtype: torch.Tensor
Example::
>>> scaler = DynamicGradScaler(initial_scale=1024.0)
>>> print(scaler.scale.item()) # 1024.0
"""
return self._scale
@property
def inv_scale(self) -> Tensor:
"""
Returns the inverse of the loss scale.
The inverse scale is used to unscale gradients after backpropagation
to restore their original magnitudes.
:return: Inverse of current loss scale
:rtype: torch.Tensor
Example::
>>> scaler = DynamicGradScaler(initial_scale=1024.0)
>>> print(scaler.inv_scale.item()) # 0.0009765625 (1/1024)
"""
return self._scale.double().reciprocal().float()
[docs] def state_dict(self) -> Dict:
"""
Returns the states of the gradient scaler as a dict object.
:return: State dictionary containing scale
:rtype: Dict
Example::
>>> scaler = DynamicGradScaler()
>>> state = scaler.state_dict()
>>> print(state.keys())
dict_keys(['scale'])
"""
state_dict = dict()
state_dict["scale"] = self.scale
return state_dict
[docs] def load_state_dict(self, state_dict: Dict) -> None:
"""
Load the states of the gradient scaler from a dict object.
:param state_dict: The states of the gradient scaler
:type state_dict: Dict
Example::
>>> scaler = DynamicGradScaler()
>>> state = {"scale": torch.tensor([2048.0])}
>>> scaler.load_state_dict(state)
"""
self._scale = state_dict["scale"]
[docs] @abstractmethod
def update(self, overflow: bool) -> None:
"""
Update the loss scale.
This abstract method must be implemented by subclasses to define
how the loss scale should be updated based on overflow detection.
:param overflow: Whether overflow occurs
:type overflow: bool
"""
pass
[docs]class DynamicGradScaler(BaseGradScaler):
"""
A gradient scaler which uses dynamic loss scale.
This scaler automatically adjusts the loss scale based on gradient overflow
detection. It increases the scale when training is stable and decreases it
when overflow occurs, providing automatic mixed precision training support.
:param initial_scale: The initial loss scale
:type initial_scale: float
:param growth_factor: The multiplication factor for increasing loss scale
:type growth_factor: float
:param backoff_factor: The multiplication factor for decreasing loss scale
:type backoff_factor: float
:param growth_interval: The number of steps to increase loss scale when no overflow occurs
:type growth_interval: int
:param min_scale: The minimum loss scale
:type min_scale: Optional[float]
:param max_scale: The maximum loss scale
:type max_scale: Optional[float]
:param hysteresis: The number of overflows before decreasing loss scale
:type hysteresis: int
:param dtype: The data type used for training
:type dtype: torch.dtype
Example::
>>> scaler = DynamicGradScaler(initial_scale=2**16, growth_factor=2.0)
>>> # In training loop
>>> for epoch in range(num_epochs):
... for batch in dataloader:
... # Forward pass with scaled loss
... scaled_loss = loss * scaler.scale
... scaled_loss.backward()
... # Check for overflow and update scaler
... overflow = check_overflow(model.parameters())
... scaler.update(overflow)
"""
def __init__( # pylint: disable=R0917
self,
initial_scale: float = 2**16,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
min_scale: Optional[float] = 1,
max_scale: Optional[float] = 2**24,
hysteresis: int = 2,
dtype: torch.dtype = torch.bfloat16,
):
super().__init__(initial_scale)
if min_scale:
self._min_scale = torch.tensor([min_scale], device=get_current_device(), dtype=torch.float32)
else:
self._min_scale = None
if max_scale:
self._max_scale = torch.tensor([max_scale], device=get_current_device(), dtype=torch.float32)
else:
self._max_scale = None
self._growth_factor = growth_factor
self._backoff_factor = backoff_factor
self._growth_interval = growth_interval
self._growth_step = 0
self._hysteresis = hysteresis
self._hysteresis_step = 0
self._dtype = dtype
self._sanity_checks()
def _sanity_checks(self) -> None:
"""
Check if the arguments are correct.
This method validates all the initialization parameters to ensure they
are within reasonable ranges and compatible with the specified data type.
It provides warnings for potentially suboptimal configurations.
"""
assert self._dtype in [torch.float16, torch.bfloat16, torch.float32]
if self._min_scale is not None:
min_scale = self._min_scale.item()
assert min_scale > 0, "The minimum gradient scale cannot be zero or negative"
if self._dtype != torch.float16 and min_scale != 1.0:
print(f"Detect you use {self._dtype}, but min_scale: {min_scale} != 1.0")
if self._max_scale:
max_scale = self._max_scale.item()
assert max_scale > 0, "The maximum gradient scale cannot be zero or negative"
if self._dtype != torch.float16 and max_scale != 1.0:
print(f"Detect you use {self._dtype}, but max_scale: {max_scale} != 1.0")
if self._dtype == torch.float16:
assert self._growth_factor > 1.0, "The growth factor cannot be equal or smaller than 1"
assert self._backoff_factor < 1.0 and self._backoff_factor > 0, "The backoff factor must be between 0 and 1"
else:
assert self._growth_factor >= 1.0, "The growth factor cannot be smaller than 1"
assert (
self._backoff_factor <= 1.0 and self._backoff_factor > 0
), "The backoff factor must be between 0 and 1"
if self._growth_factor != 1.0:
print(f"Detect you use {self._dtype}, but growth_factor: {self._growth_factor} != 1.0")
if self._backoff_factor != 1.0:
print(f"Detect you use {self._dtype}, but backoff_factor: {self._backoff_factor} != 1.0")
assert self._hysteresis >= 0, "The hysteresis cannot be negative"
[docs] def update(self, overflow: bool) -> None:
"""
Update the loss scale based on whether overflow occurred.
This method implements the dynamic scaling algorithm. When overflow occurs,
it increments the hysteresis counter and resets growth progress. When no
overflow occurs for a sufficient period, it increases the scale to maximize
gradient precision.
:param overflow: Whether overflow occurs
:type overflow: bool
Example::
>>> scaler = DynamicGradScaler()
>>> # Simulate training steps
>>> scaler.update(False) # No overflow, increment growth counter
>>> scaler.update(True) # Overflow detected, may decrease scale
"""
if overflow:
self._hysteresis_step += 1
self._growth_step = 0
if self._hysteresis_step >= self._hysteresis:
self._backoff_scale()
print(f"Overflow occurs, the loss scale is adjusted to {self.scale.item()}")
else:
self._growth_step += 1
if self._growth_step == self._growth_interval:
self._growth_step = 0
self._hysteresis_step = 0
self._grow_scale()
print(
f"No overflow for consecutive {self._growth_interval} steps, "
f"the loss scale is adjusted to {self.scale.item()}",
)
def _backoff_scale(self) -> None:
"""
Decrease the loss scale when overflow occurs.
This private method reduces the loss scale by the backoff factor and
ensures it doesn't go below the minimum scale if specified.
"""
self._scale = self._scale * self._backoff_factor
if self._min_scale:
self._scale = torch.max(self._scale, self._min_scale)
def _grow_scale(self) -> None:
"""
Increase the loss scale when no overflow occurs for a period.
This private method increases the loss scale by the growth factor and
ensures it doesn't exceed the maximum scale if specified.
"""
self._scale = self._scale * self._growth_factor
if self._max_scale:
self._scale = torch.min(self._scale, self._max_scale)
[docs] def state_dict(self):
"""
Returns the states of the gradient scaler as a dict object.
This method provides a complete state dictionary that can be saved
and restored to maintain training consistency across checkpoints.
:return: A dictionary containing the current state of the gradient scaler
:rtype: dict
Example::
>>> scaler = DynamicGradScaler()
>>> scaler_state = scaler.state_dict()
>>> print(scaler_state.keys())
dict_keys(['_scale', '_growth_step', '_hysteresis_step'])
"""
state_dict = dict()
state_dict["_scale"] = self._scale.item()
state_dict["_growth_step"] = self._growth_step
state_dict["_hysteresis_step"] = self._hysteresis_step
return state_dict
[docs] def load_state_dict(self, state_dict):
"""
Load the states of the gradient scaler from a dict object.
This method restores the scaler state from a previously saved state
dictionary, enabling seamless checkpoint restoration.
:param state_dict: The states of the gradient scaler
:type state_dict: dict
Example::
>>> scaler = DynamicGradScaler()
>>> scaler.load_state_dict({
... "_scale": 2048.0,
... "_growth_step": 0,
... "_hysteresis_step": 0
... })
"""
self._scale = self._scale.fill_(state_dict["_scale"])
self._growth_step = state_dict["_growth_step"]
self._hysteresis_step = state_dict["_hysteresis_step"]
[docs]class BaseOptimizer(Optimizer):
"""
Base Optimizer class that wraps a PyTorch optimizer.
This class provides a wrapper around PyTorch optimizers, exposing the same interface
while allowing for additional functionality like custom backward passes and gradient clipping.
It serves as a foundation for building more sophisticated optimizers with enhanced features
for distributed training, gradient scaling, and custom optimization strategies.
:param optim: The PyTorch optimizer to wrap
:type optim: torch.optim.Optimizer
Example::
>>> import torch.optim as optim
>>> model = torch.nn.Linear(10, 1)
>>> pytorch_optimizer = optim.Adam(model.parameters(), lr=0.001)
>>> wrapped_optimizer = BaseOptimizer(pytorch_optimizer)
>>> # Use wrapped_optimizer like a regular optimizer
>>> wrapped_optimizer.zero_grad()
>>> loss.backward()
>>> wrapped_optimizer.step()
"""
def __init__(self, optim: Optimizer): # pylint: disable=W0231
self.optim = optim
@property
def param_groups(self):
"""
Access to the parameter groups of the wrapped optimizer.
Parameter groups allow different sets of parameters to have different
optimization settings like learning rates, weight decay, etc.
:return: List of parameter groups
:rtype: list
Example::
>>> optimizer = BaseOptimizer(torch.optim.Adam(model.parameters()))
>>> print(len(optimizer.param_groups)) # Number of parameter groups
>>> print(optimizer.param_groups[0]['lr']) # Learning rate of first group
"""
return self.optim.param_groups
@property
def defaults(self):
"""
Access to the default parameters of the wrapped optimizer.
:return: Default parameters dictionary
:rtype: dict
Example::
>>> optimizer = BaseOptimizer(torch.optim.Adam(model.parameters(), lr=0.001))
>>> print(optimizer.defaults['lr']) # 0.001
"""
return self.optim.defaults
[docs] def add_param_group(self, *args, **kwargs):
"""
Add a parameter group to the optimizer.
This method allows adding new parameter groups with potentially different
optimization settings during training.
:param args: Positional arguments to pass to the wrapped optimizer
:param kwargs: Keyword arguments to pass to the wrapped optimizer
:return: Result from the wrapped optimizer's add_param_group method
Example::
>>> optimizer = BaseOptimizer(torch.optim.Adam(model.parameters()))
>>> new_params = [torch.randn(10, requires_grad=True)]
>>> optimizer.add_param_group({'params': new_params, 'lr': 0.01})
"""
return self.optim.add_param_group(*args, **kwargs)
[docs] def step(self, *args, **kwargs):
"""
Perform a single optimization step.
This method executes one optimization step, updating the model parameters
based on their gradients and the optimizer's algorithm.
:param args: Positional arguments to pass to the wrapped optimizer
:param kwargs: Keyword arguments to pass to the wrapped optimizer
:return: Result from the wrapped optimizer's step method
Example::
>>> optimizer = BaseOptimizer(torch.optim.Adam(model.parameters()))
>>> loss.backward()
>>> optimizer.step()
"""
return self.optim.step(*args, **kwargs)
[docs] def zero_grad(self, *args, **kwargs):
"""
Reset the gradients of all optimized tensors.
This method clears the gradients of all parameters, which is typically
done before each backward pass to prevent gradient accumulation.
:param args: Positional arguments to pass to the wrapped optimizer
:param kwargs: Keyword arguments to pass to the wrapped optimizer
Example::
>>> optimizer = BaseOptimizer(torch.optim.Adam(model.parameters()))
>>> optimizer.zero_grad() # Clear gradients
>>> loss.backward() # Compute new gradients
>>> optimizer.step() # Update parameters
"""
self.optim.zero_grad(*args, **kwargs)
[docs] def load_state_dict(self, *args, **kwargs):
"""
Load the optimizer state.
This method restores the optimizer's internal state from a state dictionary,
enabling checkpoint restoration and training resumption.
:param args: Positional arguments to pass to the wrapped optimizer
:param kwargs: Keyword arguments to pass to the wrapped optimizer
Example::
>>> optimizer = BaseOptimizer(torch.optim.Adam(model.parameters()))
>>> state_dict = torch.load('optimizer_checkpoint.pth')
>>> optimizer.load_state_dict(state_dict)
"""
self.optim.load_state_dict(*args, **kwargs)
[docs] def state_dict(self):
"""
Return the state of the optimizer as a dict.
This method provides the optimizer's complete state for checkpointing,
including parameter states and hyperparameters.
:return: The state of the optimizer
:rtype: dict
Example::
>>> optimizer = BaseOptimizer(torch.optim.Adam(model.parameters()))
>>> state_dict = optimizer.state_dict()
>>> torch.save(state_dict, 'optimizer_checkpoint.pth')
"""
return self.optim.state_dict()
[docs] def backward(self, loss):
"""
Compute gradients of the loss.
This method performs backpropagation to compute gradients of the loss
with respect to the model parameters.
:param loss: The loss tensor to compute gradients for
:type loss: torch.Tensor
Example::
>>> optimizer = BaseOptimizer(torch.optim.Adam(model.parameters()))
>>> loss = criterion(output, target)
>>> optimizer.backward(loss)
"""
loss.backward()
[docs] def backward_by_grad(self, tensor, grad):
"""
Compute gradients of the tensor with respect to the provided gradients.
This method allows for custom gradient computation, useful in scenarios
like gradient accumulation or custom loss functions.
:param tensor: The tensor to compute gradients for
:type tensor: torch.Tensor
:param grad: The gradients to backpropagate
:type grad: torch.Tensor
Example::
>>> optimizer = BaseOptimizer(torch.optim.Adam(model.parameters()))
>>> output = model(input)
>>> custom_grad = torch.randn_like(output)
>>> optimizer.backward_by_grad(output, custom_grad)
"""
torch.autograd.backward(tensors=tensor, grad_tensors=grad)
[docs] def clip_grad_norm(self):
"""
Clip the gradient norm.
This is a placeholder method that should be implemented by subclasses
to provide gradient clipping functionality. Gradient clipping helps
prevent gradient explosion in deep networks.
Example::
>>> class MyOptimizer(BaseOptimizer):
... def clip_grad_norm(self):
... torch.nn.utils.clip_grad_norm_(self.param_groups[0]['params'], max_norm=1.0)
"""
pass