Shortcuts

Source code for lightrft.strategy.fsdp.fsdp_optimizer

"""
FSDP Optimizer Module for PyTorch.

This module provides optimizers and utilities for working with PyTorch's Fully Sharded Data Parallel (FSDP)
training. It includes an adapted optimizer for FSDP that handles gradient scaling, clipping, and state
management, as well as utility functions for offloading and loading optimizer states.

The main components include:
- FSDPadaptOptimizer: A wrapper optimizer that handles mixed precision training with FSDP
- Utility functions for optimizer state management and memory optimization
- Support for gradient scaling, clipping, and overflow detection
- Efficient FP16/FP32 parameter conversion and synchronization

Example::

    import torch
    from torch.optim import AdamW
    from lightrft.trainer.fsdp_optimizer import FSDPadaptOptimizer

    # Create base optimizer
    base_optimizer = AdamW(model.parameters(), lr=1e-4)

    # Wrap with FSDP adapter
    fsdp_optimizer = FSDPadaptOptimizer(base_optimizer)

    # Training loop
    for batch in dataloader:
        loss = model(batch)
        fsdp_optimizer.backward(loss)
        success = fsdp_optimizer.step()
        if not success:
            print("Gradient overflow detected, skipping step")
"""

import torch
import torch.distributed as dist
from torch.optim import Optimizer

try:
    from torch.distributed.tensor import DTensor

    DTENSOR_SUPPORTED = True
except (ModuleNotFoundError, ImportError):
    DTENSOR_SUPPORTED = False

from lightrft.utils import get_current_device

from .fsdp_utils import BaseOptimizer, DynamicGradScaler


