Shortcuts

lightrft.trainer.ppo_trainer_vl

class lightrft.trainer.ppo_trainer_vl.PPOTrainerVL(strategy, actor: ActorVL, critic: torch.nn.Module, reward_model: torch.nn.Module, initial_model: ActorVL, ema_model: ActorVL, actor_optim: torch.optim.Optimizer, critic_optim: torch.optim.Optimizer, actor_scheduler, critic_scheduler, ema_beta: float = 0.992, init_kl_coef: float = 0.001, kl_target: float = None, kl_horizon: int = 10000, ptx_coef: float = 0, micro_train_batch_size: int = 8, buffer_limit: int = 0, buffer_cpu_offload: bool = True, eps_clip: float = 0.2, value_clip: float = 0.2, micro_rollout_batch_size: int = 8, gradient_checkpointing: bool = False, max_epochs: int = 1, max_norm: float = 1.0, tokenizer: Callable[[Any], dict] | None = None, processor: Callable[[Any], dict] | None = None, prompt_max_len: int = 128, dataloader_pin_memory: bool = True, remote_rm_url: str = None, reward_fn: Callable[[List[torch.Tensor]], torch.Tensor] = None, reward_fn_label_map: dict = None, reward_recipe: dict = None, save_hf_ckpt: bool = False, disable_ds_ckpt: bool = False, **generate_kwargs)[source]

Bases: ABC

Trainer for Proximal Policy Optimization (PPO) algorithm for Vision-Language Models.

Parameters:
  • strategy (Strategy) – The training strategy to use.

  • actor (ActorVL) – The actor model in the PPO algorithm.

  • critic (nn.Module) – The critic model in the PPO algorithm.

  • reward_model (nn.Module) – The reward model for calculating rewards in the RLHF setup.

  • initial_model (ActorVL) – The initial model for reference logits to limit actor updates in RLHF.

  • ema_model (ActorVL) – The exponential moving average model for stable training.

  • actor_optim (Optimizer) – The optimizer for the actor model.

  • critic_optim (Optimizer) – The optimizer for the critic model.

  • actor_scheduler (Scheduler) – The learning rate scheduler for the actor.

  • critic_scheduler (Scheduler) – The learning rate scheduler for the critic.

  • ema_beta (float) – EMA decay rate for model stability, defaults to 0.992.

  • init_kl_coef (float) – Initial coefficient for KL divergence, defaults to 0.001.

  • kl_target (float, optional) – Target value for KL divergence, defaults to None.

  • kl_horizon (int) – Horizon for KL annealing, defaults to 10000.

  • ptx_coef (float) – Coefficient for supervised loss from pre-trained data, defaults to 0.

  • micro_train_batch_size (int) – Micro-batch size for actor training, defaults to 8.

  • buffer_limit (int) – Maximum size of the replay buffer, defaults to 0.

  • buffer_cpu_offload (bool) – If True, offloads replay buffer to CPU, defaults to True.

  • eps_clip (float) – Clipping coefficient for policy loss, defaults to 0.2.

  • value_clip (float) – Clipping coefficient for value function loss, defaults to 0.2.

  • micro_rollout_batch_size (int) – Micro-batch size for generating rollouts, defaults to 8.

  • gradient_checkpointing (bool) – If True, enables gradient checkpointing, defaults to False.

  • max_epochs (int) – Number of epochs to train, defaults to 1.

  • max_norm (float) – Maximum gradient norm for gradient clipping, defaults to 1.0.

  • tokenizer (Callable, optional) – Tokenizer for input data, defaults to None.

  • processor (Callable, optional) – Processor for multimodal input data, defaults to None.

  • prompt_max_len (int) – Maximum length for prompts, defaults to 128.

  • dataloader_pin_memory (bool) – If True, pins memory in the data loader, defaults to True.

  • remote_rm_url (str, optional) – URL for remote reward model API, defaults to None.

  • reward_fn (Callable, optional) – Custom reward function for computing rewards, defaults to None.

  • reward_fn_label_map (dict, optional) – Label mapping for reward function, defaults to None.

  • reward_recipe (dict, optional) – Recipe configuration for reward computation, defaults to None.

  • save_hf_ckpt (bool) – Whether to save huggingface-format model weight, defaults to False.

  • disable_ds_ckpt (bool) – Whether not to save deepspeed-format model weight (used for training recovery).

  • generate_kwargs (dict) – Additional arguments for model generation.

evaluate(eval_dataloader, global_step)[source]

Evaluate the model on evaluation dataset.

Parameters:
  • eval_dataloader (DataLoader) – DataLoader for evaluation data.

  • global_step (int) – Current global step for logging.

Returns:

Dictionary of evaluation metrics.

Return type:

dict

fit(args, prompts_dataloader, pretrain_dataloader, eval_dataloader=None, consumed_samples=0, num_update_steps_per_episodes=1) None[source]

Main training loop for PPO.

Parameters:
  • args (Namespace) – Training arguments.

  • prompts_dataloader (DataLoader) – DataLoader for prompt data.

  • pretrain_dataloader (DataLoader) – DataLoader for pre-training data.

  • eval_dataloader (DataLoader, optional) – DataLoader for evaluation data, defaults to None.

  • consumed_samples (int) – Number of samples already consumed, defaults to 0.

  • num_update_steps_per_episodes (int) – Number of update steps per episode, defaults to 1.

ppo_train(global_steps=0)[source]

PPO training loop over the replay buffer.

NOTE: This method is not used directly in the main trainer, as it’s overridden by external classes (e.g., lightrft/trainer/spmd_ppo_trainer.py).

Parameters:

global_steps (int) – Current global step count, defaults to 0.

Returns:

Dictionary of averaged training statistics.

Return type:

dict

save_logs_and_checkpoints(args, global_step, step_bar, logs_dict={}, client_states={}, episode=0)[source]

Save logs to wandb/tensorboard and save model checkpoints.

Parameters:
  • args (Namespace) – Training arguments.

  • global_step (int) – Current global step.

  • step_bar (tqdm) – Progress bar object.

  • logs_dict (dict) –

    Dictionary of metrics to log. Should contain both: - Rollout statistics (rollout_reward, rollout_response_length, etc.)

    from inference/generation phase

    • Training statistics (policy_loss, critic_loss, kl, etc.) from optimization phase

    Defaults to {}.

  • client_states (dict) – Client state for checkpoint recovery, defaults to {}.

  • episode (int) – Current episode number, defaults to 0.

training_step(experience: ExperienceVL, global_steps, entropy_mask: torch.Tensor | None = None) Dict[str, float][source]

Single training step combining actor and critic updates.

Parameters:
  • experience (ExperienceVL) – Experience batch from replay buffer.

  • global_steps (int) – Current global step count.

  • entropy_mask (Optional[torch.Tensor]) – Optional mask for high-entropy tokens.

Returns:

Dictionary of training statistics.

Return type:

Dict[str, float]

training_step_actor(experience: ExperienceVL, entropy_mask: torch.Tensor | None = None) Dict[str, float][source]

Actor training step.

Parameters:

experience (ExperienceVL) – Experience batch from replay buffer.

Returns:

Dictionary of actor training statistics.

Return type:

Dict[str, float]

training_step_critic(experience: ExperienceVL) Dict[str, float][source]

Critic training step.

Parameters:

experience (ExperienceVL) – Experience batch from replay buffer.

Returns:

Dictionary of critic training statistics.

Return type:

Dict[str, float]