lightrft.trainer.ppo_trainer_vl¶
- class lightrft.trainer.ppo_trainer_vl.PPOTrainerVL(strategy, actor: ActorVL, critic: torch.nn.Module, reward_model: torch.nn.Module, initial_model: ActorVL, ema_model: ActorVL, actor_optim: torch.optim.Optimizer, critic_optim: torch.optim.Optimizer, actor_scheduler, critic_scheduler, ema_beta: float = 0.992, init_kl_coef: float = 0.001, kl_target: float = None, kl_horizon: int = 10000, ptx_coef: float = 0, micro_train_batch_size: int = 8, buffer_limit: int = 0, buffer_cpu_offload: bool = True, eps_clip: float = 0.2, value_clip: float = 0.2, micro_rollout_batch_size: int = 8, gradient_checkpointing: bool = False, max_epochs: int = 1, max_norm: float = 1.0, tokenizer: Callable[[Any], dict] | None = None, processor: Callable[[Any], dict] | None = None, prompt_max_len: int = 128, dataloader_pin_memory: bool = True, remote_rm_url: str = None, reward_fn: Callable[[List[torch.Tensor]], torch.Tensor] = None, reward_fn_label_map: dict = None, reward_recipe: dict = None, save_hf_ckpt: bool = False, disable_ds_ckpt: bool = False, **generate_kwargs)[source]¶
Bases:
ABCTrainer for Proximal Policy Optimization (PPO) algorithm for Vision-Language Models.
- Parameters:
strategy (Strategy) – The training strategy to use.
actor (ActorVL) – The actor model in the PPO algorithm.
critic (nn.Module) – The critic model in the PPO algorithm.
reward_model (nn.Module) – The reward model for calculating rewards in the RLHF setup.
initial_model (ActorVL) – The initial model for reference logits to limit actor updates in RLHF.
ema_model (ActorVL) – The exponential moving average model for stable training.
actor_optim (Optimizer) – The optimizer for the actor model.
critic_optim (Optimizer) – The optimizer for the critic model.
actor_scheduler (Scheduler) – The learning rate scheduler for the actor.
critic_scheduler (Scheduler) – The learning rate scheduler for the critic.
ema_beta (float) – EMA decay rate for model stability, defaults to 0.992.
init_kl_coef (float) – Initial coefficient for KL divergence, defaults to 0.001.
kl_target (float, optional) – Target value for KL divergence, defaults to None.
kl_horizon (int) – Horizon for KL annealing, defaults to 10000.
ptx_coef (float) – Coefficient for supervised loss from pre-trained data, defaults to 0.
micro_train_batch_size (int) – Micro-batch size for actor training, defaults to 8.
buffer_limit (int) – Maximum size of the replay buffer, defaults to 0.
buffer_cpu_offload (bool) – If True, offloads replay buffer to CPU, defaults to True.
eps_clip (float) – Clipping coefficient for policy loss, defaults to 0.2.
value_clip (float) – Clipping coefficient for value function loss, defaults to 0.2.
micro_rollout_batch_size (int) – Micro-batch size for generating rollouts, defaults to 8.
gradient_checkpointing (bool) – If True, enables gradient checkpointing, defaults to False.
max_epochs (int) – Number of epochs to train, defaults to 1.
max_norm (float) – Maximum gradient norm for gradient clipping, defaults to 1.0.
tokenizer (Callable, optional) – Tokenizer for input data, defaults to None.
processor (Callable, optional) – Processor for multimodal input data, defaults to None.
prompt_max_len (int) – Maximum length for prompts, defaults to 128.
dataloader_pin_memory (bool) – If True, pins memory in the data loader, defaults to True.
remote_rm_url (str, optional) – URL for remote reward model API, defaults to None.
reward_fn (Callable, optional) – Custom reward function for computing rewards, defaults to None.
reward_fn_label_map (dict, optional) – Label mapping for reward function, defaults to None.
reward_recipe (dict, optional) – Recipe configuration for reward computation, defaults to None.
save_hf_ckpt (bool) – Whether to save huggingface-format model weight, defaults to False.
disable_ds_ckpt (bool) – Whether not to save deepspeed-format model weight (used for training recovery).
generate_kwargs (dict) – Additional arguments for model generation.
- evaluate(eval_dataloader, global_step)[source]¶
Evaluate the model on evaluation dataset.
- Parameters:
eval_dataloader (DataLoader) – DataLoader for evaluation data.
global_step (int) – Current global step for logging.
- Returns:
Dictionary of evaluation metrics.
- Return type:
dict
- fit(args, prompts_dataloader, pretrain_dataloader, eval_dataloader=None, consumed_samples=0, num_update_steps_per_episodes=1) None[source]¶
Main training loop for PPO.
- Parameters:
args (Namespace) – Training arguments.
prompts_dataloader (DataLoader) – DataLoader for prompt data.
pretrain_dataloader (DataLoader) – DataLoader for pre-training data.
eval_dataloader (DataLoader, optional) – DataLoader for evaluation data, defaults to None.
consumed_samples (int) – Number of samples already consumed, defaults to 0.
num_update_steps_per_episodes (int) – Number of update steps per episode, defaults to 1.
- ppo_train(global_steps=0)[source]¶
PPO training loop over the replay buffer.
NOTE: This method is not used directly in the main trainer, as it’s overridden by external classes (e.g., lightrft/trainer/spmd_ppo_trainer.py).
- Parameters:
global_steps (int) – Current global step count, defaults to 0.
- Returns:
Dictionary of averaged training statistics.
- Return type:
dict
- save_logs_and_checkpoints(args, global_step, step_bar, logs_dict={}, client_states={}, episode=0)[source]¶
Save logs to wandb/tensorboard and save model checkpoints.
- Parameters:
args (Namespace) – Training arguments.
global_step (int) – Current global step.
step_bar (tqdm) – Progress bar object.
logs_dict (dict) –
Dictionary of metrics to log. Should contain both: - Rollout statistics (rollout_reward, rollout_response_length, etc.)
from inference/generation phase
Training statistics (policy_loss, critic_loss, kl, etc.) from optimization phase
Defaults to {}.
client_states (dict) – Client state for checkpoint recovery, defaults to {}.
episode (int) – Current episode number, defaults to 0.
- training_step(experience: ExperienceVL, global_steps, entropy_mask: torch.Tensor | None = None) Dict[str, float][source]¶
Single training step combining actor and critic updates.
- Parameters:
experience (ExperienceVL) – Experience batch from replay buffer.
global_steps (int) – Current global step count.
entropy_mask (Optional[torch.Tensor]) – Optional mask for high-entropy tokens.
- Returns:
Dictionary of training statistics.
- Return type:
Dict[str, float]
- training_step_actor(experience: ExperienceVL, entropy_mask: torch.Tensor | None = None) Dict[str, float][source]¶
Actor training step.
- Parameters:
experience (ExperienceVL) – Experience batch from replay buffer.
- Returns:
Dictionary of actor training statistics.
- Return type:
Dict[str, float]
- training_step_critic(experience: ExperienceVL) Dict[str, float][source]¶
Critic training step.
- Parameters:
experience (ExperienceVL) – Experience batch from replay buffer.
- Returns:
Dictionary of critic training statistics.
- Return type:
Dict[str, float]