lightrft.strategy.fsdp.fsdp_optimizer¶
FSDP Optimizer Module for PyTorch.
This module provides optimizers and utilities for working with PyTorch’s Fully Sharded Data Parallel (FSDP) training. It includes an adapted optimizer for FSDP that handles gradient scaling, clipping, and state management, as well as utility functions for offloading and loading optimizer states.
The main components include: - FSDPadaptOptimizer: A wrapper optimizer that handles mixed precision training with FSDP - Utility functions for optimizer state management and memory optimization - Support for gradient scaling, clipping, and overflow detection - Efficient FP16/FP32 parameter conversion and synchronization
Example:
import torch
from torch.optim import AdamW
from lightrft.trainer.fsdp_optimizer import FSDPadaptOptimizer
# Create base optimizer
base_optimizer = AdamW(model.parameters(), lr=1e-4)
# Wrap with FSDP adapter
fsdp_optimizer = FSDPadaptOptimizer(base_optimizer)
# Training loop
for batch in dataloader:
loss = model(batch)
fsdp_optimizer.backward(loss)
success = fsdp_optimizer.step()
if not success:
print("Gradient overflow detected, skipping step")
DTENSOR_SUPPORTED¶
- lightrft.strategy.fsdp.fsdp_optimizer.DTENSOR_SUPPORTED = True¶
bool(x) -> bool
Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.
FSDPadaptOptimizer¶
- class lightrft.strategy.fsdp.fsdp_optimizer.FSDPadaptOptimizer(*args: Any, **kwargs: Any)[source]¶
Optimizer wrapper for PyTorch FSDP (Fully Sharded Data Parallel).
This optimizer handles the necessary components for mixed precision training with FSDP:
Gradient scaling for numerical stability in mixed precision training
Gradient clipping and unscaling to prevent gradient explosion
State dictionary management for checkpointing and model saving
Efficient FP16/FP32 parameter conversion and synchronization
Overflow detection and recovery mechanisms
The optimizer maintains separate FP16 and FP32 parameter groups where FP16 parameters share memory space with the model’s FlatParam, while FP32 parameters are used for the actual optimization step to maintain numerical precision.
- Parameters:
optimizer (torch.optim.Optimizer) – The base optimizer to wrap (e.g., AdamW, SGD)
Example:
import torch from torch.optim import AdamW base_optimizer = AdamW(model.parameters(), lr=1e-4) fsdp_optimizer = FSDPadaptOptimizer(base_optimizer) # Training step loss = model(batch) fsdp_optimizer.backward(loss) success = fsdp_optimizer.step()
- backward(loss, retain_graph=False)[source]¶
Perform backward pass with loss scaling for mixed precision training.
The loss is scaled to prevent gradient underflow in FP16 computations.
- Parameters:
loss (torch.Tensor) – The loss tensor to backpropagate
retain_graph (bool) – If True, the computation graph will be retained for multiple backward passes
Example:
loss = criterion(outputs, targets) optimizer.backward(loss)
- clip_grad_norm(model, max_norm)[source]¶
Set gradient clipping norm (actual clipping is performed in the step() method).
- Parameters:
model – The model whose gradients will be clipped (unused in current implementation)
max_norm (float) – Maximum norm value for gradient clipping
- Note:
The actual gradient clipping is performed internally in the step() method using the _unscale_and_clip_grads method.
- load_state_dict(states)[source]¶
Load a complete state dictionary from checkpoint.
This method restores the optimizer to its exact previous state, including gradient scaler, optimizer states, and both FP32 and FP16 parameter values.
- Parameters:
states (dict) – The state dictionary to load
- Raises:
AssertionError – If required state components are missing or parameter counts are inconsistent
Example:
# Load optimizer state checkpoint = torch.load('checkpoint.pt') optimizer.load_state_dict(checkpoint['optimizer'])
- property loss_scale¶
Get the current loss scale value used for gradient scaling.
- Returns:
The current loss scale tensor
- Return type:
torch.Tensor
- state_dict()[source]¶
Get the complete state dictionary for checkpointing.
The state dictionary includes: - Gradient scaler state for loss scaling - Base optimizer states (momentum, etc.) - FP32 parameter weights for precise restoration
- Returns:
A dictionary containing all optimizer states and parameters
- Return type:
dict
Example:
# Save optimizer state checkpoint = { 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch } torch.save(checkpoint, 'checkpoint.pt')
- step()[source]¶
Perform a single optimization step with overflow detection and gradient processing.
This method orchestrates the complete optimization process: 1. Computes gradient norms for overflow detection 2. Updates the gradient scaler based on overflow status 3. Transfers gradients from FP16 to FP32 parameters 4. Unscales and clips gradients as needed 5. Performs the optimization step on FP32 parameters 6. Copies updated FP32 parameters back to FP16 parameters
- Returns:
True if the optimization step was successful, False if overflow occurred
- Return type:
bool
Example:
success = optimizer.step() if not success: print("Gradient overflow detected, step skipped")
offload_fsdp_optimizer¶
- lightrft.strategy.fsdp.fsdp_optimizer.offload_fsdp_optimizer(optimizer)¶
Offload optimizer states from GPU to CPU memory to reduce GPU memory usage.
This function moves all tensor-based optimizer states (such as momentum buffers, variance estimates, etc.) from GPU to CPU memory. This is useful for reducing GPU memory pressure during training, especially when using large models or when GPU memory is limited.
- Parameters:
optimizer (torch.optim.Optimizer) – The optimizer whose states should be offloaded to CPU
Example:
# Offload optimizer states to save GPU memory offload_fsdp_optimizer(optimizer) # Later, load back when needed load_fsdp_optimizer(optimizer)
- Note:
After offloading, you should call load_fsdp_optimizer before the next optimization step to ensure states are available on the correct device.
load_fsdp_optimizer¶
- lightrft.strategy.fsdp.fsdp_optimizer.load_fsdp_optimizer(optimizer, device_id=torch.cuda.current_device)¶
Load optimizer states from CPU back to the specified GPU device.
This function moves all tensor-based optimizer states from CPU memory back to the specified GPU device. This is typically used after offload_fsdp_optimizer to restore states for the next optimization step.
- Parameters:
optimizer (torch.optim.Optimizer) – The optimizer whose states should be loaded to GPU
device_id (int or torch.device) – The device ID to load states to (default: current CUDA device)
Example:
# Load optimizer states back to GPU before optimization load_fsdp_optimizer(optimizer, device_id=0) # Or use current device load_fsdp_optimizer(optimizer)
- Note:
This function automatically determines the current device using get_current_device() to ensure compatibility with distributed training setups.