lightrft.strategy.strategy_base¶
A module for implementing training strategies in deep learning, particularly for RLVR and RLHF.
This module provides base classes and utilities for different training strategies like DeepSpeed and FSDP. It handles distributed training setup, model/optimizer preparation, checkpointing, and inference engine management.
EngineStatus¶
StrategyBase¶
- class lightrft.strategy.strategy_base.StrategyBase(seed: int, max_norm: float, micro_train_batch_size: int, train_batch_size: int, args: Any | None = None)[source]¶
Base class for training strategies (DeepSpeed and FSDP).
Provides common functionality for distributed training setup, model preparation, optimization, checkpointing, and inference engine management.
- Parameters:
seed (int) – Random seed for reproducibility
max_norm (float) – Maximum gradient norm for clipping
micro_train_batch_size (int) – Batch size for each training step
train_batch_size (int) – Total batch size for training
args (Any) – Additional configuration arguments
- __init__(seed: int, max_norm: float, micro_train_batch_size: int, train_batch_size: int, args: Any | None = None) None[source]¶
Initialize strategy with common parameters.
- Parameters:
seed (int) – Random seed for reproducibility
max_norm (float) – Maximum gradient norm for clipping
micro_train_batch_size (int) – Batch size for each training step
train_batch_size (int) – Total batch size for training
args (Any (usually argparse.Namespace)) – Additional configuration arguments
- classmethod _build_multimodal_inputs(all_prompts, all_images, images_num, all_videos, videos_num)[source]¶
Build multimodal inputs for inference engine (vLLM/SGLang).
This function supports two input formats for images and videos to accommodate different data preprocessing approaches:
- Format 1 - Nested List (multi-image/video per prompt already grouped):
all_images = [[img1_a, img1_b], [img2_a], [img3_a, img3_b, img3_c]] images_num = [2, 1, 3] -> all_images[i] is directly used as the image list for prompt i
- Format 2 - Flattened List (all images/videos in a single flat list):
all_images = [img1_a, img1_b, img2_a, img3_a, img3_b, img3_c] images_num = [2, 1, 3] -> images are sliced based on images_num: [0:2], [2:3], [3:6]
- Parameters:
all_prompts – List of text prompts
all_images – Images in nested or flattened format, or None
images_num – Number of images per prompt
all_videos – Videos in nested or flattened format, or None
videos_num – Number of videos per prompt
- Returns:
List of dicts with ‘prompt’ and optional ‘multi_modal_data’ keys
- all_gather(data: torch.Tensor | Dict[str, torch.Tensor]) torch.Tensor | Dict[str, torch.Tensor][source]¶
Gather data from all distributed processes.
- Parameters:
data (Union[torch.Tensor, dict]) – Data to be gathered, can be a tensor or dictionary of tensors
- Returns:
Gathered data concatenated from all processes
- Return type:
Union[torch.Tensor, dict]
- all_reduce(data: torch.Tensor | Dict[str, torch.Tensor], op: str = 'mean') torch.Tensor | Dict[str, torch.Tensor] | float | int[source]¶
Perform all-reduce operation across distributed processes.
- Parameters:
data (Union[torch.Tensor, Dict[str, torch.Tensor]]) – Data to be reduced, can be a tensor or dictionary of tensors
op (str) – Reduction operation (‘mean’, ‘max’, ‘sum’)
- Returns:
Reduced data in the same format as input
- Return type:
Union[torch.Tensor, Dict[str, torch.Tensor], float, int]
- Raises:
AssertionError – If op is not one of ‘mean’, ‘max’, ‘sum’
- abstract backward(loss: torch.Tensor, model: torch.nn.Module, optimizer: torch.optim.Optimizer, **kwargs) None[source]¶
Perform backward pass.
- Parameters:
loss (torch.Tensor) – The loss to backpropagate
model (nn.Module) – The model
optimizer (optim.Optimizer) – The optimizer
kwargs – Additional arguments
- abstract create_optimizer(model: torch.nn.Module, **kwargs) torch.optim.Optimizer[source]¶
Create optimizer for the model.
- Parameters:
model (nn.Module) – The model to optimize
kwargs – Additional optimizer arguments
- Returns:
The created optimizer
- Return type:
optim.Optimizer
- engine_generate_local(sampling_params: Any, prompt_token_ids: List[List[int]] | List[int] | None = None, multi_modal_inputs: List[Dict[str, Any]] | None = None) List[EasyDict][source]¶
Perform text or multimodal generation using different inference engines based on the input mode.
- Parameters:
sampling_params – Parameters used for controlling the generation process (e.g., temperature, top_k).
prompt_token_ids – List of text token IDs.
multi_modal_inputs –
A list of dictionaries representing multimodal inputs. Each dictionary should contain a raw text under the “prompt” key, and additional modalities (such as images) under the “multi_modal_data” key. Example: multi_modal_inputs = [{
”prompt”: […], “multi_modal_data”: {
”image”: […], “video”: […]
}
}]
- Returns:
A list of generated outputs in EasyDict format, produced by the selected inference engine.
- Raises:
ValueError – If both prompt_token_ids and multi_modal_inputs are None.
ValueError – If both prompt_token_ids and multi_modal_inputs are not None.
- gather_and_generate(sampling_params, all_prompt_token_ids=None, all_prompts=None, all_images=None, sleep_engine=True, images_num=None, all_videos=None, videos_num=None)[source]¶
Gather prompts across distributed ranks and perform text/multimodal generation.
This method coordinates distributed generation by: 1. Gathering prompts from all ranks within a vLLM tensor parallel group 2. Performing batched generation using the inference engine 3. Splitting generated outputs and returning each rank’s portion 4. Optionally putting the inference engine to sleep to conserve memory
For multimodal inputs, supports flexible input formats: - One prompt with one image - One prompt with multiple images - One prompt with video(s) only (no images) - One prompt with one or more videos - Mixed image and video inputs
- Parameters:
sampling_params (Any) – Parameters controlling generation (e.g., temperature, top_k, max_tokens)
all_prompt_token_ids (Optional[List[List[int]]]) – Token IDs for text-only prompts, defaults to None
all_prompts (Optional[List[str]]) – Raw text prompts for multimodal generation, defaults to None
all_images (Optional[List]) – Images corresponding to prompts for VLM generation, defaults to None
sleep_engine (bool) – Whether to sleep the inference engine after generation, defaults to True
images_num (Optional[List[int]]) – Number of images per prompt (for multi-image scenarios), defaults to None
all_videos (Optional[List]) – Videos corresponding to prompts for video generation, defaults to None
videos_num (Optional[List[int]]) – Number of videos per prompt, defaults to None
- Returns:
List of generation outputs for the current rank, each containing prompt_token_ids and output_token_ids
- Return type:
List[EasyDict]
- Raises:
NotImplementedError – If inference engine is not initialized
- init_model_context()[source]¶
Context manager for model initialization.
Currently does nothing by default, used only for DeepSpeed. Reports memory usage after completion.
- classmethod is_rank_0() bool[source]¶
Check if current process is rank 0.
- Returns:
True if current process is rank 0
- Return type:
bool
- abstract load_ckpt(model: torch.nn.Module, load_dir: str, tag: str | None = None, load_module_strict: bool = True, load_optimizer_states: bool = True, load_lr_scheduler_states: bool = True, load_module_only: bool = False) Any[source]¶
Load training checkpoint with various options.
- Parameters:
model – The model to load checkpoint into
load_dir (str) – Directory containing the checkpoint
tag – Optional specific checkpoint tag to load
load_module_strict (bool) – Whether to use strict loading for module states, defaults to True
load_optimizer_states (bool) – Whether to load optimizer states, defaults to True
load_lr_scheduler_states (bool) – Whether to load learning rate scheduler states, defaults to True
load_module_only (bool) – Whether to load only the module states, defaults to False
- maybe_load_optimizer(optimizer, device=torch.cuda.current_device)[source]¶
Placeholder for FSDP optimizer loading functionality. :param optimizer: The optimizer to potentially load :type optimizer: torch.optim.Optimizer :param device: Target device for loading :type device: torch.device
- maybe_offload_optimizer(optimizer)[source]¶
Placeholder for FSDP optimizer offloading functionality. :param optimizer: The optimizer to potentially offload :type optimizer: torch.optim.Optimizer
- maybe_sleep_inference_engine()[source]¶
Put the inference engine to sleep if enabled and available.
Sleeps the engine to conserve memory when not in use. Only supports vLLM and SGLang engines. After sleeping, synchronizes and clears the cache.
- Raises:
ValueError – If the inference engine type is not supported
- abstract optimizer_step(optimizer: torch.optim.Optimizer, model: torch.nn.Module, scheduler: Any, name: str = 'model', **kwargs) None[source]¶
Take optimizer step.
- Parameters:
optimizer (optim.Optimizer) – The optimizer
model (nn.Module) – The model
scheduler – The learning rate scheduler
name (str) – Name for logging purposes
kwargs – Additional arguments
- prepare(*models_or_model_optim_pairs: torch.nn.Module | Tuple[torch.nn.Module, torch.optim.Optimizer], is_rlhf=False) List[torch.nn.Module | Tuple[torch.nn.Module, torch.optim.Optimizer]] | torch.nn.Module | Tuple[torch.nn.Module, torch.optim.Optimizer][source]¶
Prepare models and optimizers for training.
- Parameters:
models_or_model_optim_pairs – Models or (model, optimizer) pairs to prepare
is_rlhf (bool) – Whether preparing for RLHF training
- Returns:
Prepared models/optimizers
- Return type:
Union[List[ModelOrModelOptimPair], ModelOrModelOptimPair]
- prepare_models_and_optimizers(actor, critic, reward_models, initial_model, args, max_steps)[source]¶
Prepare models, optimizers and schedulers for training.
- Parameters:
actor (nn.Module) – Actor model
critic (nn.Module) – Critic model
reward_models (nn.Module) – Reward models
initial_model (nn.Module) – Initial model for reference
args (argparse.Namespace) – Training arguments
max_steps (int) – Maximum training steps
- Returns:
Tuple of prepared models, optimizers, and schedulers
- Return type:
tuple
- classmethod print(*msg)[source]¶
Print messages with timestamp, but only on rank 0.
- Parameters:
msg – Messages to print
- classmethod report_memory(prefix='')[source]¶
Report GPU memory usage statistics.
- Parameters:
prefix (str) – Prefix string for the memory report
- abstract save_ckpt(model: torch.nn.Module, save_dir: str, tag: str | None = None, max_num: int = 3, max_mem: int = 1000, client_state: Dict[str, Any] | None = None, save_latest: bool = True) None[source]¶
Save training checkpoint with additional metadata.
- Parameters:
model – The model to save
save_dir (str) – Directory to save the checkpoint
tag – Optional tag for the checkpoint
max_num (int) – Maximum number of checkpoints to keep, defaults to 3
max_mem (int) – Maximum memory in MB for checkpoints, defaults to 1000
client_state (dict) – Additional state to save, defaults to {}
save_latest (bool) – Whether to save as latest checkpoint, defaults to True
- set_seed(seed: int) None[source]¶
Set random seeds for reproducibility.
- Parameters:
seed (int) – The random seed to use
- setup_dataloader(replay_buffer, batch_size: int, pin_memory: bool = False, shuffle: bool = True, collate_fn: Callable | None = None, drop_last: bool = True, sampler: Any | None = None, consumed_samples: int = 0) torch.utils.data.DataLoader[source]¶
Set up data loader for training.
- Parameters:
replay_buffer – Dataset/replay buffer
batch_size (int) – Batch size
pin_memory (bool) – Whether to pin memory
shuffle (bool) – Whether to shuffle data
collate_fn (Optional[Callable]) – Function to collate samples
drop_last (bool) – Whether to drop last incomplete batch
sampler – Custom sampler
consumed_samples (int) – Number of samples already consumed
- Returns:
Configured DataLoader
- Return type:
DataLoader
- setup_distributed(timeout: timedelta | None = None, num_gpu_per_node: int = 8) None[source]¶
Initialize distributed training environment.
- Parameters:
timeout (timedelta, optional) – Maximum time to wait for initialization
num_gpu_per_node (int) – Number of GPUs per node
- Raises:
RuntimeError – If required environment variables are missing
ValueError – If unsupported engine type is specified
- setup_inference_engine(args, engine_type='vllm', actor=None)[source]¶
Initialize and setup the inference engine.
- Parameters:
args (argparse.Namespace) – Configuration arguments
engine_type (str) – Type of inference engine (‘vllm’ or ‘sglang’)
actor (torch.nn.Module) – The actor module, if passed, will be used to update engine weights
- Returns:
Initialized inference engine
- Return type:
object
- Raises:
ValueError – If engine_type is not supported
- classmethod sync_and_clear_cache()[source]¶
Synchronize CUDA operations and clear the cache.
Performs three operations: 1. CUDA synchronization 2. Distributed barrier 3. CUDA cache clearing
- unwrap_model(model) torch.nn.Module[source]¶
Unwrap model from strategy-specific wrappers.
- Parameters:
model (nn.Module) – Model to unwrap
- Returns:
Unwrapped model
- Return type:
nn.Module