Shortcuts

lightrft.strategy.strategy_base

A module for implementing training strategies in deep learning, particularly for RLVR and RLHF.

This module provides base classes and utilities for different training strategies like DeepSpeed and FSDP. It handles distributed training setup, model/optimizer preparation, checkpointing, and inference engine management.

EngineStatus

enum lightrft.strategy.strategy_base.EngineStatus(value)[source]

Enum class for inference engine status.

Variables:
  • SLEEPED – Engine is in sleep mode

  • WAKEUP – Engine is awake and ready

Valid values are as follows:

SLEEPED = <EngineStatus.SLEEPED: 0>
WAKEUP = <EngineStatus.WAKEUP: 1>

StrategyBase

class lightrft.strategy.strategy_base.StrategyBase(seed: int, max_norm: float, micro_train_batch_size: int, train_batch_size: int, args: Any | None = None)[source]

Base class for training strategies (DeepSpeed and FSDP).

Provides common functionality for distributed training setup, model preparation, optimization, checkpointing, and inference engine management.

Parameters:
  • seed (int) – Random seed for reproducibility

  • max_norm (float) – Maximum gradient norm for clipping

  • micro_train_batch_size (int) – Batch size for each training step

  • train_batch_size (int) – Total batch size for training

  • args (Any) – Additional configuration arguments

__init__(seed: int, max_norm: float, micro_train_batch_size: int, train_batch_size: int, args: Any | None = None) None[source]

Initialize strategy with common parameters.

Parameters:
  • seed (int) – Random seed for reproducibility

  • max_norm (float) – Maximum gradient norm for clipping

  • micro_train_batch_size (int) – Batch size for each training step

  • train_batch_size (int) – Total batch size for training

  • args (Any (usually argparse.Namespace)) – Additional configuration arguments

classmethod _build_multimodal_inputs(all_prompts, all_images, images_num, all_videos, videos_num)[source]

Build multimodal inputs for inference engine (vLLM/SGLang).

This function supports two input formats for images and videos to accommodate different data preprocessing approaches:

Format 1 - Nested List (multi-image/video per prompt already grouped):

all_images = [[img1_a, img1_b], [img2_a], [img3_a, img3_b, img3_c]] images_num = [2, 1, 3] -> all_images[i] is directly used as the image list for prompt i

Format 2 - Flattened List (all images/videos in a single flat list):

all_images = [img1_a, img1_b, img2_a, img3_a, img3_b, img3_c] images_num = [2, 1, 3] -> images are sliced based on images_num: [0:2], [2:3], [3:6]

Parameters:
  • all_prompts – List of text prompts

  • all_images – Images in nested or flattened format, or None

  • images_num – Number of images per prompt

  • all_videos – Videos in nested or flattened format, or None

  • videos_num – Number of videos per prompt

Returns:

List of dicts with ‘prompt’ and optional ‘multi_modal_data’ keys

all_gather(data: torch.Tensor | Dict[str, torch.Tensor]) torch.Tensor | Dict[str, torch.Tensor][source]

Gather data from all distributed processes.

Parameters:

data (Union[torch.Tensor, dict]) – Data to be gathered, can be a tensor or dictionary of tensors

Returns:

Gathered data concatenated from all processes

Return type:

Union[torch.Tensor, dict]

all_reduce(data: torch.Tensor | Dict[str, torch.Tensor], op: str = 'mean') torch.Tensor | Dict[str, torch.Tensor] | float | int[source]

Perform all-reduce operation across distributed processes.

Parameters:
  • data (Union[torch.Tensor, Dict[str, torch.Tensor]]) – Data to be reduced, can be a tensor or dictionary of tensors

  • op (str) – Reduction operation (‘mean’, ‘max’, ‘sum’)

Returns:

Reduced data in the same format as input

Return type:

Union[torch.Tensor, Dict[str, torch.Tensor], float, int]

Raises:

AssertionError – If op is not one of ‘mean’, ‘max’, ‘sum’

abstract backward(loss: torch.Tensor, model: torch.nn.Module, optimizer: torch.optim.Optimizer, **kwargs) None[source]

Perform backward pass.

Parameters:
  • loss (torch.Tensor) – The loss to backpropagate

  • model (nn.Module) – The model

  • optimizer (optim.Optimizer) – The optimizer

  • kwargs – Additional arguments

abstract create_optimizer(model: torch.nn.Module, **kwargs) torch.optim.Optimizer[source]

Create optimizer for the model.

Parameters:
  • model (nn.Module) – The model to optimize

  • kwargs – Additional optimizer arguments

Returns:

The created optimizer

Return type:

optim.Optimizer

engine_generate_local(sampling_params: Any, prompt_token_ids: List[List[int]] | List[int] | None = None, multi_modal_inputs: List[Dict[str, Any]] | None = None) List[EasyDict][source]

