Source code for lightrft.trainer.spmd_ppo_trainer
"""
SPMD (Single Program Multiple Data) PPO Trainer for distributed reinforcement learning.
This module extends the base PPOTrainer with SPMD capabilities, enabling efficient
distributed training across multiple devices. It provides specialized implementations
for both text-only language models and vision-language models with optimized
tensor parallelism and distributed inference using vLLM.
The module includes:
- SPMDPPOTrainerBase: Base class with core SPMD functionality
- SPMDPPOTrainer: Implementation for Large Language Models (LLMs)
- SPMDPPOTrainerVL: Implementation for Vision-Language Models (VLMs)
Key features:
- FastExperienceMaker for improved throughput during experience collection
- Optimized memory management and communication patterns
- Support for both text-only and multi-modal reinforcement learning
- Efficient distributed training across multiple devices and nodes
"""
import time
import torch
from tqdm import tqdm
from lightrft.trainer import PPOTrainer, PPOTrainerVL
from lightrft.trainer.fast_exp_maker import FastExperienceMaker
from lightrft.utils.trajectory_saver import create_trajectory_saver
from lightrft.trainer.replay_buffer import make_experience_batch
from lightrft.trainer.replay_buffer_vl import make_experience_batch as make_experience_batch_vl
from lightrft.models.utils import create_high_entropy_mask
from lightrft.utils import init_logger
logger = init_logger(__name__)
[docs]class SPMDPPOTrainerBase:
"""
PPO Trainer implementation optimized for Single Program Multiple Data (SPMD) execution.
This trainer extends the base PPOTrainer with specialized handling for tensor parallelism
and distributed inference using vLLM. It includes optimizations for experience collection
and training across multiple devices.
The base class provides core functionality for SPMD training including:
- FastExperienceMaker integration for improved throughput
- Tensor parallelism support with vLLM engine
- Optimized memory management during training
- Support for both text-only and vision-language models
.. note:: Performance
This implementation uses FastExperienceMaker for improved throughput during
experience collection compared to the standard implementation.
.. important:: Requirements
Requires tensor parallelism configuration with engine_tp_size > 0.
"""
[docs] def __init__(
self,
*args,
loss_agg_mode: str = "seq-mean-token-mean",
use_gspo: bool = False,
VLM: bool = False,
**kwargs,
):
"""
Initialize the SPMD PPO Trainer base class.
Sets up the distributed training environment, creates the experience maker,
and configures the policy loss function for SPMD execution.
:param args: Positional arguments passed to the parent PPOTrainer, including strategy, actor, critic,
reward_model, initial_model, etc.
:type args: tuple
:param loss_agg_mode: Mode for aggregating policy losses, either "seq-mean-token-mean" or other supported modes
:type loss_agg_mode: str
:param use_gspo: Whether to enable GSPO (Group Sequence Policy Optimization) mode
:type use_gspo: bool
:param VLM: Whether to use Vision-Language Model mode (True) or Language Model mode (False)
:type VLM: bool
:param kwargs: Keyword arguments for configuration including packing_samples, processor, and other parameters.
:type kwargs: dict
:raises AssertionError: If engine_tp_size is not properly configured (must be > 0)
Example::
trainer_base = SPMDPPOTrainerBase(
strategy,
actor_model,
critic_model,
reward_model,
initial_model,
ema_model,
actor_optim,
critic_optim,
actor_scheduler,
critic_scheduler,
loss_agg_mode="seq-mean-token-mean",
VLM=False,
packing_samples=True
)
"""
self.VLM = VLM # otherwise it's LLM
self.packing_samples = kwargs.pop("packing_samples", False)
self.print_replay_buffer_stats = kwargs.pop("print_replay_buffer_stats", False)
# Note: super().__init__ will be called by child classes
assert self.args.engine_tp_size > 0, "engine_tp_size should be larger than 0"
self.vllm_mp_group = self.strategy.engine_mp_group
self.vllm_engine = self.strategy.inference_engine
torch.distributed.barrier()
# TODO: here we pass a list of concrete params, this may collapse in future versions.
# Create experience maker with appropriate parameters
processor = kwargs.pop("processor", None)
self.experience_maker = FastExperienceMaker(
self.actor,
self.critic,
self.reward_model,
self.initial_model,
self.tokenizer,
self.prompt_max_len,
self.kl_ctl,
self.strategy,
self.remote_rm_url,
self.reward_fn,
self.reward_fn_label_map,
self.reward_recipe,
packing_samples=self.packing_samples,
processor=processor,
)
# Extract high_entropy_token_ratio for entropy-based token filtering
self.high_entropy_token_ratio = kwargs.pop("high_entropy_token_ratio", 0.0)
# Initialize loss function based on mode
policy_loss_kwargs = {"loss_agg_mode": loss_agg_mode, "use_gspo": use_gspo}
if use_gspo:
policy_loss_kwargs.update({
"normalize_advantages": kwargs.get("normalize_advantages", True),
"use_sequence_rewards": kwargs.get("use_sequence_rewards", True)
})
self.use_gspo = use_gspo
# Initialize trajectory saver if enabled
self.trajectory_saver = create_trajectory_saver(self.args, self.tokenizer)
# Validate num_trajectories_to_save parameter if trajectory saving is enabled
if self.trajectory_saver is not None:
if not hasattr(self.args, 'num_trajectories_to_save'):
raise ValueError(
"num_trajectories_to_save must be provided in args when trajectory saving is enabled. "
"Please add --num_trajectories_to_save <value> to your command line arguments."
)
self.num_trajectories_to_save = self.args.num_trajectories_to_save
else:
self.num_trajectories_to_save = None
self.dataloader_pin_memory = False
if torch.distributed.get_rank() == 0:
print(self.args, flush=True)
[docs] def ppo_train(self, global_steps=0): # Currently using this rewritten ppo_train() method
"""
Execute a full PPO training iteration with SPMD optimizations.
This method processes the replay buffer data, trains the actor and critic models
for multiple epochs, and updates the inference engine weights. It includes
optimized memory management and distributed training coordination.
The training process includes:
1. Data preprocessing for distributed execution
2. Multi-epoch training with experience batching
3. Loss computation and optimization
4. Memory cleanup and weight synchronization
:param global_steps: Current global step counter for logging and scheduling
:type global_steps: int
:return: Dictionary of training metrics averaged across all training steps
:rtype: Dict[str, float]
Example::
metrics = trainer.ppo_train(global_steps=100)
print(f"Policy loss: {metrics['policy_loss']}")
print(f"Critic loss: {metrics['critic_loss']}")
"""
torch.cuda.synchronize()
train_begin = time.time()
torch.cuda.empty_cache()
self.strategy.maybe_load_optimizer(self.actor_optim)
all_items = self.strategy.sp_data_processor.preprocess(self.replay_buffer.items)
device = torch.cuda.current_device()
status_list = []
status_mean = {}
for epoch in range(self.max_epochs):
pbar = tqdm(
range(0, len(all_items), self.micro_train_batch_size),
desc=f"Train epoch [{epoch + 1}/{self.max_epochs}]",
disable=not self.strategy.is_rank_0(),
)
for i in pbar:
items = all_items[i:i + self.micro_train_batch_size]
if self.VLM:
experience = make_experience_batch_vl(items, packing_samples=self.packing_samples)
else:
experience = make_experience_batch(items, packing_samples=self.packing_samples)
experience.to_device(device)
# ======================================================================================
# Validate data BEFORE calling training_step to prevent execution path divergence
# If validation is done inside training_step, different ranks may follow different code paths
# (some return early, others continue), causing deadlock in collective communication ops.
# Step 1: Each rank validates its local data
should_skip_local = False
if self.VLM and hasattr(self, '_validate_qwen_vl_tensors'):
# Call the same validation logic used in training_step_actor
sequences = experience.sequences
pixel_values = experience.pixel_values
# Validate before any forward pass
is_valid = self._validate_qwen_vl_tensors(
sequences, pixel_values, context="pre_training_validation"
)
should_skip_local = not is_valid
# Step 2: Synchronize skip decision across all ranks via all_reduce
# This ensures all ranks agree on whether to skip, preventing execution divergence
skip_flag = torch.tensor([1.0 if should_skip_local else 0.0], device=device)
torch.distributed.all_reduce(skip_flag, op=torch.distributed.ReduceOp.MAX)
# Step 3: Collectively skip if ANY rank detected invalid data
if skip_flag.item() > 0:
if self.strategy.is_rank_0():
pbar.set_description(f"Train epoch [{epoch + 1}/{self.max_epochs}] (skipping invalid batch)")
continue # All ranks skip together - no deadlock
# ======================================================================================
# Create entropy_mask if high_entropy_token_ratio > 0 and action_entropy is available
entropy_mask = None
if hasattr(experience, 'action_entropy') and experience.action_entropy is not None:
if self.high_entropy_token_ratio > 0.0:
entropy_mask = create_high_entropy_mask(
experience.action_entropy, experience.action_mask, self.high_entropy_token_ratio
)
# Call training_step which will handle both GSPO and standard modes
status = self.training_step(experience, global_steps, entropy_mask=entropy_mask)
# for DP
# weighted mean for kl
if "kl" in status:
status["kl"] *= status["response_length"]
status = self.strategy.all_reduce(status)
status["kl"] /= status["response_length"]
# Training epoch progress bar: show per-batch metrics for detailed monitoring
short_status = {}
if "policy_loss" in status:
short_status = {
"pg": status["policy_loss"], # policy gradient loss
"rm": status["reward"], # per-batch reward (instantaneous)
"ret": status["return"], # per-batch return (instantaneous)
"glen": status["response_length"], # per-batch response length
"tlen": status["total_length"], # per-batch total length
"kl": status["kl"], # KL divergence
"act_lr": status["actor_lr"], # actor learning rate
}
if "critic_loss" in status:
short_status["cri"] = status["critic_loss"]
short_status["vals"] = status["values"]
short_status["cri_lr"] = status["critic_lr"]
if "ptx_loss" in status:
short_status["ptx"] = status["ptx_loss"]
status_list.append(status)
pbar.set_postfix(short_status)
# Short status keys added for progress bar display:
# "pg": policy_loss
# "rm": reward
# "ret": return
# "glen": response_length
# "tlen": total_length
# "kl": KL divergence
# "act_lr": actor_lr
if status_list:
status_mean = status_list[0]
for m in status_list[1:]:
for k, v in m.items():
status_mean[k] += v
for k in status_mean.keys():
status_mean[k] /= len(status_list)
# ========== Aggregate step-level reward metrics from replay buffer ==========
# NOTE: These metrics are aggregated from ALL experiences in the current step's
# replay buffer (e.g., 640 experiences if rollout_batch_size=128, n_samples=5).
# They represent the TRUE statistics of the rollout phase, NOT the training phase
# micro-batch averages which are less representative.
#
# Naming convention:
# - "*_mean" suffix: mean across all experiences in this step
# - "step_*" prefix: clarifies this is per-step aggregation, not per-episode
if self.replay_buffer.items:
all_rewards = []
all_format_rewards = []
all_accuracy_rewards = []
all_model_rewards = []
all_rule_rewards = []
all_advantages = []
all_returns = []
all_response_lengths = []
for item in self.replay_buffer.items:
# Collect rewards
if hasattr(item, 'info') and item.info is not None and 'reward' in item.info:
all_rewards.append(item.info['reward'])
# Collect detailed reward metrics from info dict
if hasattr(item, 'info') and item.info is not None and 'reward_metrics' in item.info:
reward_metrics = item.info['reward_metrics']
if 'format_reward' in reward_metrics:
all_format_rewards.append(reward_metrics['format_reward'])
if 'accuracy_reward' in reward_metrics:
all_accuracy_rewards.append(reward_metrics['accuracy_reward'])
if 'model_reward' in reward_metrics:
all_model_rewards.append(reward_metrics['model_reward'])
if 'rule_reward' in reward_metrics:
all_rule_rewards.append(reward_metrics['rule_reward'])
# Collect advantages and returns
if hasattr(item, 'advantages') and item.advantages is not None:
all_advantages.append(item.advantages)
if hasattr(item, 'returns') and item.returns is not None:
all_returns.append(item.returns)
if hasattr(item, 'info') and item.info is not None and 'response_length' in item.info:
all_response_lengths.append(item.info['response_length'])
# Compute statistics
# [TENSOR-FIX] Handle both tensor lists and scalar lists for all reward types
if all_rewards:
# Handle both tensor lists (from batched rewards) and scalar lists
if isinstance(all_rewards[0], torch.Tensor):
rewards_tensor = torch.cat([t.to(device).float() for t in all_rewards])
else:
rewards_tensor = torch.tensor(all_rewards, dtype=torch.float32, device=device)
# Use "step_*" prefix to clarify this is per-step aggregation, not per-episode
status_mean["step_reward_mean"] = rewards_tensor.mean().item()
status_mean["step_reward_std"] = rewards_tensor.std().item()
status_mean["step_reward_max"] = rewards_tensor.max().item()
status_mean["step_reward_min"] = rewards_tensor.min().item()
if all_format_rewards:
# [TENSOR-FIX] Handle both tensor lists and scalar lists
if isinstance(all_format_rewards[0], torch.Tensor):
format_tensor = torch.cat([t.to(device).float() for t in all_format_rewards])
else:
format_tensor = torch.tensor(all_format_rewards, dtype=torch.float32, device=device)
status_mean["format_reward_mean"] = format_tensor.mean().item()
status_mean["format_reward_std"] = format_tensor.std().item()
if all_accuracy_rewards:
# [TENSOR-FIX] Handle both tensor lists and scalar lists
if isinstance(all_accuracy_rewards[0], torch.Tensor):
accuracy_tensor = torch.cat([t.to(device).float() for t in all_accuracy_rewards])
else:
accuracy_tensor = torch.tensor(all_accuracy_rewards, dtype=torch.float32, device=device)
status_mean["accuracy_reward_mean"] = accuracy_tensor.mean().item()
status_mean["accuracy_reward_std"] = accuracy_tensor.std().item()
if all_model_rewards:
# [TENSOR-FIX] Handle both tensor lists and scalar lists
if isinstance(all_model_rewards[0], torch.Tensor):
model_tensor = torch.cat([t.to(device).float() for t in all_model_rewards])
else:
model_tensor = torch.tensor(all_model_rewards, dtype=torch.float32, device=device)
if model_tensor.abs().sum() > 0: # Only log if model rewards are non-zero
status_mean["model_reward_mean"] = model_tensor.mean().item()
self.strategy.print(f" model_reward_mean: {status_mean['model_reward_mean']}")
if all_rule_rewards:
# [TENSOR-FIX] Handle both tensor lists and scalar lists
if isinstance(all_rule_rewards[0], torch.Tensor):
rule_tensor = torch.cat([t.to(device).float() for t in all_rule_rewards])
else:
rule_tensor = torch.tensor(all_rule_rewards, dtype=torch.float32, device=device)
if rule_tensor.abs().sum() > 0: # Only log if rule rewards are non-zero
status_mean["rule_reward_mean"] = rule_tensor.mean().item()
self.strategy.print(f"rule_reward_mean: {status_mean['rule_reward_mean']}")
# For advantages, returns, and lengths, they are already lists of tensors,
# so torch.cat() is the correct function to use.
if all_advantages:
advantages_tensor = torch.cat(all_advantages)
status_mean["advantages_mean"] = advantages_tensor.mean().item()
status_mean["advantages_std"] = advantages_tensor.std().item()
status_mean["advantages_max"] = advantages_tensor.max().item()
status_mean["advantages_min"] = advantages_tensor.min().item()
if all_returns:
returns_tensor = torch.cat(all_returns)
status_mean["returns_mean"] = returns_tensor.mean().item()
status_mean["returns_std"] = returns_tensor.std().item()
if all_response_lengths:
# [TENSOR-FIX] Handle both tensor lists and scalar lists
if isinstance(all_response_lengths[0], torch.Tensor):
lengths_tensor = torch.cat([t.to(device).float() for t in all_response_lengths])
else:
lengths_tensor = torch.tensor(all_response_lengths, dtype=torch.float32, device=device)
status_mean["response_length_mean"] = lengths_tensor.float().mean().item()
status_mean["response_length_std"] = lengths_tensor.float().std().item()
# Print detailed reward breakdown (only on rank 0)
if self.print_replay_buffer_stats and self.strategy.is_rank_0():
self.strategy.print("\n" + "=" * 60)
self.strategy.print("📊 Detailed Step Statistics")
self.strategy.print("=" * 60)
if all_rewards:
self.strategy.print(
f"🎁 Total Reward: {status_mean['step_reward_mean']:.4f} ± {status_mean['step_reward_std']:.4f} " # noqa
f"(min={status_mean['step_reward_min']:.4f}, max={status_mean['step_reward_max']:.4f})"
)
if all_format_rewards:
self.strategy.print(
f"📝 Format Reward: {status_mean['format_reward_mean']:.4f} ± {status_mean['format_reward_std']:.4f}" # noqa
)
if all_accuracy_rewards:
self.strategy.print(
f"✅ Accuracy Reward: {status_mean['accuracy_reward_mean']:.4f} ± {status_mean['accuracy_reward_std']:.4f}" # noqa
)
if all_advantages:
self.strategy.print(
f"📈 Advantages: {status_mean['advantages_mean']:.4f} ± {status_mean['advantages_std']:.4f} " # noqa
f"(min={status_mean['advantages_min']:.4f}, max={status_mean['advantages_max']:.4f})"
)
if all_returns:
self.strategy.print(
f"💰 Returns: {status_mean['returns_mean']:.4f} ± {status_mean['returns_std']:.4f}"
)
if all_response_lengths:
self.strategy.print(
f"📏 Response Length: {status_mean['response_length_mean']:.1f} ± {status_mean['response_length_std']:.1f} tokens" # noqa
)
self.strategy.print("=" * 60 + "\n")
torch.cuda.empty_cache()
self.strategy.maybe_offload_optimizer(self.actor_optim)
torch.cuda.synchronize()
torch.cuda.empty_cache()
self.strategy.print(f"PPO Train TIMECOST {time.time() - train_begin}")
self.strategy.report_memory("after train, opt offloaded, before update weights")
self.strategy.print(torch.cuda.memory_summary())
self.strategy.update_engine_weights(self.actor)
# Save trajectories at the end of ppo_train, BEFORE replay buffer is cleared
# This ensures we have data to save when trajectory saving is enabled
if global_steps % self.args.save_steps == 0:
self.save_trajectories(global_steps)
return status_mean
def save_trajectories(self, global_step: int):
"""
Save experience trajectories if trajectory saving is enabled.
This method is called during checkpoint saving to store sample trajectories
for debugging and analysis purposes. If trajectory analysis is enabled,
it also logs statistics to wandb.
:param global_step: Current global training step
:type global_step: int
"""
if self.trajectory_saver is not None and self.replay_buffer.items:
# Check if trajectory analysis is enabled
output_path, stats = self.trajectory_saver.save_trajectories(
experiences=self.replay_buffer.items,
step=global_step,
num_samples=self.num_trajectories_to_save,
prefix="trajectories",
compute_stats=self.args.trajectory_analysis
)
# Log statistics to wandb if available
if stats and self.args.trajectory_analysis and hasattr(self, 'strategy') and self.strategy.is_rank_0():
# Try to get wandb from strategy or parent class
if hasattr(self.strategy, 'args') and self.strategy.args.use_wandb:
try:
import wandb
if wandb.run is not None:
# Prefix with train/ for consistency
wandb_stats = {f"train/{k}": v for k, v in stats.items()}
wandb_stats["train/global_step"] = global_step
wandb.log(wandb_stats, step=global_step)
except (ImportError, AttributeError):
pass
[docs]class SPMDPPOTrainer(SPMDPPOTrainerBase, PPOTrainer):
"""
PPOTrainer for SPMD on Large Language Models and Multi-modal Large Language Models.
This class combines the SPMD (Single Program Multiple Data) base functionality with the
standard PPOTrainer for efficient distributed training of large language models (LLMs)
and multi-modal large language models (MLLMs). It supports training across multiple
devices and nodes with optimized communication patterns for both text-only and
multi-modal reinforcement learning scenarios.
The trainer provides:
- Distributed PPO training with tensor parallelism
- Efficient experience collection using FastExperienceMaker
- Memory-optimized training loops
- Support for various loss aggregation modes
- Integration with vLLM inference engine
Example::
trainer = SPMDPPOTrainer(
strategy=my_strategy,
actor=actor_model,
critic=critic_model,
reward_model=reward_model,
initial_model=reference_model,
ema_model=ema_model,
actor_optim=actor_optimizer,
critic_optim=critic_optimizer,
actor_scheduler=actor_scheduler,
critic_scheduler=critic_scheduler,
tokenizer=tokenizer,
# Additional PPO parameters
max_epochs=5,
micro_train_batch_size=16
)
# Train for multiple iterations
for step in range(training_steps):
trainer.make_experience()
metrics = trainer.ppo_train(step)
"""
[docs] def __init__(
self,
*args,
**kwargs,
):
"""
Initialize the SPMD PPO Trainer for language models.
Creates a trainer instance optimized for distributed training of language models
using SPMD execution patterns. Inherits from both SPMDPPOTrainerBase and PPOTrainer
to combine SPMD optimizations with standard PPO functionality.
:param args: Positional arguments passed to the parent PPOTrainer including strategy, actor, critic,
reward_model, initial_model, ema_model, actor_optim, critic_optim, actor_scheduler,
critic_scheduler.
:type args: tuple
:param kwargs: Keyword arguments for configuration including training hyperparameters like max_epochs,
micro_train_batch_size, eps_clip, value_clip, etc.
:type kwargs: dict
Example::
trainer = SPMDPPOTrainer(
strategy,
actor_model,
critic_model,
reward_model,
reference_model,
ema_model,
actor_optimizer,
critic_optimizer,
actor_scheduler,
critic_scheduler,
tokenizer=my_tokenizer,
loss_agg_mode="seq-mean-token-mean",
packing_samples=True,
max_epochs=5,
micro_train_batch_size=16
)
"""
# First initialize the PPOTrainer parent
PPOTrainer.__init__(self, *args, **kwargs)
# Then initialize our base class
SPMDPPOTrainerBase.__init__(self, *args, VLM=False, **kwargs)
[docs]class SPMDPPOTrainerVL(SPMDPPOTrainerBase, PPOTrainerVL):
"""
PPOTrainer for SPMD with Vision-Language Models (VLM).
This class combines the SPMD base functionality with the VLM-specific PPOTrainer
for efficient distributed training of vision-language models. It extends the standard
VLM training capabilities with SPMD optimizations for better performance across
multiple devices.
Key features for VLM training:
- Multi-modal experience collection and processing
- Vision-language specific batch creation
- Processor integration for image and text handling
- Optimized memory management for large multi-modal models
Example::
trainer = SPMDPPOTrainerVL(
strategy=my_strategy,
actor=actor_model,
critic=critic_model,
reward_model=reward_model,
initial_model=reference_model,
ema_model=ema_model,
actor_optim=actor_optimizer,
critic_optim=critic_optimizer,
actor_scheduler=actor_scheduler,
critic_scheduler=critic_scheduler,
tokenizer=tokenizer,
processor=image_processor, # Required for VLM
# Additional PPO parameters
max_epochs=5,
micro_train_batch_size=16
)
# Train for multiple iterations
for step in range(training_steps):
trainer.make_experience()
metrics = trainer.ppo_train(step)
"""
[docs] def __init__(
self,
*args,
**kwargs,
):
"""
Initialize the SPMD PPO Trainer for vision-language models.
Creates a trainer instance specifically designed for distributed training of
vision-language models using SPMD execution patterns. Requires a processor
for handling multi-modal inputs.
:param args: Positional arguments passed to the parent PPOTrainerVL including strategy, actor, critic,
reward_model, initial_model, ema_model, actor_optim, critic_optim, actor_scheduler,
critic_scheduler.
:type args: tuple
:param kwargs: Keyword arguments for configuration, must include 'processor' for image processing along
with other training parameters.
:type kwargs: dict
:raises AssertionError: If processor is not provided or is None.
Example::
trainer = SPMDPPOTrainerVL(
strategy,
vlm_actor,
vlm_critic,
vlm_reward_model,
vlm_reference,
vlm_ema_model,
actor_optimizer,
critic_optimizer,
actor_scheduler,
critic_scheduler,
tokenizer=my_tokenizer,
processor=my_image_processor, # Required!
loss_agg_mode="seq-mean-token-mean",
max_epochs=5,
micro_train_batch_size=8
)
"""
# First initialize the PPOTrainerVL parent
PPOTrainerVL.__init__(self, *args, **kwargs)
# Then initialize our base class
assert "processor" in kwargs and kwargs["processor"] is not None, "processor is required for SPMDPPOTrainerVL"
SPMDPPOTrainerBase.__init__(self, *args, VLM=True, **kwargs)