lightrft.trainer.spmd_ppo_trainer¶
SPMD (Single Program Multiple Data) PPO Trainer for distributed reinforcement learning.
This module extends the base PPOTrainer with SPMD capabilities, enabling efficient distributed training across multiple devices. It provides specialized implementations for both text-only language models and vision-language models with optimized tensor parallelism and distributed inference using vLLM.
The module includes: - SPMDPPOTrainerBase: Base class with core SPMD functionality - SPMDPPOTrainer: Implementation for Large Language Models (LLMs) - SPMDPPOTrainerVL: Implementation for Vision-Language Models (VLMs)
Key features: - FastExperienceMaker for improved throughput during experience collection - Optimized memory management and communication patterns - Support for both text-only and multi-modal reinforcement learning - Efficient distributed training across multiple devices and nodes
SPMDPPOTrainerBase¶
- class lightrft.trainer.spmd_ppo_trainer.SPMDPPOTrainerBase(*args, loss_agg_mode: str = 'seq-mean-token-mean', use_gspo: bool = False, VLM: bool = False, **kwargs)[source]¶
PPO Trainer implementation optimized for Single Program Multiple Data (SPMD) execution.
This trainer extends the base PPOTrainer with specialized handling for tensor parallelism and distributed inference using vLLM. It includes optimizations for experience collection and training across multiple devices.
The base class provides core functionality for SPMD training including: - FastExperienceMaker integration for improved throughput - Tensor parallelism support with vLLM engine - Optimized memory management during training - Support for both text-only and vision-language models
Note
Performance This implementation uses FastExperienceMaker for improved throughput during experience collection compared to the standard implementation.
Important
Requirements Requires tensor parallelism configuration with engine_tp_size > 0.
- __init__(*args, loss_agg_mode: str = 'seq-mean-token-mean', use_gspo: bool = False, VLM: bool = False, **kwargs)[source]¶
Initialize the SPMD PPO Trainer base class.
Sets up the distributed training environment, creates the experience maker, and configures the policy loss function for SPMD execution.
- Parameters:
args (tuple) – Positional arguments passed to the parent PPOTrainer, including strategy, actor, critic, reward_model, initial_model, etc.
loss_agg_mode (str) – Mode for aggregating policy losses, either “seq-mean-token-mean” or other supported modes
use_gspo (bool) – Whether to enable GSPO (Group Sequence Policy Optimization) mode
VLM (bool) – Whether to use Vision-Language Model mode (True) or Language Model mode (False)
kwargs (dict) – Keyword arguments for configuration including packing_samples, processor, and other parameters.
- Raises:
AssertionError – If engine_tp_size is not properly configured (must be > 0)
Example:
trainer_base = SPMDPPOTrainerBase( strategy, actor_model, critic_model, reward_model, initial_model, ema_model, actor_optim, critic_optim, actor_scheduler, critic_scheduler, loss_agg_mode="seq-mean-token-mean", VLM=False, packing_samples=True )
- ppo_train(global_steps=0)[source]¶
Execute a full PPO training iteration with SPMD optimizations.
This method processes the replay buffer data, trains the actor and critic models for multiple epochs, and updates the inference engine weights. It includes optimized memory management and distributed training coordination.
The training process includes: 1. Data preprocessing for distributed execution 2. Multi-epoch training with experience batching 3. Loss computation and optimization 4. Memory cleanup and weight synchronization
- Parameters:
global_steps (int) – Current global step counter for logging and scheduling
- Returns:
Dictionary of training metrics averaged across all training steps
- Return type:
Dict[str, float]
Example:
metrics = trainer.ppo_train(global_steps=100) print(f"Policy loss: {metrics['policy_loss']}") print(f"Critic loss: {metrics['critic_loss']}")
SPMDPPOTrainer¶
- class lightrft.trainer.spmd_ppo_trainer.SPMDPPOTrainer(*args, **kwargs)[source]¶
PPOTrainer for SPMD on Large Language Models and Multi-modal Large Language Models.
This class combines the SPMD (Single Program Multiple Data) base functionality with the standard PPOTrainer for efficient distributed training of large language models (LLMs) and multi-modal large language models (MLLMs). It supports training across multiple devices and nodes with optimized communication patterns for both text-only and multi-modal reinforcement learning scenarios.
The trainer provides: - Distributed PPO training with tensor parallelism - Efficient experience collection using FastExperienceMaker - Memory-optimized training loops - Support for various loss aggregation modes - Integration with vLLM inference engine
Example:
trainer = SPMDPPOTrainer( strategy=my_strategy, actor=actor_model, critic=critic_model, reward_model=reward_model, initial_model=reference_model, ema_model=ema_model, actor_optim=actor_optimizer, critic_optim=critic_optimizer, actor_scheduler=actor_scheduler, critic_scheduler=critic_scheduler, tokenizer=tokenizer, # Additional PPO parameters max_epochs=5, micro_train_batch_size=16 ) # Train for multiple iterations for step in range(training_steps): trainer.make_experience() metrics = trainer.ppo_train(step)
- __init__(*args, **kwargs)[source]¶
Initialize the SPMD PPO Trainer for language models.
Creates a trainer instance optimized for distributed training of language models using SPMD execution patterns. Inherits from both SPMDPPOTrainerBase and PPOTrainer to combine SPMD optimizations with standard PPO functionality.
- Parameters:
args (tuple) – Positional arguments passed to the parent PPOTrainer including strategy, actor, critic, reward_model, initial_model, ema_model, actor_optim, critic_optim, actor_scheduler, critic_scheduler.
kwargs (dict) – Keyword arguments for configuration including training hyperparameters like max_epochs, micro_train_batch_size, eps_clip, value_clip, etc.
Example:
trainer = SPMDPPOTrainer( strategy, actor_model, critic_model, reward_model, reference_model, ema_model, actor_optimizer, critic_optimizer, actor_scheduler, critic_scheduler, tokenizer=my_tokenizer, loss_agg_mode="seq-mean-token-mean", packing_samples=True, max_epochs=5, micro_train_batch_size=16 )
- ppo_train(global_steps=0)¶
Execute a full PPO training iteration with SPMD optimizations.
This method processes the replay buffer data, trains the actor and critic models for multiple epochs, and updates the inference engine weights. It includes optimized memory management and distributed training coordination.
The training process includes: 1. Data preprocessing for distributed execution 2. Multi-epoch training with experience batching 3. Loss computation and optimization 4. Memory cleanup and weight synchronization
- Parameters:
global_steps (int) – Current global step counter for logging and scheduling
- Returns:
Dictionary of training metrics averaged across all training steps
- Return type:
Dict[str, float]
Example:
metrics = trainer.ppo_train(global_steps=100) print(f"Policy loss: {metrics['policy_loss']}") print(f"Critic loss: {metrics['critic_loss']}")
SPMDPPOTrainerVL¶
- class lightrft.trainer.spmd_ppo_trainer.SPMDPPOTrainerVL(*args, **kwargs)[source]¶
PPOTrainer for SPMD with Vision-Language Models (VLM).
This class combines the SPMD base functionality with the VLM-specific PPOTrainer for efficient distributed training of vision-language models. It extends the standard VLM training capabilities with SPMD optimizations for better performance across multiple devices.
Key features for VLM training: - Multi-modal experience collection and processing - Vision-language specific batch creation - Processor integration for image and text handling - Optimized memory management for large multi-modal models
Example:
trainer = SPMDPPOTrainerVL( strategy=my_strategy, actor=actor_model, critic=critic_model, reward_model=reward_model, initial_model=reference_model, ema_model=ema_model, actor_optim=actor_optimizer, critic_optim=critic_optimizer, actor_scheduler=actor_scheduler, critic_scheduler=critic_scheduler, tokenizer=tokenizer, processor=image_processor, # Required for VLM # Additional PPO parameters max_epochs=5, micro_train_batch_size=16 ) # Train for multiple iterations for step in range(training_steps): trainer.make_experience() metrics = trainer.ppo_train(step)
- __init__(*args, **kwargs)[source]¶
Initialize the SPMD PPO Trainer for vision-language models.
Creates a trainer instance specifically designed for distributed training of vision-language models using SPMD execution patterns. Requires a processor for handling multi-modal inputs.
- Parameters:
args (tuple) – Positional arguments passed to the parent PPOTrainerVL including strategy, actor, critic, reward_model, initial_model, ema_model, actor_optim, critic_optim, actor_scheduler, critic_scheduler.
kwargs (dict) – Keyword arguments for configuration, must include ‘processor’ for image processing along with other training parameters.
- Raises:
AssertionError – If processor is not provided or is None.
Example:
trainer = SPMDPPOTrainerVL( strategy, vlm_actor, vlm_critic, vlm_reward_model, vlm_reference, vlm_ema_model, actor_optimizer, critic_optimizer, actor_scheduler, critic_scheduler, tokenizer=my_tokenizer, processor=my_image_processor, # Required! loss_agg_mode="seq-mean-token-mean", max_epochs=5, micro_train_batch_size=8 )
- ppo_train(global_steps=0)¶
Execute a full PPO training iteration with SPMD optimizations.
This method processes the replay buffer data, trains the actor and critic models for multiple epochs, and updates the inference engine weights. It includes optimized memory management and distributed training coordination.
The training process includes: 1. Data preprocessing for distributed execution 2. Multi-epoch training with experience batching 3. Loss computation and optimization 4. Memory cleanup and weight synchronization
- Parameters:
global_steps (int) – Current global step counter for logging and scheduling
- Returns:
Dictionary of training metrics averaged across all training steps
- Return type:
Dict[str, float]
Example:
metrics = trainer.ppo_train(global_steps=100) print(f"Policy loss: {metrics['policy_loss']}") print(f"Critic loss: {metrics['critic_loss']}")