Source code for lightrft.trainer.ppo_trainer_vl
import os
import sys
import os.path
from abc import ABC
from typing import Any, Callable, Dict, List, Optional
import torch
import math
import torch.nn as nn
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from tqdm import tqdm
from lightrft.models import ActorVL, GPTLMLoss, PolicyLoss, ValueLoss
from lightrft.models.actor_modality import ActorModality, get_supported_parameters
from lightrft.models.utils import masked_mean, unpacking_samples, compute_approx_kl
from lightrft.utils.distributed_sampler import DistributedSampler
from lightrft.trainer import AdaptiveKLController, ExperienceVL, FixedKLController, NaiveExperienceMakerVL, NaiveReplayBufferVL # noqa
[docs]class PPOTrainerVL(ABC):
"""
Trainer for Proximal Policy Optimization (PPO) algorithm for Vision-Language Models.
:param strategy: The training strategy to use.
:type strategy: Strategy
:param actor: The actor model in the PPO algorithm.
:type actor: ActorVL
:param critic: The critic model in the PPO algorithm.
:type critic: nn.Module
:param reward_model: The reward model for calculating rewards in the RLHF setup.
:type reward_model: nn.Module
:param initial_model: The initial model for reference logits to limit actor updates in RLHF.
:type initial_model: ActorVL
:param ema_model: The exponential moving average model for stable training.
:type ema_model: ActorVL
:param actor_optim: The optimizer for the actor model.
:type actor_optim: Optimizer
:param critic_optim: The optimizer for the critic model.
:type critic_optim: Optimizer
:param actor_scheduler: The learning rate scheduler for the actor.
:type actor_scheduler: Scheduler
:param critic_scheduler: The learning rate scheduler for the critic.
:type critic_scheduler: Scheduler
:param ema_beta: EMA decay rate for model stability, defaults to 0.992.
:type ema_beta: float
:param init_kl_coef: Initial coefficient for KL divergence, defaults to 0.001.
:type init_kl_coef: float
:param kl_target: Target value for KL divergence, defaults to None.
:type kl_target: float, optional
:param kl_horizon: Horizon for KL annealing, defaults to 10000.
:type kl_horizon: int
:param ptx_coef: Coefficient for supervised loss from pre-trained data, defaults to 0.
:type ptx_coef: float
:param micro_train_batch_size: Micro-batch size for actor training, defaults to 8.
:type micro_train_batch_size: int
:param buffer_limit: Maximum size of the replay buffer, defaults to 0.
:type buffer_limit: int
:param buffer_cpu_offload: If True, offloads replay buffer to CPU, defaults to True.
:type buffer_cpu_offload: bool
:param eps_clip: Clipping coefficient for policy loss, defaults to 0.2.
:type eps_clip: float
:param value_clip: Clipping coefficient for value function loss, defaults to 0.2.
:type value_clip: float
:param micro_rollout_batch_size: Micro-batch size for generating rollouts, defaults to 8.
:type micro_rollout_batch_size: int
:param gradient_checkpointing: If True, enables gradient checkpointing, defaults to False.
:type gradient_checkpointing: bool
:param max_epochs: Number of epochs to train, defaults to 1.
:type max_epochs: int
:param max_norm: Maximum gradient norm for gradient clipping, defaults to 1.0.
:type max_norm: float
:param tokenizer: Tokenizer for input data, defaults to None.
:type tokenizer: Callable, optional
:param processor: Processor for multimodal input data, defaults to None.
:type processor: Callable, optional
:param prompt_max_len: Maximum length for prompts, defaults to 128.
:type prompt_max_len: int
:param dataloader_pin_memory: If True, pins memory in the data loader, defaults to True.
:type dataloader_pin_memory: bool
:param remote_rm_url: URL for remote reward model API, defaults to None.
:type remote_rm_url: str, optional
:param reward_fn: Custom reward function for computing rewards, defaults to None.
:type reward_fn: Callable, optional
:param reward_fn_label_map: Label mapping for reward function, defaults to None.
:type reward_fn_label_map: dict, optional
:param reward_recipe: Recipe configuration for reward computation, defaults to None.
:type reward_recipe: dict, optional
:param save_hf_ckpt: Whether to save huggingface-format model weight, defaults to False.
:type save_hf_ckpt: bool
:param disable_ds_ckpt: Whether not to save deepspeed-format model weight (used for training recovery).
:type disable_ds_ckpt: bool
:param generate_kwargs: Additional arguments for model generation.
:type generate_kwargs: dict
"""
def __init__(
self,
strategy,
actor: ActorVL,
critic: nn.Module,
reward_model: nn.Module,
initial_model: ActorVL,
ema_model: ActorVL,
actor_optim: Optimizer,
critic_optim: Optimizer,
actor_scheduler,
critic_scheduler,
ema_beta: float = 0.992,
init_kl_coef: float = 0.001,
kl_target: float = None,
kl_horizon: int = 10000,
ptx_coef: float = 0,
micro_train_batch_size: int = 8,
buffer_limit: int = 0,
buffer_cpu_offload: bool = True,
eps_clip: float = 0.2,
value_clip: float = 0.2,
micro_rollout_batch_size: int = 8,
gradient_checkpointing: bool = False,
max_epochs: int = 1,
max_norm: float = 1.0,
tokenizer: Optional[Callable[[Any], dict]] = None,
processor: Optional[Callable[[Any], dict]] = None,
prompt_max_len: int = 128,
dataloader_pin_memory: bool = True,
remote_rm_url: str = None,
reward_fn: Callable[[List[torch.Tensor]], torch.Tensor] = None,
reward_fn_label_map: dict = None,
reward_recipe: dict = None,
save_hf_ckpt: bool = False,
disable_ds_ckpt: bool = False,
**generate_kwargs,
) -> None:
assert (
not isinstance(reward_model, List) or len(reward_model) == 1 or reward_fn is not None
), "reward_fn must be specified if using multiple reward models"
ABC.__init__(self)
self.strategy = strategy
self.args = strategy.args
self.save_hf_ckpt = save_hf_ckpt
current_filename = os.path.basename(__file__)
current_lineno = sys._getframe().f_lineno
self.strategy.print(f"[{current_filename}:{current_lineno}]")
self.disable_ds_ckpt = disable_ds_ckpt
self.micro_rollout_batch_size = micro_rollout_batch_size
self.max_epochs = max_epochs
self.tokenizer = tokenizer
self.processor = processor
self.generate_kwargs = generate_kwargs
self.dataloader_pin_memory = dataloader_pin_memory
self.max_norm = max_norm
self.ptx_coef = ptx_coef
self.micro_train_batch_size = micro_train_batch_size
self.kl_target = kl_target
self.prompt_max_len = prompt_max_len
self.ema_beta = ema_beta
self.gradient_checkpointing = gradient_checkpointing
self.reward_fn = reward_fn
self.reward_fn_label_map = reward_fn_label_map
self.reward_recipe = reward_recipe
self.actor = actor
self.critic = critic
self.reward_model = reward_model
self.remote_rm_url = remote_rm_url
self.initial_model = initial_model
self.ema_model = ema_model
self.actor_optim = actor_optim
self.critic_optim = critic_optim
self.actor_scheduler = actor_scheduler
self.critic_scheduler = critic_scheduler
# Cache actor's supported parameters based on its modality
# Default to VISION_LANGUAGE for backward compatibility with models without modality attribute
actor_modality = self.actor.modality
self._actor_supported_params = get_supported_parameters(actor_modality)
self.actor_loss_fn = PolicyLoss(eps_clip, use_cpg_loss=self.args.use_cpg_loss)
self.critic_loss_fn = ValueLoss(value_clip)
self.ptx_loss_fn = GPTLMLoss()
self.freezing_actor_steps = getattr(self.args, "freezing_actor_steps", -1)
self.aux_loss = self.args.aux_loss_coef > 1e-8
if self.kl_target:
self.kl_ctl = AdaptiveKLController(init_kl_coef, kl_target, kl_horizon)
else:
self.kl_ctl = FixedKLController(init_kl_coef)
self.experience_maker = NaiveExperienceMakerVL(
actor,
critic,
reward_model,
initial_model,
tokenizer,
processor,
prompt_max_len,
self.kl_ctl,
strategy,
remote_rm_url,
reward_fn,
)
packing_samples = getattr(self.args, "packing_samples", False)
self.replay_buffer = NaiveReplayBufferVL(
micro_train_batch_size, buffer_limit, buffer_cpu_offload, packing_samples
)
# Initialize wandb/tensorboard for logging
self._wandb = None
self._tensorboard = None
self.eval_step_counter = 0 # Independent counter for eval X-axis
self.wandb_log_counter = 0 # Global counter for unique wandb system steps
if self.strategy.args.use_wandb and self.strategy.is_rank_0():
import wandb
self._wandb = wandb
if not wandb.api.api_key:
wandb.login(key=strategy.args.use_wandb)
wandb.init(
entity=strategy.args.wandb_org,
project=strategy.args.wandb_project,
group=strategy.args.wandb_group,
name=strategy.args.wandb_run_name,
config=strategy.args.__dict__,
reinit=True,
)
# Define custom metrics to allow different X-axes
# rollout/* and train/* use the main training step
wandb.define_metric("rollout/global_step")
wandb.define_metric("rollout/*", step_metric="rollout/global_step")
wandb.define_metric("train/global_step")
wandb.define_metric("train/*", step_metric="train/global_step")
# eval/* uses its own counter, allowing it to be plotted sequentially
# even if evaluations happen rarely
wandb.define_metric("eval/global_step")
wandb.define_metric("eval/*", step_metric="eval/global_step")
# Initialize TensorBoard writer if wandb is not available
if self.strategy.args.use_tensorboard and self._wandb is None and self.strategy.is_rank_0():
from torch.utils.tensorboard import SummaryWriter
os.makedirs(self.strategy.args.use_tensorboard, exist_ok=True)
log_dir = os.path.join(self.strategy.args.use_tensorboard, strategy.args.wandb_run_name)
self._tensorboard = SummaryWriter(log_dir=log_dir)
[docs] def fit(
self,
args,
prompts_dataloader,
pretrain_dataloader,
eval_dataloader=None,
consumed_samples=0,
num_update_steps_per_episodes=1,
) -> None:
"""
Main training loop for PPO.
:param args: Training arguments.
:type args: Namespace
:param prompts_dataloader: DataLoader for prompt data.
:type prompts_dataloader: DataLoader
:param pretrain_dataloader: DataLoader for pre-training data.
:type pretrain_dataloader: DataLoader
:param eval_dataloader: DataLoader for evaluation data, defaults to None.
:type eval_dataloader: DataLoader, optional
:param consumed_samples: Number of samples already consumed, defaults to 0.
:type consumed_samples: int
:param num_update_steps_per_episodes: Number of update steps per episode, defaults to 1.
:type num_update_steps_per_episodes: int
"""
# Calculate samples per rollout and per training iteration
samples_per_rollout = args.rollout_batch_size * args.n_samples_per_prompt
samples_per_train = args.train_batch_size * args.n_samples_per_prompt
# Print training mode information
if args.train_batch_size < args.rollout_batch_size:
updates_per_rollout = samples_per_rollout / samples_per_train
self.strategy.print(
f"\n{'=' * 80}\n"
f"HIGH FREQUENCY UPDATE MODE: train_batch_size ({args.train_batch_size}) < rollout_batch_size ({args.rollout_batch_size})\n" # noqa
f"{'=' * 80}\n"
f"Behavior:\n"
f" - Each rollout generates {samples_per_rollout} samples.\n"
f" - Each rollout will trigger {updates_per_rollout:.2f} optimizer updates.\n"
f" - Total updates will be HIGHER than standard mode for the same amount of data.\n"
f"{'=' * 80}\n"
)
elif args.train_batch_size > args.rollout_batch_size:
self.strategy.print(
f"\n{'=' * 80}\n"
f"ACCUMULATION MODE: train_batch_size ({args.train_batch_size}) > rollout_batch_size ({args.rollout_batch_size})\n" # noqa
f"{'=' * 80}\n"
f"Behavior:\n"
f" - Multiple rollouts needed for one update.\n"
f"{'=' * 80}\n"
)
# Calculate number of rollouts per episode.
# Regardless of TBS and RBS relationship, rollout count should be determined by "total data / rollout size".
# Numerator (num_update_steps * train_batch_size) equals "total samples planned for this episode".
# Denominator (rollout_batch_size * n_samples) equals "samples produced per rollout".
# This calculation ensures data collection volume is constant.
# When TBS=64, num_update_steps is naturally twice as large as when TBS=128.
# Substituting into formula: (2N * 0.5T) / R = (N * T) / R.
# Conclusion: Rollout count unchanged, but internal update loop count doubles due to smaller TBS.
num_rollouts_per_episodes = (
num_update_steps_per_episodes * args.train_batch_size // args.max_epochs // args.rollout_batch_size //
args.n_samples_per_prompt
)
# Safeguard to prevent num_rollouts_per_episodes from being 0
if num_rollouts_per_episodes == 0:
# Try recalculating with ceil to prevent fractional values from being discarded by integer division
val = (num_update_steps_per_episodes *
args.train_batch_size) / (args.max_epochs * args.rollout_batch_size * args.n_samples_per_prompt)
num_rollouts_per_episodes = math.ceil(val)
if num_rollouts_per_episodes == 0:
self.strategy.print("[WARNING] Calculated num_rollouts_per_episodes is 0. Forcing to 1.")
num_rollouts_per_episodes = 1
# Get eval and save steps
if args.eval_steps == -1:
args.eval_steps = num_rollouts_per_episodes # Evaluate once per epoch
if args.save_steps == -1:
args.save_steps = float("inf") # Do not save checkpoint
self.prompts_dataloader = prompts_dataloader
self.pretrain_dataloader = pretrain_dataloader
self.eval_dataloader = eval_dataloader # Save for evaluation
# Restore step and start_episode
steps = consumed_samples // args.rollout_batch_size + 1
start_episode = consumed_samples // args.rollout_batch_size // num_rollouts_per_episodes
consumed_samples = consumed_samples % (num_rollouts_per_episodes * args.rollout_batch_size)
for episode in range(start_episode, args.num_episodes):
if isinstance(self.prompts_dataloader.sampler, DistributedSampler):
self.prompts_dataloader.sampler.set_epoch(
episode, consumed_samples=0 if episode > start_episode else consumed_samples
)
pbar = tqdm(
range(self.prompts_dataloader.__len__()),
desc=f"Episode [{episode + 1}/{args.num_episodes}]",
disable=not self.strategy.is_rank_0(),
)
for batch in self.prompts_dataloader:
# Compatible with both image-only (4 args) and video (5 args) dataloaders
if len(batch) == 5:
rand_prompts, rand_images, rand_videos, rand_references, rand_labels = batch
else:
rand_prompts, rand_images, rand_references, rand_labels = batch
rand_videos = None
# TODO: Remove debug print
self.strategy.print(
f"rand_prompts:\n {rand_prompts}\n , rand_images:{rand_images}\n , rand_references:{rand_references}\n, rand_labels:{rand_labels}\n " # noqa
)
for i, experience in enumerate(
self.experience_maker.make_experience_list(
rand_prompts,
rand_images,
all_videos=rand_videos,
all_references=rand_references,
all_labels=rand_labels,
**self.generate_kwargs
)
):
if i == 0:
output = self.tokenizer.batch_decode(
experience.sequences[0].unsqueeze(0), skip_special_tokens=True
)
self.strategy.print("collect phase: experience.sequences w skip_special_tokens: ", output)
self.strategy.print(
f"collect phase: rand_prompts:\n {rand_prompts[0:2]}\n , rand_images:{rand_images[0:2]}\n , rand_references:{rand_references[0:2]}\n, rand_labels:{rand_labels[0:2]}\n " # noqa
)
# print all
# self.strategy.print(
# f"rand_prompts:\n {rand_prompts}\n , rand_images:{rand_images}\n , rand_references:{rand_references}\n, rand_labels:{rand_labels}\n " # noqa
# )
self.replay_buffer.append(experience)
self.strategy.report_memory('after replay_buffer ready')
# Aggregate rollout statistics from replay buffer
# Collect metrics from the rollout/collection phase
rollout_status = {}
if self.replay_buffer.items:
all_rewards = []
all_format_rewards = []
all_accuracy_rewards = []
all_response_lengths = []
for item in self.replay_buffer.items:
# Collect rewards from rollout
if hasattr(item, 'info') and item.info is not None and 'reward' in item.info:
all_rewards.append(item.info['reward'])
# Robust handling of reward_metrics
# 1. Check if info exists
# 2. Check if 'reward_metrics' key exists
# 3. Check if reward_metrics is not None (critical!)
if (
hasattr(item, 'info') and item.info is not None and 'reward_metrics' in item.info
and item.info['reward_metrics'] is not None
):
reward_metrics = item.info['reward_metrics']
# Safely extract sub-metrics
if 'format_reward' in reward_metrics:
all_format_rewards.append(reward_metrics['format_reward'])
if 'accuracy_reward' in reward_metrics:
all_accuracy_rewards.append(reward_metrics['accuracy_reward'])
# Collect response lengths from rollout
if hasattr(item, 'info') and item.info is not None and 'response_length' in item.info:
all_response_lengths.append(item.info['response_length'])
# Compute rollout statistics
device = torch.cuda.current_device()
if all_rewards:
# [TENSOR-FIX] Handle both tensor lists and scalar lists
if isinstance(all_rewards[0], torch.Tensor):
rewards_tensor = torch.cat([t.to(device).float() for t in all_rewards])
else:
rewards_tensor = torch.tensor(all_rewards, dtype=torch.float32, device=device)
rollout_status["rollout_reward"] = rewards_tensor.mean().item()
rollout_status["rollout_reward_std"] = rewards_tensor.std().item()
if all_format_rewards:
# [TENSOR-FIX] Handle both tensor lists and scalar lists
# Issue: all_format_rewards may contain tensors (from reward_metrics),
# but torch.tensor() cannot convert a list of tensors directly.
# Solution: Use torch.cat() for tensor lists, torch.tensor() for scalar lists
if isinstance(all_format_rewards[0], torch.Tensor):
# List of tensors: concatenate them
format_tensor = torch.cat([t.to(device).float() for t in all_format_rewards])
else:
# List of scalars: convert to tensor
format_tensor = torch.tensor(all_format_rewards, dtype=torch.float32, device=device)
mean_format_reward = format_tensor.mean().item()
# Only display if mean is significantly non-zero
if abs(mean_format_reward) > 1e-6:
rollout_status["rollout_format_reward"] = mean_format_reward
if all_accuracy_rewards:
# [TENSOR-FIX] Handle both tensor lists and scalar lists
if isinstance(all_accuracy_rewards[0], torch.Tensor):
accuracy_tensor = torch.cat([t.to(device).float() for t in all_accuracy_rewards])
else:
accuracy_tensor = torch.tensor(all_accuracy_rewards, dtype=torch.float32, device=device)
mean_accuracy_reward = accuracy_tensor.mean().item()
# Only display if mean is significantly non-zero
if abs(mean_accuracy_reward) > 1e-6:
rollout_status["rollout_accuracy_reward"] = mean_accuracy_reward
if all_response_lengths:
# [TENSOR-FIX] Handle both tensor lists and scalar lists
if isinstance(all_response_lengths[0], torch.Tensor):
lengths_tensor = torch.cat([t.to(device).float() for t in all_response_lengths])
else:
lengths_tensor = torch.tensor(all_response_lengths, dtype=torch.float32, device=device)
rollout_status["rollout_response_length"] = lengths_tensor.mean().item()
# TODO: Check normalization behavior
if self.args.advantage_estimator != "group_norm":
self.replay_buffer.normalize("advantages", self.strategy)
self.strategy.report_memory('before train')
status = self.ppo_train(steps)
self.strategy.report_memory('before clear buffer')
self.replay_buffer.clear()
self.strategy.report_memory('after train')
if "kl" in status:
self.kl_ctl.update(status["kl"], args.rollout_batch_size * args.n_samples_per_prompt)
# Update Episode pbar with ROLLOUT statistics (not training statistics!)
pbar.set_postfix(rollout_status)
# Logs/checkpoints: save BOTH ROLLOUT and TRAINING statistics to wandb
# [FIX] Merge rollout_status (from inference) and status (from training)
# to ensure wandb logs contain both types of metrics
client_states = {"consumed_samples": steps * args.rollout_batch_size}
logs_dict_combined = {**rollout_status, **status} # Merge: rollout first, training second
self.save_logs_and_checkpoints(args, steps, pbar, logs_dict_combined, client_states, episode=episode)
pbar.update()
steps = steps + 1
if self._wandb is not None and self.strategy.is_rank_0():
self._wandb.finish()
if self._tensorboard is not None and self.strategy.is_rank_0():
self._tensorboard.close()
[docs] def ppo_train(self, global_steps=0):
"""
PPO training loop over the replay buffer.
NOTE: This method is not used directly in the main trainer,
as it's overridden by external classes (e.g., lightrft/trainer/spmd_ppo_trainer.py).
:param global_steps: Current global step count, defaults to 0.
:type global_steps: int
:return: Dictionary of averaged training statistics.
:rtype: dict
"""
torch.cuda.empty_cache()
# Replay buffer may be empty at first, we should rebuild at each training
dataloader = DataLoader(
self.replay_buffer,
batch_size=self.replay_buffer.sample_batch_size,
shuffle=True,
drop_last=True,
pin_memory=self.dataloader_pin_memory,
collate_fn=self.replay_buffer.collate_fn,
)
device = torch.cuda.current_device()
status_list = []
status_mean = {}
for epoch in range(self.max_epochs):
pbar = tqdm(
dataloader,
desc=f"Train epoch [{epoch + 1}/{self.max_epochs}]",
disable=not self.strategy.is_rank_0(),
)
for experience in pbar:
experience.to_device(device)
status = self.training_step(experience, global_steps)
# For DP: weighted mean for KL
if "kl" in status:
status["kl"] *= status["response_length"]
status = self.strategy.all_reduce(status)
status["kl"] /= status["response_length"]
short_status = {}
# Add core metrics with abbreviations to keep progress bar concise
if "policy_loss" in status:
short_status.update({
"pg": status.get("policy_loss"),
"rm": status.get("reward"),
"ret": status.get("return"),
"glen": status.get("response_length"),
"tlen": status.get("total_length"),
"kl": status.get("kl"),
"act_lr": status.get("actor_lr"),
})
if "critic_loss" in status:
short_status.update({
"cri": status.get("critic_loss"),
"vals": status.get("values"),
"cri_lr": status.get("critic_lr"),
})
if "ptx_loss" in status:
short_status["ptx"] = status.get("ptx_loss")
for k, v in status.items():
if "/" in k:
short_key = k.split('/')[-1]
short_status[short_key] = v
status_list.append(status)
pbar.set_postfix(short_status)
if status_list:
status_mean = status_list[0]
for m in status_list[1:]:
for k, v in m.items():
status_mean[k] += v
for k in status_mean.keys():
status_mean[k] /= len(status_list)
torch.cuda.empty_cache()
return status_mean
[docs] def training_step(self,
experience: ExperienceVL,
global_steps,
entropy_mask: Optional[torch.Tensor] = None) -> Dict[str, float]:
"""
Single training step combining actor and critic updates.
:param experience: Experience batch from replay buffer.
:type experience: ExperienceVL
:param global_steps: Current global step count.
:type global_steps: int
:param entropy_mask: Optional mask for high-entropy tokens.
:type entropy_mask: Optional[torch.Tensor]
:return: Dictionary of training statistics.
:rtype: Dict[str, float]
"""
status = {}
if global_steps > self.freezing_actor_steps:
status = self.training_step_actor(experience, entropy_mask=entropy_mask)
if self.critic is not None:
status.update(self.training_step_critic(experience))
return status
def _validate_qwen_vl_tensors(
self, sequences: torch.Tensor, pixel_values: Optional[torch.Tensor], context: str = "training"
) -> bool:
"""
Validates the consistency between image tokens in sequences and pixel_values features.
:param sequences: Token sequence tensor.
:type sequences: torch.Tensor
:param pixel_values: Processed pixel values tensor.
:type pixel_values: Optional[torch.Tensor]
:param context: A string indicating where the validation is called from (e.g., "actor_rl", "actor_ptx").
:type context: str
:return: True if data is consistent, False otherwise.
:rtype: bool
"""
if pixel_values is None or pixel_values.numel() == 0:
# This is a text-only batch, no validation needed.
return True
config = self.strategy.unwrap_model(self.actor.model).config
image_token_id = getattr(config, "image_token_id", None)
if image_token_id is None:
# Model does not use special image tokens.
return True
num_tokens = (sequences == image_token_id).sum().item()
num_patches = pixel_values.shape[0] // 4
if num_tokens != num_patches:
self.strategy.print(
f"[CRITICAL WARNING] Skipping batch in '{context}'. "
f"Image features and image tokens do not match: tokens: {num_tokens}, features: {num_patches}. "
"This batch will be discarded to prevent a crash."
)
return False
return True
[docs] def training_step_actor(self,
experience: ExperienceVL,
entropy_mask: Optional[torch.Tensor] = None) -> Dict[str, float]:
"""
Actor training step.
:param experience: Experience batch from replay buffer.
:type experience: ExperienceVL
:return: Dictionary of actor training statistics.
:rtype: Dict[str, float]
"""
self.actor.train()
# TODO: This is a bad indicator to say that data is packed... not supported
if isinstance(experience.sequences, list):
sequences = torch.cat(experience.sequences, dim=0).unsqueeze(0)
pixel_values = experience.pixel_values
image_grid_thws = experience.image_grid_thws
pixel_values_videos = getattr(experience, "pixel_values_videos", None)
video_grid_thws = getattr(experience, "video_grid_thws", None)
old_action_log_probs = torch.cat(experience.action_log_probs, dim=0).unsqueeze(0)
advantages = torch.cat(experience.advantages, dim=0).unsqueeze(0)
num_actions = [v.numel() for v in experience.advantages]
packed_seq_lens = [s.numel() for s in experience.sequences]
attention_mask = torch.cat([torch.full_like(s, i + 1) for i, s in enumerate(experience.sequences)],
dim=0).unsqueeze(0)
if self.args.use_kl_loss and experience.base_action_log_probs is not None:
base_action_log_probs = torch.cat(experience.base_action_log_probs, dim=0).unsqueeze(0)
else:
sequences = experience.sequences
pixel_values = experience.pixel_values
image_grid_thws = experience.image_grid_thws
pixel_values_videos = getattr(experience, "pixel_values_videos", None)
video_grid_thws = getattr(experience, "video_grid_thws", None)
old_action_log_probs = experience.action_log_probs
advantages = experience.advantages
num_actions = experience.action_mask.size(1)
packed_seq_lens = None
attention_mask = experience.attention_mask
if self.args.use_kl_loss and experience.base_action_log_probs is not None:
base_action_log_probs = experience.base_action_log_probs
if advantages is not None:
# Log max advantage before clipping for debugging (optional)
max_adv = advantages.max().item()
if max_adv > 10.0:
self.strategy.print(f"[Warning] Huge advantage detected: {max_adv}")
advantages = torch.clamp(advantages, min=-10.0, max=10.0)
# [DEFENSIVE CHECK] Validate RL data before actor forward pass
# NOTE: This validation is now primarily done in spmd_ppo_trainer.py BEFORE calling training_step
# to ensure all ranks make the same skip decision. This check remains as a safety fallback.
# If this triggers, it indicates a bug in the pre-validation logic.
if not self._validate_qwen_vl_tensors(sequences, pixel_values, context="actor_rl_update"):
self.strategy.print(
"[CRITICAL ERROR] Validation failed inside training_step_actor. "
"This should have been caught by pre-validation in spmd_ppo_trainer.py!"
)
return {} # Emergency fallback - should not normally execute
# Actor loss
# Build kwargs based on actor's modality - only include supported parameters
candidate_params = {
"pixel_values": pixel_values,
"image_grid_thw": image_grid_thws,
"pixel_values_videos": pixel_values_videos,
"video_grid_thw": video_grid_thws,
}
actor_kwargs = {key: value for key, value in candidate_params.items() if key in self._actor_supported_params}
action_log_probs, output = self.actor(
sequences,
num_actions,
attention_mask=attention_mask,
return_output=True,
packed_seq_lens=packed_seq_lens,
**actor_kwargs
)
# NOTE: Explicit masking in log-space is incorrect - removed
# if experience.action_mask is not None:
# # Setting masked positions to 0 to match old_action_log_probs is WRONG in log-space
# action_log_probs = action_log_probs * experience.action_mask
# Loss function
actor_loss = self.actor_loss_fn(
action_log_probs,
old_action_log_probs,
advantages,
action_mask=experience.action_mask,
entropy_mask=entropy_mask,
)
if self.args.use_kl_loss:
if self.initial_model is not None:
# TODO(pu): Text-only action mask for KL calculation
kl = compute_approx_kl(
action_log_probs,
base_action_log_probs,
experience.action_mask,
kl_estimator=self.args.kl_estimator,
)
# [Protection measure 2] Per-token KL Clamping
# NOTE: Adding this causes svkng training to not converge
# kl = torch.clamp(kl, min=0.0, max=20.0)
else:
kl = torch.zeros_like(action_log_probs, dtype=action_log_probs.dtype, device=action_log_probs.device)
if not self.args.packing_samples:
kl_mean = masked_mean(kl, experience.action_mask, dim=-1)
# Not supported for packed samples
else:
# Convert tensor into list of tensors for easier manipulation within dataset
kl = unpacking_samples(kl, num_actions)
kl_mean = torch.tensor([each_kl.mean() for each_kl in kl], device=action_log_probs.device)
kl_loss = kl_mean.mean()
experience.info["kl"] = kl_loss.item()
else:
kl_loss = 0
# Mixtral auxiliary loss
if self.aux_loss:
aux_loss = output.aux_loss
else:
aux_loss = 0
loss = actor_loss + aux_loss * self.args.aux_loss_coef + kl_loss * self.kl_ctl.value
if torch.isnan(loss) or torch.isinf(loss):
self.strategy.print("[CRITICAL ERROR] Actor loss is NaN or Inf at step. Skipping update.")
self.strategy.print(f" Actor Loss: {actor_loss.item()}")
self.strategy.print(f" KL Loss: {kl_loss.item() if isinstance(kl_loss, torch.Tensor) else kl_loss}")
self.strategy.backward(loss, self.actor, self.actor_optim)
# PTX loss for supervised fine-tuning
if self.pretrain_dataloader is not None:
data = next(self.pretrain_dataloader)
inputs = data[1].squeeze(1).to(torch.cuda.current_device())
attention_mask = data[2].squeeze(1).to(torch.cuda.current_device())
label = torch.where(
attention_mask.bool(),
inputs,
self.ptx_loss_fn.IGNORE_INDEX,
)
pixel_values = data[3].to(torch.cuda.current_device())
image_grid_thws = data[4].to(torch.cuda.current_device())
output = self.actor(
inputs,
attention_mask=attention_mask,
pixel_values=pixel_values,
image_grid_thw=image_grid_thws,
return_output=True
)
ptx_log_probs = output["logits"]
# Loss function
ptx_loss = self.ptx_loss_fn(ptx_log_probs, label)
# Mixtral auxiliary loss
if self.aux_loss:
aux_loss = output.aux_loss
else:
aux_loss = 0
loss = ptx_loss + aux_loss * self.args.aux_loss_coef
self.strategy.backward(self.ptx_coef * loss, self.actor, self.actor_optim)
self.strategy.optimizer_step(self.actor_optim, self.actor, self.actor_scheduler, name="actor")
if self.ema_model:
self.strategy.moving_average(self.actor, self.ema_model, self.ema_beta, "cuda")
# Status
status = {"policy_loss": actor_loss.item(), "actor_lr": self.actor_scheduler.get_last_lr()[0]}
if self.pretrain_dataloader is not None:
status["ptx_loss"] = ptx_loss.item()
# Add ratio and loss component statistics from PolicyLoss for diagnosis
if hasattr(self.actor_loss_fn, 'get_last_stats'):
policy_stats = self.actor_loss_fn.get_last_stats()
status.update(policy_stats)
# self.strategy.print(f"experience.info:{experience.info}")
# Robustly handle various data types in experience.info for logging
# Note: We keep all metrics in status dict for internal use (e.g., KL weighting, progress bar)
# but will filter out rollout-only metrics when logging to wandb to avoid duplication
for k, v in experience.info.items():
# Special handling for KL divergence, which is already a scalar item
if k == "kl":
# KL is often weighted by response length, handle it carefully if it's tensor
if isinstance(v, torch.Tensor):
# This logic assumes 'v' is a tensor of KL values per item in the batch
weighted_kl = (v *
experience.info["response_length"]).sum() / experience.info["response_length"].sum()
status[k] = weighted_kl.item()
else: # If it's already a scalar float
status[k] = v
continue
# Handle nested dictionaries like 'reward_metrics'
if isinstance(v, dict):
for sub_k, sub_v in v.items():
log_key = f"{k}/{sub_k}"
if isinstance(sub_v, torch.Tensor):
status[log_key] = sub_v.mean().item()
elif isinstance(sub_v, list) and sub_v and isinstance(sub_v[0], (int, float)):
status[log_key] = sum(sub_v) / len(sub_v)
elif isinstance(sub_v, (int, float)):
status[log_key] = sub_v
continue
# General handling for other keys
if isinstance(v, torch.Tensor):
# If it's a tensor, it's safe to call .mean()
status[k] = v.float().mean().item()
elif isinstance(v, list):
# If it's a list, only compute mean if it contains numbers
if v and isinstance(v[0], (int, float)):
status[k] = sum(v) / len(v)
# Otherwise, it's a list of strings or dicts, which cannot be averaged. Skip it.
elif isinstance(v, (int, float)):
# If it's already a scalar number, just use it
status[k] = v
return status
[docs] def training_step_critic(self, experience: ExperienceVL) -> Dict[str, float]:
"""
Critic training step.
:param experience: Experience batch from replay buffer.
:type experience: ExperienceVL
:return: Dictionary of critic training statistics.
:rtype: Dict[str, float]
"""
self.critic.train()
# Layer 1: Get current GPU device
device = torch.cuda.current_device()
# Layer 2: Helper function for robust device placement
def ensure_device_and_contiguous(tensor, name="tensor"):
"""
Ensure tensor is:
1. On the correct GPU device
2. Contiguous in memory (required by Triton)
3. Return None safely if input is None
:param tensor: Input tensor to process.
:type tensor: torch.Tensor or None
:param name: Name for logging purposes, defaults to "tensor".
:type name: str
:return: Processed tensor or None.
:rtype: torch.Tensor or None
"""
if tensor is None:
return None
# Move to GPU if not already there
if tensor.device.type != 'cuda' or tensor.device.index != device:
tensor = tensor.to(device)
# Ensure contiguous memory layout for Triton kernels
if not tensor.is_contiguous():
tensor = tensor.contiguous()
return tensor
# Layer 3: Apply defensive device placement to all multimodal tensors
pixel_values = ensure_device_and_contiguous(experience.pixel_values, "pixel_values")
image_grid_thws = ensure_device_and_contiguous(experience.image_grid_thws, "image_grid_thws")
pixel_values_videos = ensure_device_and_contiguous(
getattr(experience, "pixel_values_videos", None), "pixel_values_videos"
)
video_grid_thws = ensure_device_and_contiguous(getattr(experience, "video_grid_thws", None), "video_grid_thws")
# TODO: This is a bad indicator to say that data is packed...
if isinstance(experience.sequences, list):
sequences = torch.cat(experience.sequences, dim=0).unsqueeze(0)
old_values = torch.cat(experience.values, dim=0).unsqueeze(0)
returns = torch.cat(experience.returns, dim=0).unsqueeze(0)
num_actions = [v.numel() for v in experience.advantages]
packed_seq_lens = [s.numel() for s in experience.sequences]
attention_mask = torch.cat([torch.full_like(s, i + 1) for i, s in enumerate(experience.sequences)],
dim=0).unsqueeze(0)
else:
sequences = experience.sequences
old_values = experience.values
returns = experience.returns
num_actions = experience.action_mask.size(1)
packed_seq_lens = None
attention_mask = experience.attention_mask
# Ensure sequences and attention_mask are also on device and contiguous
sequences = ensure_device_and_contiguous(sequences, "sequences")
attention_mask = ensure_device_and_contiguous(attention_mask, "attention_mask")
# Critic loss
values, output = self.critic(
sequences,
num_actions=num_actions,
attention_mask=attention_mask,
pixel_values=pixel_values,
image_grid_thw=image_grid_thws,
pixel_values_videos=pixel_values_videos,
video_grid_thw=video_grid_thws,
return_output=True,
packed_seq_lens=packed_seq_lens,
)
# Loss function
critic_loss = self.critic_loss_fn(
values,
old_values,
returns,
action_mask=experience.action_mask,
)
# Mixtral auxiliary loss
if self.aux_loss:
aux_loss = output.aux_loss
else:
aux_loss = 0
loss = critic_loss + aux_loss * self.args.aux_loss_coef
self.strategy.backward(loss, self.critic, self.critic_optim)
self.strategy.optimizer_step(self.critic_optim, self.critic, self.critic_scheduler, name="critic")
# Status
status = {
"critic_loss": critic_loss.item(),
"values": masked_mean(values, experience.action_mask).item(),
"critic_lr": self.critic_scheduler.get_last_lr()[0],
}
return status
[docs] def save_logs_and_checkpoints(self, args, global_step, step_bar, logs_dict={}, client_states={}, episode=0):
"""
Save logs to wandb/tensorboard and save model checkpoints.
:param args: Training arguments.
:type args: Namespace
:param global_step: Current global step.
:type global_step: int
:param step_bar: Progress bar object.
:type step_bar: tqdm
:param logs_dict: Dictionary of metrics to log. Should contain both:
- Rollout statistics (rollout_reward, rollout_response_length, etc.)
from inference/generation phase
- Training statistics (policy_loss, critic_loss, kl, etc.)
from optimization phase
Defaults to {}.
:type logs_dict: dict
:param client_states: Client state for checkpoint recovery, defaults to {}.
:type client_states: dict
:param episode: Current episode number, defaults to 0.
:type episode: int
"""
# 1. LOGGING TRAIN & ROLLOUT METRICS
if global_step % args.logging_steps == 0:
# Metrics that are already logged in rollout/ namespace should not be duplicated in train/
ROLLOUT_ONLY_METRICS = {'reward', 'response_length', 'total_length', 'num_actions', 'return'}
ROLLOUT_ONLY_PREFIXES = {'reward_metrics/'}
rollout_metrics = {}
train_metrics = {}
for k, v in logs_dict.items():
if k.startswith('rollout_'):
# Clean key: rollout_reward -> reward
clean_key = k.replace('rollout_', '', 1)
rollout_metrics[clean_key] = v
elif k in ROLLOUT_ONLY_METRICS:
continue
elif any(k.startswith(prefix) for prefix in ROLLOUT_ONLY_PREFIXES):
continue
else:
# Everything else is considered a training metric
train_metrics[k] = v
# Wandb Logging
if self._wandb is not None and self.strategy.is_rank_0():
all_wandb_logs = {}
# Add Rollout Metrics
for k, v in rollout_metrics.items():
all_wandb_logs[f"rollout/{k}"] = v
all_wandb_logs["rollout/global_step"] = global_step
all_wandb_logs["rollout/episode"] = episode
# Add Train Metrics
for k, v in train_metrics.items():
all_wandb_logs[f"train/{k}"] = v
all_wandb_logs["train/global_step"] = global_step
all_wandb_logs["train/episode"] = episode
# Performance Stats
if self.experience_maker.perf_stats is not None:
for k, v in self.experience_maker.perf_stats.items():
all_wandb_logs[f"perf/experience_maker/{k}"] = v
# Commit Train/Rollout logs with unique system step
if all_wandb_logs:
self.wandb_log_counter += 1
self._wandb.log(all_wandb_logs, step=self.wandb_log_counter, commit=True)
# TensorBoard Logging
elif self._tensorboard is not None and self.strategy.is_rank_0():
for k, v in rollout_metrics.items():
self._tensorboard.add_scalar(f"rollout/{k}", v, global_step)
for k, v in train_metrics.items():
self._tensorboard.add_scalar(f"train/{k}", v, global_step)
# 2. EVALUATION
if global_step % args.eval_steps == 0 and self.eval_dataloader is not None:
# Run evaluation
raw_eval_metrics = self.evaluate(self.eval_dataloader, global_step)
# Only log if we have results
if raw_eval_metrics and self.strategy.is_rank_0():
self.eval_step_counter += 1
# Wandb Logging for Eval
if self._wandb is not None:
eval_logs = {}
for k, v in raw_eval_metrics.items():
# Remove "eval_" prefix if present to avoid "eval/eval_reward"
clean_key = k.replace("eval_", "") if k.startswith("eval_") else k
eval_logs[f"eval/{clean_key}"] = v
# Custom X-axis for Eval
eval_logs["eval/global_step"] = self.eval_step_counter
# Reference to main training step
eval_logs["eval/train_step"] = global_step
eval_logs["eval/episode"] = episode
# IMPORTANT:
# Use wandb_log_counter to ensure eval has a unique system step
# This prevents eval metrics from being overwritten by train metrics
# The plots will still use eval/global_step as X-axis due to define_metric
self.wandb_log_counter += 1
self._wandb.log(eval_logs, step=self.wandb_log_counter, commit=True)
# TensorBoard Logging for Eval
elif self._tensorboard is not None:
for k, v in raw_eval_metrics.items():
# Clean key
clean_key = k.replace("eval_", "") if k.startswith("eval_") else k
self._tensorboard.add_scalar(f"eval/{clean_key}", v, global_step)
# 3. CHECKPOINTING
if global_step % args.save_steps == 0:
tag = f"global_step{global_step}"
self._save_checkpoint(args, tag, client_states)
def _save_checkpoint(self, args, tag, client_states):
"""
Save model checkpoint to disk.
:param args: Training arguments.
:type args: Namespace
:param tag: Checkpoint tag (e.g., "global_step1000").
:type tag: str
:param client_states: Client state for checkpoint recovery.
:type client_states: dict
"""
if not self.disable_ds_ckpt:
self.strategy.save_ckpt(
self.actor.model,
os.path.join(args.ckpt_path, "_actor"),
tag,
args.max_ckpt_num,
args.max_ckpt_mem,
client_states,
)
if self.critic is not None:
self.strategy.save_ckpt(
self.critic, os.path.join(args.ckpt_path, "_critic"), tag, args.max_ckpt_num, args.max_ckpt_mem
)
if self.save_hf_ckpt:
save_path = os.path.join(args.ckpt_path, f"{tag}_hf")
self.strategy.save_model(self.actor, self.tokenizer, save_path)
[docs] def evaluate(self, eval_dataloader, global_step):
"""
Evaluate the model on evaluation dataset.
:param eval_dataloader: DataLoader for evaluation data.
:type eval_dataloader: DataLoader
:param global_step: Current global step for logging.
:type global_step: int
:return: Dictionary of evaluation metrics.
:rtype: dict
"""
if eval_dataloader is None:
return {}
self.strategy.print(f"\n{'=' * 60}")
self.strategy.print(f"Starting evaluation at step {global_step}")
self.strategy.print(f"{'=' * 60}")
self.actor.eval()
if self.critic is not None:
self.critic.eval()
all_rewards = []
all_format_rewards = []
all_accuracy_rewards = []
all_response_lengths = []
num_eval_batches = 0
# Helper to extract values
def extract_values(val):
if isinstance(val, torch.Tensor):
return val.view(-1).cpu().tolist()
elif isinstance(val, (list, tuple)):
return list(val)
else:
return [float(val)]
with torch.no_grad():
for batch in eval_dataloader:
if len(batch) == 5:
eval_prompts, eval_images, eval_videos, eval_references, eval_labels = batch
else:
eval_prompts, eval_images, eval_references, eval_labels = batch
eval_videos = None
# Generate responses using experience maker (but don't train on them)
# We reuse the experience maker but only for generation
# TODO: simplify this logic
for i, experience in enumerate(
self.experience_maker.make_experience_list(
eval_prompts, eval_images, eval_videos, eval_references, eval_labels, **self.generate_kwargs
)
):
if i == 0:
output = self.tokenizer.batch_decode(
experience.sequences[0].unsqueeze(0), skip_special_tokens=True
)
self.strategy.print("eval phase: experience.sequences w skip_special_tokens: ", output)
self.strategy.print(
f"eval phase: eval_prompts:\n {eval_prompts[0:2]}\n , rand_images:{eval_images[0:2]}\n , eval_references:{eval_references[0:2]}\n, eval_labels:{eval_labels[0:2]}\n " # noqa
)
if hasattr(experience, 'info') and experience.info:
info = experience.info
if 'reward' in info:
all_rewards.extend(extract_values(info['reward']))
if 'response_length' in info:
all_response_lengths.extend(extract_values(info['response_length']))
if 'reward_metrics' in info:
rm = info['reward_metrics']
if 'format_reward' in rm:
all_format_rewards.extend(extract_values(rm['format_reward']))
if 'accuracy_reward' in rm:
all_accuracy_rewards.extend(extract_values(rm['accuracy_reward']))
num_eval_batches += 1
if num_eval_batches >= len(eval_dataloader):
break
# Compute statistics
metrics = {}
device = torch.cuda.current_device()
def compute_stats(name, values_list):
if not values_list:
return
if isinstance(values_list[0], torch.Tensor):
t = torch.cat([x.to(device).float() for x in values_list])
else:
t = torch.tensor(values_list, dtype=torch.float32, device=device)
metrics[f"{name}_mean"] = t.mean().item()
# metrics[f"{name}_std"] = t.std().item() # Optional
compute_stats("reward", all_rewards)
compute_stats("format_reward", all_format_rewards)
compute_stats("accuracy_reward", all_accuracy_rewards)
compute_stats("response_length", all_response_lengths)
metrics["num_samples"] = len(all_rewards)
# Print results
self.strategy.print(f"Evaluation Results (Step {global_step}):")
for k, v in metrics.items():
self.strategy.print(f" {k}: {v:.4f}")
self.strategy.print(f"{'=' * 60}\n")
self.actor.train()
if self.critic is not None:
self.critic.train()
return metrics