Perform text or multimodal generation using different inference engines based on the input mode.

Parameters:
  • sampling_params – Parameters used for controlling the generation process (e.g., temperature, top_k).

  • prompt_token_ids – List of text token IDs.

  • multi_modal_inputs

    A list of dictionaries representing multimodal inputs. Each dictionary should contain a raw text under the “prompt” key, and additional modalities (such as images) under the “multi_modal_data” key. Example: multi_modal_inputs = [{

    ”prompt”: […], “multi_modal_data”: {

    ”image”: […], “video”: […]

    }

    }]

Returns:

A list of generated outputs in EasyDict format, produced by the selected inference engine.

Raises:
  • ValueError – If both prompt_token_ids and multi_modal_inputs are None.

  • ValueError – If both prompt_token_ids and multi_modal_inputs are not None.

gather_and_generate(sampling_params, all_prompt_token_ids=None, all_prompts=None, all_images=None, sleep_engine=True, images_num=None, all_videos=None, videos_num=None)[source]

Gather prompts across distributed ranks and perform text/multimodal generation.

This method coordinates distributed generation by: 1. Gathering prompts from all ranks within a vLLM tensor parallel group 2. Performing batched generation using the inference engine 3. Splitting generated outputs and returning each rank’s portion 4. Optionally putting the inference engine to sleep to conserve memory

For multimodal inputs, supports flexible input formats: - One prompt with one image - One prompt with multiple images - One prompt with video(s) only (no images) - One prompt with one or more videos - Mixed image and video inputs

Parameters:
  • sampling_params (Any) – Parameters controlling generation (e.g., temperature, top_k, max_tokens)

  • all_prompt_token_ids (Optional[List[List[int]]]) – Token IDs for text-only prompts, defaults to None

  • all_prompts (Optional[List[str]]) – Raw text prompts for multimodal generation, defaults to None

  • all_images (Optional[List]) – Images corresponding to prompts for VLM generation, defaults to None

  • sleep_engine (bool) – Whether to sleep the inference engine after generation, defaults to True

  • images_num (Optional[List[int]]) – Number of images per prompt (for multi-image scenarios), defaults to None

  • all_videos (Optional[List]) – Videos corresponding to prompts for video generation, defaults to None

  • videos_num (Optional[List[int]]) – Number of videos per prompt, defaults to None

Returns:

List of generation outputs for the current rank, each containing prompt_token_ids and output_token_ids

Return type:

List[EasyDict]

Raises:

NotImplementedError – If inference engine is not initialized

get_rank() int[source]

Get current process rank.

Returns:

Current process rank

Return type:

int

init_model_context()[source]

Context manager for model initialization.

Currently does nothing by default, used only for DeepSpeed. Reports memory usage after completion.

classmethod is_rank_0() bool[source]

Check if current process is rank 0.

Returns:

True if current process is rank 0

Return type:

bool

abstract load_ckpt(model: torch.nn.Module, load_dir: str, tag: str | None = None, load_module_strict: bool = True, load_optimizer_states: bool = True, load_lr_scheduler_states: bool = True, load_module_only: bool = False) Any[source]

Load training checkpoint with various options.

Parameters:
  • model – The model to load checkpoint into

  • load_dir (str) – Directory containing the checkpoint

  • tag – Optional specific checkpoint tag to load

  • load_module_strict (bool) – Whether to use strict loading for module states, defaults to True

  • load_optimizer_states (bool) – Whether to load optimizer states, defaults to True

  • load_lr_scheduler_states (bool) – Whether to load learning rate scheduler states, defaults to True

  • load_module_only (bool) – Whether to load only the module states, defaults to False

maybe_load_optimizer(optimizer, device=torch.cuda.current_device)[source]

Placeholder for FSDP optimizer loading functionality. :param optimizer: The optimizer to potentially load :type optimizer: torch.optim.Optimizer :param device: Target device for loading :type device: torch.device

maybe_offload_optimizer(optimizer)[source]

Placeholder for FSDP optimizer offloading functionality. :param optimizer: The optimizer to potentially offload :type optimizer: torch.optim.Optimizer

maybe_sleep_inference_engine()[source]

Put the inference engine to sleep if enabled and available.

Sleeps the engine to conserve memory when not in use. Only supports vLLM and SGLang engines. After sleeping, synchronizes and clears the cache.

Raises:

ValueError – If the inference engine type is not supported

abstract optimizer_step(optimizer: torch.optim.Optimizer, model: torch.nn.Module, scheduler: Any, name: str = 'model', **kwargs) None[source]

Take optimizer step.

Parameters:
  • optimizer (optim.Optimizer) – The optimizer

  • model (nn.Module) – The model

  • scheduler – The learning rate scheduler

  • name (str) – Name for logging purposes

  • kwargs – Additional arguments

