lightrft.strategy.fsdp.fsdpv2¶
Hugging Face FSDP (Fully Sharded Data Parallel) Strategy Module.
This module provides implementations for distributed training using PyTorch’s FSDP. It includes utilities for model wrapping, optimization, checkpointing, and state management in a distributed training environment. The module supports FSDP v2 strategy, with special handling for model sharding, mixed precision training, and optimizer state management.
FSDPV2Strategy¶
- class lightrft.strategy.fsdp.fsdpv2.FSDPV2Strategy(seed: int = 42, max_norm: float = 0.0, micro_train_batch_size: int = 1, train_batch_size: int = 1, bf16: bool = True, args=None)[source]¶
The strategy for training with PyTorch’s Fully Sharded Data Parallel V2.
This strategy implements model sharding using PyTorch’s FSDP to enable training of large models across multiple GPUs with memory efficiency.
- Parameters:
seed (int) – Random seed for reproducibility.
max_norm (float) – Maximum gradient norm for gradient clipping. If 0.0, no clipping is performed.
micro_train_batch_size (int) – Batch size for a single training step.
train_batch_size (int) – Total batch size for training.
bf16 (bool) – Whether to use bfloat16 precision.
args (object) – Additional arguments for the strategy.
- __init__(seed: int = 42, max_norm: float = 0.0, micro_train_batch_size: int = 1, train_batch_size: int = 1, bf16: bool = True, args=None) None[source]¶
Initialize the FSDP V2 strategy.
- Parameters:
seed (int) – Random seed for reproducibility
max_norm (float) – Maximum gradient norm for gradient clipping. If 0.0, no clipping is performed
micro_train_batch_size (int) – Batch size for a single training step
train_batch_size (int) – Total batch size for training
bf16 (bool) – Whether to use bfloat16 precision
args (object) – Additional arguments for the strategy
- backward(loss: torch.Tensor, model: torch.nn.Module, optimizer: torch.optim.Optimizer, **kwargs) None[source]¶
Perform backward pass for the loss.
- Parameters:
loss (torch.Tensor) – The loss tensor
model (torch.nn.Module) – The model
optimizer (torch.optim.Optimizer) – The optimizer
kwargs – Additional arguments
- create_optimizer(model, **kwargs) torch.optim.Optimizer[source]¶
Create an optimizer for the model with proper parameter grouping.
Groups parameters by dtype, dtensor shard size, and weight decay to avoid errors during gradient clipping and optimization steps.
- Parameters:
model (torch.nn.Module) – The model for which to create the optimizer
kwargs – Additional arguments for the optimizer, including weight_decay
- Returns:
The created optimizer
- Return type:
torch.optim.Optimizer
Example:
>>> optimizer = strategy.create_optimizer(model, lr=1e-4, weight_decay=0.01)
- init_model_context(meta_init=False)[source]¶
Context manager for model initialization, for large models it can be initialized on Meta device.
- Parameters:
meta_init (bool) – if init on meta device
- load_ckpt(model, load_dir, optimizer=None, scheduler=None, load_module_strict=True, load_optimizer_states=True, load_lr_scheduler_states=True, **kwargs)[source]¶
Load model checkpoints in a distributed environment.
This method loads sharded model weights from disk for each distributed process. It handles the proper loading of FSDP-sharded state dictionaries.
- Parameters:
model (torch.nn.Module) – The model to load weights into, typically an FSDP-wrapped model
optimizer (torch.optim, optional) – The optimizer to load weights into
scheduler (torch.lr_scheduler, optional) – The scheduler to load weights into
load_dir (str) – Directory containing the checkpoints
load_module_strict (bool, default=True) – Whether to strictly enforce that the keys in the model state dict match
load_optimizer_states (bool, default=True) – Whether to load optimizer states
load_lr_scheduler_states (bool, default=True) – Whether to load learning rate scheduler states
- Returns:
A tuple of (load_dir, client_states) where load_dir is the directory from which the checkpoint was loaded and client_states contains additional saved state
- Return type:
tuple
Example:
>>> load_dir, client_states = trainer.load_ckpt(model, "checkpoints/step_1000")
- maybe_load_optimizer(optimizer, device=torch.cuda.current_device)[source]¶
Load FSDP optimizer states back to GPU if adam_offload is enabled.
- Parameters:
optimizer (torch.optim.Optimizer) – The optimizer to potentially load
device (torch.device) – The device to load the optimizer to
- Returns:
The loaded optimizer if adam_offload is enabled, otherwise the original optimizer
- Return type:
torch.optim.Optimizer
- maybe_offload_optimizer(optimizer)[source]¶
Offload FSDP optimizer states to CPU if adam_offload is enabled.
- Parameters:
optimizer (torch.optim.Optimizer) – The optimizer to potentially offload
- Returns:
The offloaded optimizer if adam_offload is enabled, otherwise the original optimizer
- Return type:
torch.optim.Optimizer
- optimizer_step(optimizer: torch.optim.Optimizer, model: torch.nn.Module, scheduler, name='model', **kwargs) None[source]¶
Perform an optimization step.
Handles gradient accumulation by only stepping the optimizer and scheduler after the specified number of accumulation steps.
- Parameters:
optimizer (torch.optim.Optimizer) – The optimizer
model (torch.nn.Module) – The model
scheduler – The learning rate scheduler
name (str) – Name identifier for the model
kwargs – Additional arguments
- prepare_model(model, is_training=False, shard_size=-1, reshard_after_forward=True) List[torch.nn.Module | Tuple[torch.nn.Module, torch.optim.Optimizer]] | torch.nn.Module | Tuple[torch.nn.Module, torch.optim.Optimizer][source]¶
Prepares a model for FSDP training.
- Parameters:
model (torch.nn.Module or None) – The model to prepare.
- Returns:
The prepared model wrapped with FSDP.
- Return type:
torch.nn.Module
Example:
>>> prepared_model = strategy.prepare_model(model)
- save_ckpt(model, save_dir, tag=None, max_num=3, max_mem=1000, client_state={}, save_latest=True, optimizer=None, scheduler=None)[source]¶
Save model checkpoints in a distributed environment with automatic rotation.
This method saves the sharded model weights to disk and manages the number and total size of checkpoints by removing older ones when necessary. It ensures proper synchronization between distributed processes.
- Parameters:
model (torch.nn.Module) – The model to save, typically an FSDP-wrapped model
save_dir (str) – Directory where checkpoints will be saved
optimizer (torch.optim, optional) – The optimizer to save
scheduler (torch.lr_scheduler, optional) – The scheduler to save
tag (str, optional) – Subdirectory name for this specific checkpoint
max_num (int, default=3) – Maximum number of checkpoints to keep
max_mem (int, default=1000) – Maximum disk space in GB for all checkpoints
client_state (dict, default={}) – Additional state to save (not currently used)
save_latest (bool, default=True) – Whether to save a copy as the latest checkpoint (not used)
Example:
>>> trainer.save_ckpt(model, "checkpoints", tag="step_1000", max_num=5)
- save_model(*args, **kwargs) None[source]¶
Save the model, its configuration, and tokenizer.
This method handles gathering and saving the full model parameters in a distributed setting. Only rank 0 process saves the model to disk.
- Parameters:
model (torch.nn.Module) – The model to save
tokenizer – The tokenizer to save
output_dir (str) – Directory to save the model to
kwargs – Additional arguments to pass to model.save_pretrained
is_mp_optimizer¶
- lightrft.strategy.fsdp.fsdpv2.is_mp_optimizer(optim)[source]¶
Check if an optimizer is an instance of FSDPadaptOptimizer.
This function determines whether the provided optimizer is a model parallel optimizer specifically designed for FSDP (Fully Sharded Data Parallel).
- Parameters:
optim (torch.optim.Optimizer or similar) – The optimizer to check
- Returns:
True if the optimizer is an instance of FSDPadaptOptimizer, False otherwise
- Return type:
bool
Example:
>>> optimizer = FSDPadaptOptimizer(model_parameters, lr=0.01) >>> is_mp = is_mp_optimizer(optimizer) >>> print(is_mp) True
ModelOptimPair¶
- lightrft.strategy.fsdp.fsdpv2.ModelOptimPair¶
alias of
Tuple[Module,Optimizer]
ModelOrModelOptimPair¶
- lightrft.strategy.fsdp.fsdpv2.ModelOrModelOptimPair¶
alias of
Union[Module,Tuple[Module,Optimizer]]
manual_transformer_cls_names_to_wrap¶
- lightrft.strategy.fsdp.fsdpv2.manual_transformer_cls_names_to_wrap = ['Embedding', 'Qwen2VLDecoderLayer', 'Qwen2VLVisionBlock', 'Qwen2_5_VLVisionBlock', 'Qwen2_5_VLDecoderLayer', 'Qwen2DecoderLayer', 'LlamaDecoderLayer', 'DeepseekDecoderLayer']¶
Built-in mutable sequence.
If no argument is given, the constructor creates a new empty list. The argument must be an iterable if specified.