[docs]class FSDPadaptOptimizer(BaseOptimizer): """ Optimizer wrapper for PyTorch FSDP (Fully Sharded Data Parallel). This optimizer handles the necessary components for mixed precision training with FSDP: - Gradient scaling for numerical stability in mixed precision training - Gradient clipping and unscaling to prevent gradient explosion - State dictionary management for checkpointing and model saving - Efficient FP16/FP32 parameter conversion and synchronization - Overflow detection and recovery mechanisms The optimizer maintains separate FP16 and FP32 parameter groups where FP16 parameters share memory space with the model's FlatParam, while FP32 parameters are used for the actual optimization step to maintain numerical precision. :param optimizer: The base optimizer to wrap (e.g., AdamW, SGD) :type optimizer: torch.optim.Optimizer Example:: import torch from torch.optim import AdamW base_optimizer = AdamW(model.parameters(), lr=1e-4) fsdp_optimizer = FSDPadaptOptimizer(base_optimizer) # Training step loss = model(batch) fsdp_optimizer.backward(loss) success = fsdp_optimizer.step() """ def __init__( self, optimizer: Optimizer, ): """ Initialize the FSDP adapted optimizer. :param optimizer: The base optimizer to wrap :type optimizer: torch.optim.Optimizer """ super().__init__(optim=optimizer) # gradient scaler for mixed precision training self.grad_scaler = DynamicGradScaler(initial_scale=1.0, growth_factor=1.0, backoff_factor=1.0, max_scale=1.0) # gradient clipping threshold self._clip_grad_norm = 1.0 # padding data for computing norm when no gradients are available self.padding_grad = torch.zeros([32], dtype=torch.bfloat16, device=get_current_device()) self.padding_tensor = torch.zeros([32], dtype=torch.bfloat16, device=get_current_device()) # fp16 and fp32 parameter groups # fp16 shares memory space with model.FlatParam, fp32 shares memory space with optim.param_group self._fp16_param_groups = dict() self._fp32_param_tensor_groups = dict() # initialize fp16 and fp32 parameter groups for group_idx, param_group in enumerate(self.optim.param_groups): group_params = param_group["params"] # store reference to fp16 FlatParam storage self._fp16_param_groups[group_idx] = group_params # create fp32 copies of parameters for optimization fp32_tensor_param = [param.data.float() for param in group_params] self._fp32_param_tensor_groups[group_idx] = fp32_tensor_param # replace optimizer parameter group with fp32 copies param_group["params"] = fp32_tensor_param @property def loss_scale(self): """ Get the current loss scale value used for gradient scaling. :return: The current loss scale tensor :rtype: torch.Tensor """ return self.grad_scaler.scale
[docs] def backward(self, loss, retain_graph=False): """ Perform backward pass with loss scaling for mixed precision training. The loss is scaled to prevent gradient underflow in FP16 computations. :param loss: The loss tensor to backpropagate :type loss: torch.Tensor :param retain_graph: If True, the computation graph will be retained for multiple backward passes :type retain_graph: bool Example:: loss = criterion(outputs, targets) optimizer.backward(loss) """ loss = self.loss_scale * (loss.float()) loss.backward(retain_graph=retain_graph)
def _compute_norm_with_fsdp_flatten(self, group_id, norm_type=2): """ Compute the gradient norm for a parameter group with FSDP flattened parameters. This method handles the computation of gradient norms across distributed processes, taking into account FSDP's parameter flattening and sharding. It supports both regular tensors and DTensor for distributed computation. :param group_id: The parameter group ID to compute norm for :type group_id: int :param norm_type: The type of norm to compute (default: 2 for L2 norm) :type norm_type: int :return: The computed gradient norm, or -1 if overflow is detected :rtype: torch.Tensor """ params = [p for p in self._fp16_param_groups[group_id] if p.untyped_storage().size() != 0] gradients = [p.grad for p in params if p.untyped_storage().size() != 0] # use padding tensors if no valid parameters/gradients found if len(params) <= 0 or len(gradients) <= 0: gradients = self.padding_grad params = self.padding_tensor # compute individual gradient norms (adapted from DeepSpeed) grad_norms = [] for g, p in zip(gradients, params): grad_norms.append(g.double().norm(2)) # compute total norm across all parameters if len(grad_norms) == 0: # FIX https://github.com/microsoft/DeepSpeed/issues/3564 # handle edge case when no gradients are available total_norm_cuda = torch.tensor(0, dtype=gradients[0].dtype).to(get_current_device()).double() else: total_norm_cuda = torch.sum(torch.pow(torch.stack(grad_norms), 2)) # if grad_norm is DTensor, it bahaves like: # stack_out_full = stack_out.full_tensor() # pow_out_full = torch.pow(stack_out_full, 2) # torch.sum(pow_out_full) # handle DTensor case for distributed computation if DTENSOR_SUPPORTED and isinstance(total_norm_cuda, DTensor): # 20250422(sdx): when DTensor enbaled, the output of torch.pow is REPLICATED across FSDP group, # and already the sum of suqare, so we should not do reduce-sum again. # DTensor output is already replicated across FSDP group total_norm_cuda = total_norm_cuda.full_tensor() else: # reduce across processes in Zero3 group dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM) print(f"after reduce {total_norm_cuda=}", flush=True) total_norm = total_norm_cuda ** (1.0 / norm_type) # check for overflow conditions norm_is_inf = total_norm.isinf() norm_is_nan = total_norm.isnan() inf_or_nan = norm_is_nan.logical_or(norm_is_inf) # return -1 if overflow detected, otherwise return the norm err = torch.tensor(-1.0, device=get_current_device(), dtype=torch.float) total_norm = inf_or_nan * err + inf_or_nan.logical_not() * total_norm return total_norm
[docs] def zero_grad(self): """ Set gradients of all FP16 parameters to None. This method clears gradients from the FP16 parameter groups that are used for the forward pass and gradient computation. """ for _, param_group in self._fp16_param_groups.items(): for param in param_group: param.grad = None
[docs] def step(self): """ Perform a single optimization step with overflow detection and gradient processing. This method orchestrates the complete optimization process: 1. Computes gradient norms for overflow detection 2. Updates the gradient scaler based on overflow status 3. Transfers gradients from FP16 to FP32 parameters 4. Unscales and clips gradients as needed 5. Performs the optimization step on FP32 parameters 6. Copies updated FP32 parameters back to FP16 parameters :return: True if the optimization step was successful, False if overflow occurred :rtype: bool Example:: success = optimizer.step() if not success: print("Gradient overflow detected, step skipped") """ # compute gradient norms for overflow detection found_inf = False norm_groups = [] for group_idx in range(len(self.param_groups)): norm_group = self._compute_norm_with_fsdp_flatten(group_idx) if norm_group == -1: found_inf = True norm_groups.append(norm_group) # update gradient scaler and handle overflow loss_scale = float(self.loss_scale.item()) # backup current scale self.grad_scaler.update(found_inf) if found_inf: print("Overflow occurs, please check it.", flush=True) self.zero_grad() return False # transfer gradients from fp16 to fp32 parameters for group_idx in range(len(self.param_groups)): if len(self._fp32_param_tensor_groups[group_idx]) <= 0: continue dtype = self._fp32_param_tensor_groups[group_idx][0].dtype fp16_params = [p for p in self._fp16_param_groups[group_idx] if p.untyped_storage().size() != 0] grad_fp32 = [p.grad.to(dtype) for p in fp16_params] device = self._fp32_param_tensor_groups[group_idx][0].device nonzero_fp32 = [p for p in self._fp32_param_tensor_groups[group_idx] if p.untyped_storage().size() != 0] for p, g in zip(nonzero_fp32, grad_fp32): p.grad = g.to(device) # compute global gradient norm and unscale/clip gradients scaled_global_grad_norm = torch.linalg.norm(torch.stack(norm_groups)) # pylint: disable=E1102 self._unscale_and_clip_grads(scaled_global_grad_norm, loss_scale) # perform optimization step self.optim.step() self.zero_grad() # copy updated fp32 parameters back to fp16 parameters for group_idx in range(len(self._fp16_param_groups)): fp16_params = [p for p in self._fp16_param_groups[group_idx] if p.untyped_storage().size() != 0] fp32_tensor_params = [ p for p in self._fp32_param_tensor_groups[group_idx] if p.untyped_storage().size() != 0 ] # release fp32 gradients for fp32_param in fp32_tensor_params: fp32_param.grad = None # update fp16 parameters with fp32 values for p, q in zip(fp16_params, fp32_tensor_params): p.data.copy_(q) return True
[docs] def clip_grad_norm(self, model, max_norm): """ Set gradient clipping norm (actual clipping is performed in the step() method). :param model: The model whose gradients will be clipped (unused in current implementation) :param max_norm: Maximum norm value for gradient clipping :type max_norm: float Note: The actual gradient clipping is performed internally in the step() method using the _unscale_and_clip_grads method. """ # actual clipping is conducted in the step() method pass
######################### # utils from hybirdzero # ######################### def _unscale_and_clip_grads(self, total_norm, loss_scale): """ Unscale and clip gradients based on the total norm and loss scale. This method combines gradient unscaling (to reverse the loss scaling) with gradient clipping to prevent gradient explosion. The combined scale factor accounts for both operations. :param total_norm: The total gradient norm across all parameters :type total_norm: torch.Tensor :param loss_scale: The current loss scale used for gradient scaling :type loss_scale: float """ # compute combined scale factor for this group combined_scale = loss_scale if self._clip_grad_norm > 0.0: # compute clipping factor (norm is scaled by loss_scale) clip = ((total_norm / loss_scale) + 1e-6) / self._clip_grad_norm clip = torch.clamp(clip, min=1.0) combined_scale = clip * loss_scale # apply combined unscaling and clipping to all fp32 parameters for _, param in self._fp32_param_tensor_groups.items(): for p in param: if p.untyped_storage().size() != 0: p.grad.data.mul_(1.0 / combined_scale)
[docs] def state_dict(self): """ Get the complete state dictionary for checkpointing. The state dictionary includes: - Gradient scaler state for loss scaling - Base optimizer states (momentum, etc.) - FP32 parameter weights for precise restoration :return: A dictionary containing all optimizer states and parameters :rtype: dict Example:: # Save optimizer state checkpoint = { 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch } torch.save(checkpoint, 'checkpoint.pt') """ states = {} # save gradient scaler state grad_scaler = self.grad_scaler.state_dict() states["grad_scaler"] = grad_scaler # save base optimizer states optim_states = self.optim.state_dict() states["base_optim_states"] = optim_states # save fp32 parameter weights flat_fp32_weights = {} for group_idx, param in self._fp32_param_tensor_groups.items(): flat_fp32_weights[group_idx] = param states["flat_fp32_weights"] = flat_fp32_weights return states
[docs] def load_state_dict(self, states): """ Load a complete state dictionary from checkpoint. This method restores the optimizer to its exact previous state, including gradient scaler, optimizer states, and both FP32 and FP16 parameter values. :param states: The state dictionary to load :type states: dict :raises AssertionError: If required state components are missing or parameter counts are inconsistent Example:: # Load optimizer state checkpoint = torch.load('checkpoint.pt') optimizer.load_state_dict(checkpoint['optimizer']) """ assert "grad_scaler" in states, "Not found grad_scaler state!" grad_scaler = states["grad_scaler"] self.grad_scaler.load_state_dict(grad_scaler) # load base optimizer states optim_states = states["base_optim_states"] self.optim.load_state_dict(optim_states) # load fp32 optimizer weights flat_fp32_weights = states["flat_fp32_weights"] assert set(flat_fp32_weights.keys()) == set(self._fp32_param_tensor_groups) for group_idx, param in flat_fp32_weights.items(): self_param = self._fp32_param_tensor_groups[group_idx] assert len(self_param ) == len(param), f"The number of flat tensor is inconsistent, {len(self_param)} != {len(param)}" for p, q in zip(self_param, param): p.data.copy_(q.data) # synchronize fp16 model weights with loaded fp32 weights for group_idx, param in flat_fp32_weights.items(): fp16_param = self._fp16_param_groups[group_idx] fp32_param = self._fp32_param_tensor_groups[group_idx] for p, q in zip(fp16_param, fp32_param): p.data.copy_(q.data)
@torch.no_grad() def offload_fsdp_optimizer(optimizer): """ Offload optimizer states from GPU to CPU memory to reduce GPU memory usage. This function moves all tensor-based optimizer states (such as momentum buffers, variance estimates, etc.) from GPU to CPU memory. This is useful for reducing GPU memory pressure during training, especially when using large models or when GPU memory is limited. :param optimizer: The optimizer whose states should be offloaded to CPU :type optimizer: torch.optim.Optimizer Example:: # Offload optimizer states to save GPU memory offload_fsdp_optimizer(optimizer) # Later, load back when needed load_fsdp_optimizer(optimizer) Note: After offloading, you should call load_fsdp_optimizer before the next optimization step to ensure states are available on the correct device. """ if not optimizer.state: return for param_group in optimizer.param_groups: for param in param_group["params"]: state = optimizer.state[param] for key, value in state.items(): if isinstance(value, torch.Tensor): state[key] = value.to("cpu", non_blocking=True) torch.cuda.empty_cache() @torch.no_grad() def load_fsdp_optimizer(optimizer, device_id=torch.cuda.current_device()): """ Load optimizer states from CPU back to the specified GPU device. This function moves all tensor-based optimizer states from CPU memory back to the specified GPU device. This is typically used after offload_fsdp_optimizer to restore states for the next optimization step. :param optimizer: The optimizer whose states should be loaded to GPU :type optimizer: torch.optim.Optimizer :param device_id: The device ID to load states to (default: current CUDA device) :type device_id: int or torch.device Example:: # Load optimizer states back to GPU before optimization load_fsdp_optimizer(optimizer, device_id=0) # Or use current device load_fsdp_optimizer(optimizer) Note: This function automatically determines the current device using get_current_device() to ensure compatibility with distributed training setups. """ if not optimizer.state: return torch.cuda.empty_cache() # Use get_current_device() instead of torch.cuda.current_device() for distributed compatibility device_id = get_current_device() for param_group in optimizer.param_groups: for param in param_group["params"]: state = optimizer.state[param] for key, value in state.items(): if isinstance(value, torch.Tensor): state[key] = value.to(device_id, non_blocking=True)