Shortcuts

lightrft.strategy.fsdp.fsdpv2

Hugging Face FSDP (Fully Sharded Data Parallel) Strategy Module.

This module provides implementations for distributed training using PyTorch’s FSDP. It includes utilities for model wrapping, optimization, checkpointing, and state management in a distributed training environment. The module supports FSDP v2 strategy, with special handling for model sharding, mixed precision training, and optimizer state management.

FSDPV2Strategy

class lightrft.strategy.fsdp.fsdpv2.FSDPV2Strategy(seed: int = 42, max_norm: float = 0.0, micro_train_batch_size: int = 1, train_batch_size: int = 1, bf16: bool = True, args=None)[source]

The strategy for training with PyTorch’s Fully Sharded Data Parallel V2.

This strategy implements model sharding using PyTorch’s FSDP to enable training of large models across multiple GPUs with memory efficiency.

Parameters:
  • seed (int) – Random seed for reproducibility.

  • max_norm (float) – Maximum gradient norm for gradient clipping. If 0.0, no clipping is performed.

  • micro_train_batch_size (int) – Batch size for a single training step.

  • train_batch_size (int) – Total batch size for training.

  • bf16 (bool) – Whether to use bfloat16 precision.

  • args (object) – Additional arguments for the strategy.

__init__(seed: int = 42, max_norm: float = 0.0, micro_train_batch_size: int = 1, train_batch_size: int = 1, bf16: bool = True, args=None) None[source]

Initialize the FSDP V2 strategy.

Parameters:
  • seed (int) – Random seed for reproducibility

  • max_norm (float) – Maximum gradient norm for gradient clipping. If 0.0, no clipping is performed

  • micro_train_batch_size (int) – Batch size for a single training step

  • train_batch_size (int) – Total batch size for training

  • bf16 (bool) – Whether to use bfloat16 precision

  • args (object) – Additional arguments for the strategy

backward(loss: torch.Tensor, model: torch.nn.Module, optimizer: torch.optim.Optimizer, **kwargs) None[source]

Perform backward pass for the loss.

Parameters:
  • loss (torch.Tensor) – The loss tensor

  • model (torch.nn.Module) – The model

  • optimizer (torch.optim.Optimizer) – The optimizer

  • kwargs – Additional arguments

create_optimizer(model, **kwargs) torch.optim.Optimizer[source]

Create an optimizer for the model with proper parameter grouping.

Groups parameters by dtype, dtensor shard size, and weight decay to avoid errors during gradient clipping and optimization steps.

Parameters:
  • model (torch.nn.Module) – The model for which to create the optimizer

  • kwargs – Additional arguments for the optimizer, including weight_decay

Returns:

The created optimizer

Return type:

torch.optim.Optimizer

Example:

>>> optimizer = strategy.create_optimizer(model, lr=1e-4, weight_decay=0.01)
init_model_context(meta_init=False)[source]

Context manager for model initialization, for large models it can be initialized on Meta device.

Parameters:

meta_init (bool) – if init on meta device

load_ckpt(model, load_dir, optimizer=None, scheduler=None, load_module_strict=True, load_optimizer_states=True, load_lr_scheduler_states=True, **kwargs)[source]

Load model checkpoints in a distributed environment.

This method loads sharded model weights from disk for each distributed process. It handles the proper loading of FSDP-sharded state dictionaries.

Parameters:
  • model (torch.nn.Module) – The model to load weights into, typically an FSDP-wrapped model

  • optimizer (torch.optim, optional) – The optimizer to load weights into

  • scheduler (torch.lr_scheduler, optional) – The scheduler to load weights into

  • load_dir (str) – Directory containing the checkpoints

  • load_module_strict (bool, default=True) – Whether to strictly enforce that the keys in the model state dict match

  • load_optimizer_states (bool, default=True) – Whether to load optimizer states

  • load_lr_scheduler_states (bool, default=True) – Whether to load learning rate scheduler states

Returns:

A tuple of (load_dir, client_states) where load_dir is the directory from which the checkpoint was loaded and client_states contains additional saved state

Return type:

tuple

Example:

>>> load_dir, client_states = trainer.load_ckpt(model, "checkpoints/step_1000")
maybe_load_optimizer(optimizer, device=torch.cuda.current_device)[source]

Load FSDP optimizer states back to GPU if adam_offload is enabled.

Parameters:
  • optimizer (torch.optim.Optimizer) – The optimizer to potentially load

  • device (torch.device) – The device to load the optimizer to

Returns:

The loaded optimizer if adam_offload is enabled, otherwise the original optimizer

Return type:

torch.optim.Optimizer

maybe_offload_optimizer(optimizer)[source]

Offload FSDP optimizer states to CPU if adam_offload is enabled.