prepare(*models_or_model_optim_pairs: torch.nn.Module | Tuple[torch.nn.Module, torch.optim.Optimizer], is_rlhf=False) List[torch.nn.Module | Tuple[torch.nn.Module, torch.optim.Optimizer]] | torch.nn.Module | Tuple[torch.nn.Module, torch.optim.Optimizer][source]

Prepare models and optimizers for training.

Parameters:
  • models_or_model_optim_pairs – Models or (model, optimizer) pairs to prepare

  • is_rlhf (bool) – Whether preparing for RLHF training

Returns:

Prepared models/optimizers

Return type:

Union[List[ModelOrModelOptimPair], ModelOrModelOptimPair]

prepare_models_and_optimizers(actor, critic, reward_models, initial_model, args, max_steps)[source]

Prepare models, optimizers and schedulers for training.

Parameters:
  • actor (nn.Module) – Actor model

  • critic (nn.Module) – Critic model

  • reward_models (nn.Module) – Reward models

  • initial_model (nn.Module) – Initial model for reference

  • args (argparse.Namespace) – Training arguments

  • max_steps (int) – Maximum training steps

Returns:

Tuple of prepared models, optimizers, and schedulers

Return type:

tuple

classmethod print(*msg)[source]

Print messages with timestamp, but only on rank 0.

Parameters:

msg – Messages to print

classmethod report_memory(prefix='')[source]

Report GPU memory usage statistics.

Parameters:

prefix (str) – Prefix string for the memory report

abstract save_ckpt(model: torch.nn.Module, save_dir: str, tag: str | None = None, max_num: int = 3, max_mem: int = 1000, client_state: Dict[str, Any] | None = None, save_latest: bool = True) None[source]

Save training checkpoint with additional metadata.

Parameters:
  • model – The model to save

  • save_dir (str) – Directory to save the checkpoint

  • tag – Optional tag for the checkpoint

  • max_num (int) – Maximum number of checkpoints to keep, defaults to 3

  • max_mem (int) – Maximum memory in MB for checkpoints, defaults to 1000

  • client_state (dict) – Additional state to save, defaults to {}

  • save_latest (bool) – Whether to save as latest checkpoint, defaults to True

set_seed(seed: int) None[source]

Set random seeds for reproducibility.

Parameters:

seed (int) – The random seed to use

setup_dataloader(replay_buffer, batch_size: int, pin_memory: bool = False, shuffle: bool = True, collate_fn: Callable | None = None, drop_last: bool = True, sampler: Any | None = None, consumed_samples: int = 0) torch.utils.data.DataLoader[source]

Set up data loader for training.

Parameters:
  • replay_buffer – Dataset/replay buffer

  • batch_size (int) – Batch size

  • pin_memory (bool) – Whether to pin memory

  • shuffle (bool) – Whether to shuffle data

  • collate_fn (Optional[Callable]) – Function to collate samples

  • drop_last (bool) – Whether to drop last incomplete batch

  • sampler – Custom sampler

  • consumed_samples (int) – Number of samples already consumed

Returns:

Configured DataLoader

Return type:

DataLoader

setup_distributed(timeout: timedelta | None = None, num_gpu_per_node: int = 8) None[source]

Initialize distributed training environment.

Parameters:
  • timeout (timedelta, optional) – Maximum time to wait for initialization

  • num_gpu_per_node (int) – Number of GPUs per node

Raises:
  • RuntimeError – If required environment variables are missing

  • ValueError – If unsupported engine type is specified

setup_inference_engine(args, engine_type='vllm', actor=None)[source]

Initialize and setup the inference engine.

Parameters:
  • args (argparse.Namespace) – Configuration arguments

  • engine_type (str) – Type of inference engine (‘vllm’ or ‘sglang’)

  • actor (torch.nn.Module) – The actor module, if passed, will be used to update engine weights

Returns:

Initialized inference engine

Return type:

object

Raises:

ValueError – If engine_type is not supported

classmethod sync_and_clear_cache()[source]

Synchronize CUDA operations and clear the cache.

Performs three operations: 1. CUDA synchronization 2. Distributed barrier 3. CUDA cache clearing

unwrap_model(model) torch.nn.Module[source]

Unwrap model from strategy-specific wrappers.

Parameters:

model (nn.Module) – Model to unwrap

Returns:

Unwrapped model

Return type:

nn.Module

update_engine_weights(actor)[source]

Update the weights of the inference engine from the actor model.

Parameters:

actor – The actor model whose weights will be copied

wakeup_inference_engine()[source]

Wake up the inference engine from sleep state.

To avoid OOM, we:
  1. sync and clear cache

  2. wakeup engine

Raises:

ValueError – If the inference engine type is not supported

is_actor

lightrft.strategy.strategy_base.is_actor(model)[source]

Check if a model is an actor model.

Parameters:

model – The model to check

Returns:

True if the model is an actor, False otherwise

Return type:

bool