Shortcuts

Source code for lightrft.trainer.ppo_trainer

import os
import sys
import os.path
from abc import ABC
from typing import Any, Callable, Dict, List, Optional, Union

import torch
import torch.nn as nn
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from tqdm import tqdm

from lightrft.models import ActorLanguage, GPTLMLoss, PolicyLoss, ValueLoss
from lightrft.models.utils import masked_mean, unpacking_samples, compute_approx_kl
from lightrft.utils.distributed_sampler import DistributedSampler
from lightrft.trainer import AdaptiveKLController, Experience, FixedKLController, NaiveExperienceMaker, NaiveReplayBuffer  # noqa


[docs]class PPOTrainer(ABC): """ Trainer for Proximal Policy Optimization (PPO) algorithm. :param strategy: The training strategy to use. :type strategy: Strategy :param actor: The actor model in the PPO algorithm. :type actor: ActorLanguage :param critic: The critic model in the PPO algorithm. :type critic: nn.Module :param reward_model: The reward model for calculating rewards in the RLHF setup. :type reward_model: nn.Module :param initial_model: The initial model for reference logits to limit actor updates in RLHF. :type initial_model: ActorLanguage :param ema_model: The exponential moving average model for stable training. :type ema_model: ActorLanguage :param actor_optim: The optimizer for the actor model. :type actor_optim: Optimizer :param critic_optim: The optimizer for the critic model. :type critic_optim: Optimizer :param actor_scheduler: The learning rate scheduler for the actor. :type actor_scheduler: Scheduler :param critic_scheduler: The learning rate scheduler for the critic. :type critic_scheduler: Scheduler :param ema_beta: EMA decay rate for model stability, defaults to 0.992. :type ema_beta: float :param init_kl_coef: Initial coefficient for KL divergence, defaults to 0.001. :type init_kl_coef: float :param kl_target: Target value for KL divergence, defaults to None. :type kl_target: float, optional :param kl_horizon: Horizon for KL annealing, defaults to 10000. :type kl_horizon: int :param ptx_coef: Coefficient for supervised loss from pre-trained data, defaults to 0. :type ptx_coef: float :param micro_train_batch_size: Micro-batch size for actor training, defaults to 8. :type micro_train_batch_size: int :param buffer_limit: Maximum size of the replay buffer, defaults to 0. :type buffer_limit: int :param buffer_cpu_offload: If True, offloads replay buffer to CPU, defaults to True. :type buffer_cpu_offload: bool :param eps_clip: Clipping coefficient for policy loss, defaults to 0.2. :type eps_clip: float :param value_clip: Clipping coefficient for value function loss, defaults to 0.2. :type value_clip: float :param micro_rollout_batch_size: Micro-batch size for generating rollouts, defaults to 8. :type micro_rollout_batch_size: int :param gradient_checkpointing: If True, enables gradient checkpointing, defaults to False. :type gradient_checkpointing: bool :param max_epochs: Number of epochs to train, defaults to 1. :type max_epochs: int :param max_norm: Maximum gradient norm for gradient clipping, defaults to 1.0. :type max_norm: float :param tokenizer: Tokenizer for input data, defaults to None. :type tokenizer: Callable, optional :param prompt_max_len: Maximum length for prompts, defaults to 128. :type prompt_max_len: int :param dataloader_pin_memory: If True, pins memory in the data loader, defaults to True. :type dataloader_pin_memory: bool :param remote_rm_url: URL for remote reward model API, defaults to None. :type remote_rm_url: str, optional :param reward_fn: Custom reward function for computing rewards, defaults to None. :type reward_fn: Callable, optional :param save_hf_ckpt: Whether to save huggingface-format model weight, defaults to False. :type save_hf_ckpt: bool :param disable_ds_ckpt: Whether not to save deepspeed-format model weight (used for training recovery). :type disable_ds_ckpt: bool :param generate_kwargs: Additional arguments for model generation. :type generate_kwargs: dict """ def __init__( self, strategy, actor: ActorLanguage, critic: nn.Module, reward_model: Union[nn.Module, List[nn.Module]], initial_model: ActorLanguage, ema_model: ActorLanguage, actor_optim: Optimizer, critic_optim: Optimizer, actor_scheduler, critic_scheduler, ema_beta: float = 0.992, init_kl_coef: float = 0.001, kl_target: Optional[float] = None, kl_horizon: int = 10000, ptx_coef: float = 0, micro_train_batch_size: int = 8, buffer_limit: int = 0, buffer_cpu_offload: bool = True, eps_clip: float = 0.2, value_clip: float = 0.2, micro_rollout_batch_size: int = 8, gradient_checkpointing: bool = False, max_epochs: int = 1, max_norm: float = 1.0, tokenizer: Optional[Callable[[Any], dict]] = None, prompt_max_len: int = 128, dataloader_pin_memory: bool = True, remote_rm_url: Optional[str] = None, reward_fn: Optional[Callable[[List[torch.Tensor]], torch.Tensor]] = None, save_hf_ckpt: bool = False, disable_ds_ckpt: bool = False, **generate_kwargs, ) -> None: assert ( not isinstance(reward_model, List) or len(reward_model) == 1 or reward_fn is not None ), "reward_fn must be specified if using multiple reward models" ABC.__init__(self) # Get current filename and line number for debugging current_filename = os.path.basename(__file__) current_lineno = sys._getframe().f_lineno self.strategy.print(f"[{current_filename}:{current_lineno}]") self.strategy = strategy self.args = strategy.args self.save_hf_ckpt = save_hf_ckpt self.disable_ds_ckpt = disable_ds_ckpt self.micro_rollout_batch_size = micro_rollout_batch_size self.max_epochs = max_epochs self.tokenizer = tokenizer self.generate_kwargs = generate_kwargs self.dataloader_pin_memory = dataloader_pin_memory self.max_norm = max_norm self.ptx_coef = ptx_coef self.micro_train_batch_size = micro_train_batch_size self.kl_target = kl_target self.prompt_max_len = prompt_max_len self.ema_beta = ema_beta self.gradient_checkpointing = gradient_checkpointing self.reward_fn = reward_fn self.actor = actor self.critic = critic self.reward_model = reward_model self.remote_rm_url = remote_rm_url self.initial_model = initial_model self.ema_model = ema_model self.actor_optim = actor_optim self.critic_optim = critic_optim self.actor_scheduler = actor_scheduler self.critic_scheduler = critic_scheduler self.actor_loss_fn = PolicyLoss(eps_clip) self.critic_loss_fn = ValueLoss(value_clip) self.ptx_loss_fn = GPTLMLoss() self.freezing_actor_steps = getattr(self.args, "freezing_actor_steps", -1) # Mixtral 8x7b auxiliary loss self.aux_loss = self.args.aux_loss_coef > 1e-8 if self.kl_target: self.kl_ctl = AdaptiveKLController(init_kl_coef, kl_target, kl_horizon) else: self.kl_ctl = FixedKLController(init_kl_coef) self.experience_maker = NaiveExperienceMaker( actor, critic, reward_model, initial_model, tokenizer, prompt_max_len, self.kl_ctl, strategy, remote_rm_url, reward_fn, ) packing_samples = getattr(self.args, "packing_samples", False) self.replay_buffer = NaiveReplayBuffer( micro_train_batch_size, buffer_limit, buffer_cpu_offload, packing_samples ) # Initialize wandb/tensorboard for logging self._wandb = None self._tensorboard = None if self.strategy.args.use_wandb and self.strategy.is_rank_0(): import wandb self._wandb = wandb if not wandb.api.api_key: wandb.login(key=strategy.args.use_wandb) wandb.init( entity=strategy.args.wandb_org, project=strategy.args.wandb_project, group=strategy.args.wandb_group, name=strategy.args.wandb_run_name, config=strategy.args.__dict__, reinit=True, ) # Define separate metric namespaces for clarity: # - rollout/*: Metrics from experience generation phase # - train/*: Metrics from policy optimization phase # - eval/*: Metrics from evaluation phase wandb.define_metric("rollout/global_step") wandb.define_metric("rollout/*", step_metric="rollout/global_step", step_sync=True) wandb.define_metric("train/global_step") wandb.define_metric("train/*", step_metric="train/global_step", step_sync=True) wandb.define_metric("eval/epoch") wandb.define_metric("eval/*", step_metric="eval/epoch", step_sync=True) # Initialize TensorBoard writer if wandb is not available if self.strategy.args.use_tensorboard and self._wandb is None and self.strategy.is_rank_0(): from torch.utils.tensorboard import SummaryWriter os.makedirs(self.strategy.args.use_tensorboard, exist_ok=True) log_dir = os.path.join(self.strategy.args.use_tensorboard, strategy.args.wandb_run_name) self._tensorboard = SummaryWriter(log_dir=log_dir)
[docs] def fit( self, args, prompts_dataloader, pretrain_dataloader, consumed_samples=0, num_update_steps_per_episodes=1, ) -> None: """ Main training loop for PPO. :param args: Training arguments. :type args: Namespace :param prompts_dataloader: DataLoader for prompt data. :type prompts_dataloader: DataLoader :param pretrain_dataloader: DataLoader for pre-training data. :type pretrain_dataloader: DataLoader :param consumed_samples: Number of samples already consumed, defaults to 0. :type consumed_samples: int :param num_update_steps_per_episodes: Number of update steps per episode, defaults to 1. :type num_update_steps_per_episodes: int """ num_rollouts_per_episodes = ( num_update_steps_per_episodes * args.train_batch_size // args.max_epochs // args.rollout_batch_size // args.n_samples_per_prompt ) # Get eval and save steps if args.eval_steps == -1: args.eval_steps = num_rollouts_per_episodes # Evaluate once per epoch if args.save_steps == -1: args.save_steps = float("inf") # Do not save checkpoint self.prompts_dataloader = prompts_dataloader self.pretrain_dataloader = pretrain_dataloader # Restore step and start_episode steps = consumed_samples // args.rollout_batch_size + 1 start_episode = consumed_samples // args.rollout_batch_size // num_rollouts_per_episodes consumed_samples = consumed_samples % (num_rollouts_per_episodes * args.rollout_batch_size) for episode in range(start_episode, args.num_episodes): if isinstance(self.prompts_dataloader.sampler, DistributedSampler): self.prompts_dataloader.sampler.set_epoch( episode, consumed_samples=0 if episode > start_episode else consumed_samples ) pbar = tqdm( range(self.prompts_dataloader.__len__()), desc=f"Episode [{episode + 1}/{args.num_episodes}]", disable=not self.strategy.is_rank_0(), ) for rand_prompts, labels in self.prompts_dataloader: for i, experience in enumerate( self.experience_maker.make_experience_list(rand_prompts, all_labels=labels, **self.generate_kwargs) ): if i == 0: output = self.tokenizer.batch_decode( experience.sequences[0].unsqueeze(0), skip_special_tokens=True ) self.strategy.print(output) self.replay_buffer.append(experience) self.strategy.report_memory('after replay_buffer ready') if self.args.advantage_estimator != "group_norm": self.replay_buffer.normalize("advantages", self.strategy) self.strategy.report_memory('before train') status = self.ppo_train(steps) self.strategy.report_memory('before clear buffer') self.replay_buffer.clear() self.strategy.report_memory('after train') if "kl" in status: self.kl_ctl.update(status["kl"], args.rollout_batch_size * args.n_samples_per_prompt) pbar.set_postfix(status) # Logs/checkpoints client_states = {"consumed_samples": steps * args.rollout_batch_size} self.save_logs_and_checkpoints(args, steps, pbar, status, client_states) pbar.update() steps = steps + 1 if self._wandb is not None and self.strategy.is_rank_0(): self._wandb.finish() if self._tensorboard is not None and self.strategy.is_rank_0(): self._tensorboard.close()
[docs] def ppo_train(self, global_steps=0): """ PPO training loop over the replay buffer. :param global_steps: Current global step count, defaults to 0. :type global_steps: int :return: Dictionary of averaged training statistics. :rtype: dict """ torch.cuda.empty_cache() # Replay buffer may be empty at first, we should rebuild at each training dataloader = DataLoader( self.replay_buffer, batch_size=self.replay_buffer.sample_batch_size, shuffle=True, drop_last=True, pin_memory=self.dataloader_pin_memory, collate_fn=self.replay_buffer.collate_fn, ) device = torch.cuda.current_device() status_list = [] status_mean = {} for epoch in range(self.max_epochs): pbar = tqdm( dataloader, desc=f"Train epoch [{epoch + 1}/{self.max_epochs}]", disable=not self.strategy.is_rank_0(), ) for experience in pbar: experience.to_device(device) status = self.training_step(experience, global_steps) # For DP: weighted mean for KL if "kl" in status: status["kl"] *= status["response_length"] status = self.strategy.all_reduce(status) status["kl"] /= status["response_length"] short_status = {} if "policy_loss" in status: short_status = { "pg": status["policy_loss"], "rm": status["reward"], "ret": status["return"], "glen": status["response_length"], "tlen": status["total_length"], "kl": status["kl"], "act_lr": status["actor_lr"], } if "critic_loss" in status: short_status["cri"] = status["critic_loss"] short_status["vals"] = status["values"] short_status["cri_lr"] = status["critic_lr"] if "ptx_loss" in status: short_status["ptx"] = status["ptx_loss"] status_list.append(status) pbar.set_postfix(short_status) if status_list: status_mean = status_list[0] for m in status_list[1:]: for k, v in m.items(): status_mean[k] += v for k in status_mean.keys(): status_mean[k] /= len(status_list) torch.cuda.empty_cache() return status_mean
[docs] def training_step(self, experience: Experience, global_steps, entropy_mask: Optional[torch.Tensor] = None) -> Dict[str, float]: """ Single training step combining actor and critic updates. :param experience: Experience batch from replay buffer. :type experience: Experience :param global_steps: Current global step count. :type global_steps: int :param entropy_mask: Optional mask for high-entropy tokens. :type entropy_mask: Optional[torch.Tensor] :return: Dictionary of training statistics. :rtype: Dict[str, float] """ status = {} if global_steps > self.freezing_actor_steps: status = self.training_step_actor(experience, entropy_mask=entropy_mask) if self.critic is not None: status.update(self.training_step_critic(experience)) return status
[docs] def training_step_actor(self, experience: Experience, entropy_mask: Optional[torch.Tensor] = None) -> Dict[str, float]: """ Actor training step. :param experience: Experience batch from replay buffer. :type experience: Experience :return: Dictionary of actor training statistics. :rtype: Dict[str, float] """ self.actor.train() # TODO: This is a bad indicator to say that data is packed... not supported if isinstance(experience.sequences, list): sequences = torch.cat(experience.sequences, dim=0).unsqueeze(0) old_action_log_probs = torch.cat(experience.action_log_probs, dim=0).unsqueeze(0) advantages = torch.cat(experience.advantages, dim=0).unsqueeze(0) num_actions = [v.numel() for v in experience.advantages] packed_seq_lens = [s.numel() for s in experience.sequences] attention_mask = torch.cat([torch.full_like(s, i + 1) for i, s in enumerate(experience.sequences)], dim=0).unsqueeze(0) if self.args.use_kl_loss and experience.base_action_log_probs is not None: base_action_log_probs = torch.cat(experience.base_action_log_probs, dim=0).unsqueeze(0) else: sequences = experience.sequences old_action_log_probs = experience.action_log_probs advantages = experience.advantages num_actions = experience.action_mask.size(1) packed_seq_lens = None attention_mask = experience.attention_mask if self.args.use_kl_loss and experience.base_action_log_probs is not None: base_action_log_probs = experience.base_action_log_probs # Actor loss action_log_probs, output = self.actor( sequences, num_actions, attention_mask=attention_mask, return_output=True, packed_seq_lens=packed_seq_lens, ) # Loss function actor_loss = self.actor_loss_fn( action_log_probs, old_action_log_probs, advantages, action_mask=experience.action_mask, entropy_mask=entropy_mask, ) if self.args.use_kl_loss: if self.initial_model is not None: kl = compute_approx_kl( action_log_probs, base_action_log_probs, experience.action_mask, kl_estimator=self.args.kl_estimator, ) else: kl = torch.zeros_like(action_log_probs, dtype=action_log_probs.dtype, device=action_log_probs.device) if not self.args.packing_samples: kl_mean = masked_mean(kl, experience.action_mask, dim=-1) # Not supported for packed samples else: # Convert tensor into list of tensors for easier manipulation within dataset kl = unpacking_samples(kl, num_actions) kl_mean = torch.tensor([each_kl.mean() for each_kl in kl], device=action_log_probs.device) kl_loss = kl_mean.mean() experience.info["kl"] = kl_loss.item() else: kl_loss = 0 # Mixtral auxiliary loss if self.aux_loss: aux_loss = output.aux_loss else: aux_loss = 0 loss = actor_loss + aux_loss * self.args.aux_loss_coef + kl_loss * self.kl_ctl.value self.strategy.backward(loss, self.actor, self.actor_optim) # PTX loss if self.pretrain_dataloader is not None: data = next(self.pretrain_dataloader) inputs = data[1].squeeze(1).to(torch.cuda.current_device()) attention_mask = data[2].squeeze(1).to(torch.cuda.current_device()) label = torch.where( attention_mask.bool(), inputs, self.ptx_loss_fn.IGNORE_INDEX, ) output = self.actor(inputs, attention_mask=attention_mask, return_output=True) ptx_log_probs = output["logits"] # Loss function ptx_loss = self.ptx_loss_fn(ptx_log_probs, label) # Mixtral auxiliary loss if self.aux_loss: aux_loss = output.aux_loss else: aux_loss = 0 loss = ptx_loss + aux_loss * self.args.aux_loss_coef self.strategy.backward(self.ptx_coef * loss, self.actor, self.actor_optim) self.strategy.optimizer_step(self.actor_optim, self.actor, self.actor_scheduler, name="actor") if self.ema_model: self.strategy.moving_average(self.actor, self.ema_model, self.ema_beta, "cuda") # Status status = {"policy_loss": actor_loss.item(), "actor_lr": self.actor_scheduler.get_last_lr()[0]} if self.pretrain_dataloader is not None: status["ptx_loss"] = ptx_loss.item() # Add ratio and loss component statistics from PolicyLoss for diagnosis if hasattr(self.actor_loss_fn, 'get_last_stats'): policy_stats = self.actor_loss_fn.get_last_stats() status.update(policy_stats) for k, v in experience.info.items(): if k == "kl": status[k] = ((v * experience.info["response_length"]).sum() / experience.info["response_length"].sum()).item() else: status[k] = v.mean().item() return status
[docs] def training_step_critic(self, experience: Experience) -> Dict[str, float]: """ Critic training step. :param experience: Experience batch from replay buffer. :type experience: Experience :return: Dictionary of critic training statistics. :rtype: Dict[str, float] """ self.critic.train() # TODO: This is a bad indicator to say that data is packed... not supported if isinstance(experience.sequences, list): sequences = torch.cat(experience.sequences, dim=0).unsqueeze(0) old_values = torch.cat(experience.values, dim=0).unsqueeze(0) returns = torch.cat(experience.returns, dim=0).unsqueeze(0) num_actions = [v.numel() for v in experience.advantages] packed_seq_lens = [s.numel() for s in experience.sequences] attention_mask = torch.cat([torch.full_like(s, i + 1) for i, s in enumerate(experience.sequences)], dim=0).unsqueeze(0) else: sequences = experience.sequences old_values = experience.values returns = experience.returns num_actions = experience.action_mask.size(1) packed_seq_lens = None attention_mask = experience.attention_mask # Critic loss values, output = self.critic( sequences, num_actions=num_actions, attention_mask=attention_mask, return_output=True, packed_seq_lens=packed_seq_lens, ) # Loss function critic_loss = self.critic_loss_fn( values, old_values, returns, action_mask=experience.action_mask, ) # Mixtral auxiliary loss if self.aux_loss: aux_loss = output.aux_loss else: aux_loss = 0 loss = critic_loss + aux_loss * self.args.aux_loss_coef self.strategy.backward(loss, self.critic, self.critic_optim) self.strategy.optimizer_step(self.critic_optim, self.critic, self.critic_scheduler, name="critic") # Status status = { "critic_loss": critic_loss.item(), "values": masked_mean(values, experience.action_mask).item(), "critic_lr": self.critic_scheduler.get_last_lr()[0], } return status
[docs] def save_logs_and_checkpoints(self, args, global_step, step_bar, logs_dict={}, client_states={}): """ Save logs to wandb/tensorboard and save model checkpoints. :param args: Training arguments. :type args: Namespace :param global_step: Current global step. :type global_step: int :param step_bar: Progress bar object. :type step_bar: tqdm :param logs_dict: Dictionary of metrics to log, defaults to {}. :type logs_dict: dict :param client_states: Client state for checkpoint recovery, defaults to {}. :type client_states: dict """ if global_step % args.logging_steps == 0: # Define which metrics should be excluded from train/ logs to avoid duplication # These metrics are already logged in the rollout/ namespace ROLLOUT_ONLY_METRICS = { 'reward', # Already logged as rollout/reward 'response_length', # Already logged as rollout/response_length 'total_length', # Rollout-specific metric 'num_actions', # Rollout-specific metric 'return', # Rollout-specific metric (computed from rewards) } # Also exclude reward_metrics sub-keys (format_reward, accuracy_reward) ROLLOUT_ONLY_METRIC_PREFIXES = {'reward_metrics/'} # Separate rollout and training metrics for clarity rollout_metrics = {} train_metrics = {} for k, v in logs_dict.items(): if k.startswith('rollout_'): # Remove 'rollout_' prefix and log under rollout/ namespace clean_key = k.replace('rollout_', '', 1) rollout_metrics[clean_key] = v elif k in ROLLOUT_ONLY_METRICS: # Skip metrics that are already in rollout/ namespace continue elif any(k.startswith(prefix) for prefix in ROLLOUT_ONLY_METRIC_PREFIXES): # Skip reward_metrics/* sub-keys continue else: # Training-specific metrics go under train/ namespace train_metrics[k] = v # Wandb logging if self._wandb is not None and self.strategy.is_rank_0(): # Log rollout metrics with rollout/ prefix if rollout_metrics: rollout_logs = {f"rollout/{k}": v for k, v in rollout_metrics.items()} rollout_logs["rollout/global_step"] = global_step self._wandb.log(rollout_logs) # Log training metrics with train/ prefix if train_metrics: train_logs = {f"train/{k}": v for k, v in train_metrics.items()} train_logs["train/global_step"] = global_step self._wandb.log(train_logs) # Log performance stats if self.experience_maker.perf_stats is not None: perf_logs = {f"perf/experience_maker/{k}": v for k, v in self.experience_maker.perf_stats.items()} self._wandb.log(perf_logs) # TensorBoard logging elif self._tensorboard is not None and self.strategy.is_rank_0(): for k, v in rollout_metrics.items(): self._tensorboard.add_scalar(f"rollout/{k}", v, global_step) for k, v in train_metrics.items(): self._tensorboard.add_scalar(f"train/{k}", v, global_step) if self.experience_maker.perf_stats is not None: for k, v in self.experience_maker.perf_stats.items(): self._tensorboard.add_scalar(f"perf/experience_maker/{k}", v, global_step) # TODO: Add evaluation mechanism for PPO if global_step % args.eval_steps == 0: # self.evaluate(self.eval_dataloader, global_step) pass # Save checkpoint # TODO: Save best model on dev, use loss/perplexity/others on whole dev dataset as metric if global_step % args.save_steps == 0: tag = f"global_step{global_step}" self._save_checkpoint(args, tag, client_states)
def _save_checkpoint(self, args, tag, client_states): """ Save model checkpoint to disk. :param args: Training arguments. :type args: Namespace :param tag: Checkpoint tag (e.g., "global_step1000"). :type tag: str :param client_states: Client state for checkpoint recovery. :type client_states: dict """ if not self.disable_ds_ckpt: self.strategy.save_ckpt( self.actor.model, os.path.join(args.ckpt_path, "_actor"), tag, args.max_ckpt_num, args.max_ckpt_mem, client_states, ) if self.critic is not None: self.strategy.save_ckpt( self.critic, os.path.join(args.ckpt_path, "_critic"), tag, args.max_ckpt_num, args.max_ckpt_mem ) if self.save_hf_ckpt: # Get current filename and line number for debugging current_filename = os.path.basename(__file__) current_lineno = sys._getframe().f_lineno self.strategy.print(f"[{current_filename}:{current_lineno}] self.save_hf_ckpt: {self.save_hf_ckpt}") save_path = os.path.join(args.ckpt_path, f"{tag}_hf") self.strategy.save_model(self.actor, self.tokenizer, save_path)