Parameters:

optimizer (torch.optim.Optimizer) – The optimizer to potentially offload

Returns:

The offloaded optimizer if adam_offload is enabled, otherwise the original optimizer

Return type:

torch.optim.Optimizer

optimizer_step(optimizer: torch.optim.Optimizer, model: torch.nn.Module, scheduler, name='model', **kwargs) None[source]

Perform an optimization step.

Handles gradient accumulation by only stepping the optimizer and scheduler after the specified number of accumulation steps.

Parameters:
  • optimizer (torch.optim.Optimizer) – The optimizer

  • model (torch.nn.Module) – The model

  • scheduler – The learning rate scheduler

  • name (str) – Name identifier for the model

  • kwargs – Additional arguments

prepare_model(model, is_training=False, shard_size=-1, reshard_after_forward=True) List[torch.nn.Module | Tuple[torch.nn.Module, torch.optim.Optimizer]] | torch.nn.Module | Tuple[torch.nn.Module, torch.optim.Optimizer][source]

Prepares a model for FSDP training.

Parameters:

model (torch.nn.Module or None) – The model to prepare.

Returns:

The prepared model wrapped with FSDP.

Return type:

torch.nn.Module

Example:

>>> prepared_model = strategy.prepare_model(model)
save_ckpt(model, save_dir, tag=None, max_num=3, max_mem=1000, client_state={}, save_latest=True, optimizer=None, scheduler=None)[source]

Save model checkpoints in a distributed environment with automatic rotation.

This method saves the sharded model weights to disk and manages the number and total size of checkpoints by removing older ones when necessary. It ensures proper synchronization between distributed processes.

Parameters:
  • model (torch.nn.Module) – The model to save, typically an FSDP-wrapped model

  • save_dir (str) – Directory where checkpoints will be saved

  • optimizer (torch.optim, optional) – The optimizer to save

  • scheduler (torch.lr_scheduler, optional) – The scheduler to save

  • tag (str, optional) – Subdirectory name for this specific checkpoint

  • max_num (int, default=3) – Maximum number of checkpoints to keep

  • max_mem (int, default=1000) – Maximum disk space in GB for all checkpoints

  • client_state (dict, default={}) – Additional state to save (not currently used)

  • save_latest (bool, default=True) – Whether to save a copy as the latest checkpoint (not used)

Example:

>>> trainer.save_ckpt(model, "checkpoints", tag="step_1000", max_num=5)
save_model(*args, **kwargs) None[source]

Save the model, its configuration, and tokenizer.

This method handles gathering and saving the full model parameters in a distributed setting. Only rank 0 process saves the model to disk.

Parameters:
  • model (torch.nn.Module) – The model to save

  • tokenizer – The tokenizer to save

  • output_dir (str) – Directory to save the model to

  • kwargs – Additional arguments to pass to model.save_pretrained

unwrap_model(model) torch.nn.Module[source]

Unwraps the model from any wrapper modules.

Parameters:

model (torch.nn.Module) – The model to unwrap.

Returns:

The unwrapped model.

Return type:

torch.nn.Module

Example:

>>> unwrapped_model = strategy.unwrap_model(wrapped_model)

is_mp_optimizer

lightrft.strategy.fsdp.fsdpv2.is_mp_optimizer(optim)[source]

Check if an optimizer is an instance of FSDPadaptOptimizer.

This function determines whether the provided optimizer is a model parallel optimizer specifically designed for FSDP (Fully Sharded Data Parallel).

Parameters:

optim (torch.optim.Optimizer or similar) – The optimizer to check

Returns:

True if the optimizer is an instance of FSDPadaptOptimizer, False otherwise

Return type:

bool

Example:

>>> optimizer = FSDPadaptOptimizer(model_parameters, lr=0.01)
>>> is_mp = is_mp_optimizer(optimizer)
>>> print(is_mp)
True

ModelOptimPair

lightrft.strategy.fsdp.fsdpv2.ModelOptimPair

alias of Tuple[Module, Optimizer]

ModelOrModelOptimPair

lightrft.strategy.fsdp.fsdpv2.ModelOrModelOptimPair

alias of Union[Module, Tuple[Module, Optimizer]]

manual_transformer_cls_names_to_wrap

lightrft.strategy.fsdp.fsdpv2.manual_transformer_cls_names_to_wrap = ['Embedding', 'Qwen2VLDecoderLayer', 'Qwen2VLVisionBlock', 'Qwen2_5_VLVisionBlock', 'Qwen2_5_VLDecoderLayer', 'Qwen2DecoderLayer', 'LlamaDecoderLayer', 'DeepseekDecoderLayer']

Built-in mutable sequence.

If no argument is given, the constructor creates a new empty list. The argument must be an iterable if specified.