Source code for lightrft.trainer.ppo_trainer
import os
import sys
import os.path
from abc import ABC
from typing import Any, Callable, Dict, List, Optional, Union
import torch
import torch.nn as nn
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from tqdm import tqdm
from lightrft.models import ActorLanguage, GPTLMLoss, PolicyLoss, ValueLoss
from lightrft.models.utils import masked_mean, unpacking_samples, compute_approx_kl
from lightrft.utils.distributed_sampler import DistributedSampler
from lightrft.trainer import AdaptiveKLController, Experience, FixedKLController, NaiveExperienceMaker, NaiveReplayBuffer # noqa
[docs]class PPOTrainer(ABC):
"""
Trainer for Proximal Policy Optimization (PPO) algorithm.
:param strategy: The training strategy to use.
:type strategy: Strategy
:param actor: The actor model in the PPO algorithm.
:type actor: ActorLanguage
:param critic: The critic model in the PPO algorithm.
:type critic: nn.Module
:param reward_model: The reward model for calculating rewards in the RLHF setup.
:type reward_model: nn.Module
:param initial_model: The initial model for reference logits to limit actor updates in RLHF.
:type initial_model: ActorLanguage
:param ema_model: The exponential moving average model for stable training.
:type ema_model: ActorLanguage
:param actor_optim: The optimizer for the actor model.
:type actor_optim: Optimizer
:param critic_optim: The optimizer for the critic model.
:type critic_optim: Optimizer
:param actor_scheduler: The learning rate scheduler for the actor.
:type actor_scheduler: Scheduler
:param critic_scheduler: The learning rate scheduler for the critic.
:type critic_scheduler: Scheduler
:param ema_beta: EMA decay rate for model stability, defaults to 0.992.
:type ema_beta: float
:param init_kl_coef: Initial coefficient for KL divergence, defaults to 0.001.
:type init_kl_coef: float
:param kl_target: Target value for KL divergence, defaults to None.
:type kl_target: float, optional
:param kl_horizon: Horizon for KL annealing, defaults to 10000.
:type kl_horizon: int
:param ptx_coef: Coefficient for supervised loss from pre-trained data, defaults to 0.
:type ptx_coef: float
:param micro_train_batch_size: Micro-batch size for actor training, defaults to 8.
:type micro_train_batch_size: int
:param buffer_limit: Maximum size of the replay buffer, defaults to 0.
:type buffer_limit: int
:param buffer_cpu_offload: If True, offloads replay buffer to CPU, defaults to True.
:type buffer_cpu_offload: bool
:param eps_clip: Clipping coefficient for policy loss, defaults to 0.2.
:type eps_clip: float
:param value_clip: Clipping coefficient for value function loss, defaults to 0.2.
:type value_clip: float
:param micro_rollout_batch_size: Micro-batch size for generating rollouts, defaults to 8.
:type micro_rollout_batch_size: int
:param gradient_checkpointing: If True, enables gradient checkpointing, defaults to False.
:type gradient_checkpointing: bool
:param max_epochs: Number of epochs to train, defaults to 1.
:type max_epochs: int
:param max_norm: Maximum gradient norm for gradient clipping, defaults to 1.0.
:type max_norm: float
:param tokenizer: Tokenizer for input data, defaults to None.
:type tokenizer: Callable, optional
:param prompt_max_len: Maximum length for prompts, defaults to 128.
:type prompt_max_len: int
:param dataloader_pin_memory: If True, pins memory in the data loader, defaults to True.
:type dataloader_pin_memory: bool
:param remote_rm_url: URL for remote reward model API, defaults to None.
:type remote_rm_url: str, optional
:param reward_fn: Custom reward function for computing rewards, defaults to None.
:type reward_fn: Callable, optional
:param save_hf_ckpt: Whether to save huggingface-format model weight, defaults to False.
:type save_hf_ckpt: bool
:param disable_ds_ckpt: Whether not to save deepspeed-format model weight (used for training recovery).
:type disable_ds_ckpt: bool
:param generate_kwargs: Additional arguments for model generation.
:type generate_kwargs: dict
"""
def __init__(
self,
strategy,
actor: ActorLanguage,
critic: nn.Module,
reward_model: Union[nn.Module, List[nn.Module]],
initial_model: ActorLanguage,
ema_model: ActorLanguage,
actor_optim: Optimizer,
critic_optim: Optimizer,
actor_scheduler,
critic_scheduler,
ema_beta: float = 0.992,
init_kl_coef: float = 0.001,
kl_target: Optional[float] = None,
kl_horizon: int = 10000,
ptx_coef: float = 0,
micro_train_batch_size: int = 8,
buffer_limit: int = 0,
buffer_cpu_offload: bool = True,
eps_clip: float = 0.2,
value_clip: float = 0.2,
micro_rollout_batch_size: int = 8,
gradient_checkpointing: bool = False,
max_epochs: int = 1,
max_norm: float = 1.0,
tokenizer: Optional[Callable[[Any], dict]] = None,
prompt_max_len: int = 128,
dataloader_pin_memory: bool = True,
remote_rm_url: Optional[str] = None,
reward_fn: Optional[Callable[[List[torch.Tensor]], torch.Tensor]] = None,
save_hf_ckpt: bool = False,
disable_ds_ckpt: bool = False,
**generate_kwargs,
) -> None:
assert (
not isinstance(reward_model, List) or len(reward_model) == 1 or reward_fn is not None
), "reward_fn must be specified if using multiple reward models"
ABC.__init__(self)
# Get current filename and line number for debugging
current_filename = os.path.basename(__file__)
current_lineno = sys._getframe().f_lineno
self.strategy.print(f"[{current_filename}:{current_lineno}]")
self.strategy = strategy
self.args = strategy.args
self.save_hf_ckpt = save_hf_ckpt
self.disable_ds_ckpt = disable_ds_ckpt
self.micro_rollout_batch_size = micro_rollout_batch_size
self.max_epochs = max_epochs
self.tokenizer = tokenizer
self.generate_kwargs = generate_kwargs
self.dataloader_pin_memory = dataloader_pin_memory
self.max_norm = max_norm
self.ptx_coef = ptx_coef
self.micro_train_batch_size = micro_train_batch_size
self.kl_target = kl_target
self.prompt_max_len = prompt_max_len
self.ema_beta = ema_beta
self.gradient_checkpointing = gradient_checkpointing
self.reward_fn = reward_fn
self.actor = actor
self.critic = critic
self.reward_model = reward_model
self.remote_rm_url = remote_rm_url
self.initial_model = initial_model
self.ema_model = ema_model
self.actor_optim = actor_optim
self.critic_optim = critic_optim
self.actor_scheduler = actor_scheduler
self.critic_scheduler = critic_scheduler
self.actor_loss_fn = PolicyLoss(eps_clip)
self.critic_loss_fn = ValueLoss(value_clip)
self.ptx_loss_fn = GPTLMLoss()
self.freezing_actor_steps = getattr(self.args, "freezing_actor_steps", -1)
# Mixtral 8x7b auxiliary loss
self.aux_loss = self.args.aux_loss_coef > 1e-8
if self.kl_target:
self.kl_ctl = AdaptiveKLController(init_kl_coef, kl_target, kl_horizon)
else:
self.kl_ctl = FixedKLController(init_kl_coef)
self.experience_maker = NaiveExperienceMaker(
actor,
critic,
reward_model,
initial_model,
tokenizer,
prompt_max_len,
self.kl_ctl,
strategy,
remote_rm_url,
reward_fn,
)
packing_samples = getattr(self.args, "packing_samples", False)
self.replay_buffer = NaiveReplayBuffer(
micro_train_batch_size, buffer_limit, buffer_cpu_offload, packing_samples
)
# Initialize wandb/tensorboard for logging
self._wandb = None
self._tensorboard = None
if self.strategy.args.use_wandb and self.strategy.is_rank_0():
import wandb
self._wandb = wandb
if not wandb.api.api_key:
wandb.login(key=strategy.args.use_wandb)
wandb.init(
entity=strategy.args.wandb_org,
project=strategy.args.wandb_project,
group=strategy.args.wandb_group,
name=strategy.args.wandb_run_name,
config=strategy.args.__dict__,
reinit=True,
)
# Define separate metric namespaces for clarity:
# - rollout/*: Metrics from experience generation phase
# - train/*: Metrics from policy optimization phase
# - eval/*: Metrics from evaluation phase
wandb.define_metric("rollout/global_step")
wandb.define_metric("rollout/*", step_metric="rollout/global_step", step_sync=True)
wandb.define_metric("train/global_step")
wandb.define_metric("train/*", step_metric="train/global_step", step_sync=True)
wandb.define_metric("eval/epoch")
wandb.define_metric("eval/*", step_metric="eval/epoch", step_sync=True)
# Initialize TensorBoard writer if wandb is not available
if self.strategy.args.use_tensorboard and self._wandb is None and self.strategy.is_rank_0():
from torch.utils.tensorboard import SummaryWriter
os.makedirs(self.strategy.args.use_tensorboard, exist_ok=True)
log_dir = os.path.join(self.strategy.args.use_tensorboard, strategy.args.wandb_run_name)
self._tensorboard = SummaryWriter(log_dir=log_dir)
[docs] def fit(
self,
args,
prompts_dataloader,
pretrain_dataloader,
consumed_samples=0,
num_update_steps_per_episodes=1,
) -> None:
"""
Main training loop for PPO.
:param args: Training arguments.
:type args: Namespace
:param prompts_dataloader: DataLoader for prompt data.
:type prompts_dataloader: DataLoader
:param pretrain_dataloader: DataLoader for pre-training data.
:type pretrain_dataloader: DataLoader
:param consumed_samples: Number of samples already consumed, defaults to 0.
:type consumed_samples: int
:param num_update_steps_per_episodes: Number of update steps per episode, defaults to 1.
:type num_update_steps_per_episodes: int
"""
num_rollouts_per_episodes = (
num_update_steps_per_episodes * args.train_batch_size // args.max_epochs // args.rollout_batch_size //
args.n_samples_per_prompt
)
# Get eval and save steps
if args.eval_steps == -1:
args.eval_steps = num_rollouts_per_episodes # Evaluate once per epoch
if args.save_steps == -1:
args.save_steps = float("inf") # Do not save checkpoint
self.prompts_dataloader = prompts_dataloader
self.pretrain_dataloader = pretrain_dataloader
# Restore step and start_episode
steps = consumed_samples // args.rollout_batch_size + 1
start_episode = consumed_samples // args.rollout_batch_size // num_rollouts_per_episodes
consumed_samples = consumed_samples % (num_rollouts_per_episodes * args.rollout_batch_size)
for episode in range(start_episode, args.num_episodes):
if isinstance(self.prompts_dataloader.sampler, DistributedSampler):
self.prompts_dataloader.sampler.set_epoch(
episode, consumed_samples=0 if episode > start_episode else consumed_samples
)
pbar = tqdm(
range(self.prompts_dataloader.__len__()),
desc=f"Episode [{episode + 1}/{args.num_episodes}]",
disable=not self.strategy.is_rank_0(),
)
for rand_prompts, labels in self.prompts_dataloader:
for i, experience in enumerate(
self.experience_maker.make_experience_list(rand_prompts, all_labels=labels, **self.generate_kwargs)
):
if i == 0:
output = self.tokenizer.batch_decode(
experience.sequences[0].unsqueeze(0), skip_special_tokens=True
)
self.strategy.print(output)
self.replay_buffer.append(experience)
self.strategy.report_memory('after replay_buffer ready')
if self.args.advantage_estimator != "group_norm":
self.replay_buffer.normalize("advantages", self.strategy)
self.strategy.report_memory('before train')
status = self.ppo_train(steps)
self.strategy.report_memory('before clear buffer')
self.replay_buffer.clear()
self.strategy.report_memory('after train')
if "kl" in status:
self.kl_ctl.update(status["kl"], args.rollout_batch_size * args.n_samples_per_prompt)
pbar.set_postfix(status)
# Logs/checkpoints
client_states = {"consumed_samples": steps * args.rollout_batch_size}
self.save_logs_and_checkpoints(args, steps, pbar, status, client_states)
pbar.update()
steps = steps + 1
if self._wandb is not None and self.strategy.is_rank_0():
self._wandb.finish()
if self._tensorboard is not None and self.strategy.is_rank_0():
self._tensorboard.close()
[docs] def ppo_train(self, global_steps=0):
"""
PPO training loop over the replay buffer.
:param global_steps: Current global step count, defaults to 0.
:type global_steps: int
:return: Dictionary of averaged training statistics.
:rtype: dict
"""
torch.cuda.empty_cache()
# Replay buffer may be empty at first, we should rebuild at each training
dataloader = DataLoader(
self.replay_buffer,
batch_size=self.replay_buffer.sample_batch_size,
shuffle=True,
drop_last=True,
pin_memory=self.dataloader_pin_memory,
collate_fn=self.replay_buffer.collate_fn,
)
device = torch.cuda.current_device()
status_list = []
status_mean = {}
for epoch in range(self.max_epochs):
pbar = tqdm(
dataloader,
desc=f"Train epoch [{epoch + 1}/{self.max_epochs}]",
disable=not self.strategy.is_rank_0(),
)
for experience in pbar:
experience.to_device(device)
status = self.training_step(experience, global_steps)
# For DP: weighted mean for KL
if "kl" in status:
status["kl"] *= status["response_length"]
status = self.strategy.all_reduce(status)
status["kl"] /= status["response_length"]
short_status = {}
if "policy_loss" in status:
short_status = {
"pg": status["policy_loss"],
"rm": status["reward"],
"ret": status["return"],
"glen": status["response_length"],
"tlen": status["total_length"],
"kl": status["kl"],
"act_lr": status["actor_lr"],
}
if "critic_loss" in status:
short_status["cri"] = status["critic_loss"]
short_status["vals"] = status["values"]
short_status["cri_lr"] = status["critic_lr"]
if "ptx_loss" in status:
short_status["ptx"] = status["ptx_loss"]
status_list.append(status)
pbar.set_postfix(short_status)
if status_list:
status_mean = status_list[0]
for m in status_list[1:]:
for k, v in m.items():
status_mean[k] += v
for k in status_mean.keys():
status_mean[k] /= len(status_list)
torch.cuda.empty_cache()
return status_mean
[docs] def training_step(self,
experience: Experience,
global_steps,
entropy_mask: Optional[torch.Tensor] = None) -> Dict[str, float]:
"""
Single training step combining actor and critic updates.
:param experience: Experience batch from replay buffer.
:type experience: Experience
:param global_steps: Current global step count.
:type global_steps: int
:param entropy_mask: Optional mask for high-entropy tokens.
:type entropy_mask: Optional[torch.Tensor]
:return: Dictionary of training statistics.
:rtype: Dict[str, float]
"""
status = {}
if global_steps > self.freezing_actor_steps:
status = self.training_step_actor(experience, entropy_mask=entropy_mask)
if self.critic is not None:
status.update(self.training_step_critic(experience))
return status
[docs] def training_step_actor(self,
experience: Experience,
entropy_mask: Optional[torch.Tensor] = None) -> Dict[str, float]:
"""
Actor training step.
:param experience: Experience batch from replay buffer.
:type experience: Experience
:return: Dictionary of actor training statistics.
:rtype: Dict[str, float]
"""
self.actor.train()
# TODO: This is a bad indicator to say that data is packed... not supported
if isinstance(experience.sequences, list):
sequences = torch.cat(experience.sequences, dim=0).unsqueeze(0)
old_action_log_probs = torch.cat(experience.action_log_probs, dim=0).unsqueeze(0)
advantages = torch.cat(experience.advantages, dim=0).unsqueeze(0)
num_actions = [v.numel() for v in experience.advantages]
packed_seq_lens = [s.numel() for s in experience.sequences]
attention_mask = torch.cat([torch.full_like(s, i + 1) for i, s in enumerate(experience.sequences)],
dim=0).unsqueeze(0)
if self.args.use_kl_loss and experience.base_action_log_probs is not None:
base_action_log_probs = torch.cat(experience.base_action_log_probs, dim=0).unsqueeze(0)
else:
sequences = experience.sequences
old_action_log_probs = experience.action_log_probs
advantages = experience.advantages
num_actions = experience.action_mask.size(1)
packed_seq_lens = None
attention_mask = experience.attention_mask
if self.args.use_kl_loss and experience.base_action_log_probs is not None:
base_action_log_probs = experience.base_action_log_probs
# Actor loss
action_log_probs, output = self.actor(
sequences,
num_actions,
attention_mask=attention_mask,
return_output=True,
packed_seq_lens=packed_seq_lens,
)
# Loss function
actor_loss = self.actor_loss_fn(
action_log_probs,
old_action_log_probs,
advantages,
action_mask=experience.action_mask,
entropy_mask=entropy_mask,
)
if self.args.use_kl_loss:
if self.initial_model is not None:
kl = compute_approx_kl(
action_log_probs,
base_action_log_probs,
experience.action_mask,
kl_estimator=self.args.kl_estimator,
)
else:
kl = torch.zeros_like(action_log_probs, dtype=action_log_probs.dtype, device=action_log_probs.device)
if not self.args.packing_samples:
kl_mean = masked_mean(kl, experience.action_mask, dim=-1)
# Not supported for packed samples
else:
# Convert tensor into list of tensors for easier manipulation within dataset
kl = unpacking_samples(kl, num_actions)
kl_mean = torch.tensor([each_kl.mean() for each_kl in kl], device=action_log_probs.device)
kl_loss = kl_mean.mean()
experience.info["kl"] = kl_loss.item()
else:
kl_loss = 0
# Mixtral auxiliary loss
if self.aux_loss:
aux_loss = output.aux_loss
else:
aux_loss = 0
loss = actor_loss + aux_loss * self.args.aux_loss_coef + kl_loss * self.kl_ctl.value
self.strategy.backward(loss, self.actor, self.actor_optim)
# PTX loss
if self.pretrain_dataloader is not None:
data = next(self.pretrain_dataloader)
inputs = data[1].squeeze(1).to(torch.cuda.current_device())
attention_mask = data[2].squeeze(1).to(torch.cuda.current_device())
label = torch.where(
attention_mask.bool(),
inputs,
self.ptx_loss_fn.IGNORE_INDEX,
)
output = self.actor(inputs, attention_mask=attention_mask, return_output=True)
ptx_log_probs = output["logits"]
# Loss function
ptx_loss = self.ptx_loss_fn(ptx_log_probs, label)
# Mixtral auxiliary loss
if self.aux_loss:
aux_loss = output.aux_loss
else:
aux_loss = 0
loss = ptx_loss + aux_loss * self.args.aux_loss_coef
self.strategy.backward(self.ptx_coef * loss, self.actor, self.actor_optim)
self.strategy.optimizer_step(self.actor_optim, self.actor, self.actor_scheduler, name="actor")
if self.ema_model:
self.strategy.moving_average(self.actor, self.ema_model, self.ema_beta, "cuda")
# Status
status = {"policy_loss": actor_loss.item(), "actor_lr": self.actor_scheduler.get_last_lr()[0]}
if self.pretrain_dataloader is not None:
status["ptx_loss"] = ptx_loss.item()
# Add ratio and loss component statistics from PolicyLoss for diagnosis
if hasattr(self.actor_loss_fn, 'get_last_stats'):
policy_stats = self.actor_loss_fn.get_last_stats()
status.update(policy_stats)
for k, v in experience.info.items():
if k == "kl":
status[k] = ((v * experience.info["response_length"]).sum() /
experience.info["response_length"].sum()).item()
else:
status[k] = v.mean().item()
return status
[docs] def training_step_critic(self, experience: Experience) -> Dict[str, float]:
"""
Critic training step.
:param experience: Experience batch from replay buffer.
:type experience: Experience
:return: Dictionary of critic training statistics.
:rtype: Dict[str, float]
"""
self.critic.train()
# TODO: This is a bad indicator to say that data is packed... not supported
if isinstance(experience.sequences, list):
sequences = torch.cat(experience.sequences, dim=0).unsqueeze(0)
old_values = torch.cat(experience.values, dim=0).unsqueeze(0)
returns = torch.cat(experience.returns, dim=0).unsqueeze(0)
num_actions = [v.numel() for v in experience.advantages]
packed_seq_lens = [s.numel() for s in experience.sequences]
attention_mask = torch.cat([torch.full_like(s, i + 1) for i, s in enumerate(experience.sequences)],
dim=0).unsqueeze(0)
else:
sequences = experience.sequences
old_values = experience.values
returns = experience.returns
num_actions = experience.action_mask.size(1)
packed_seq_lens = None
attention_mask = experience.attention_mask
# Critic loss
values, output = self.critic(
sequences,
num_actions=num_actions,
attention_mask=attention_mask,
return_output=True,
packed_seq_lens=packed_seq_lens,
)
# Loss function
critic_loss = self.critic_loss_fn(
values,
old_values,
returns,
action_mask=experience.action_mask,
)
# Mixtral auxiliary loss
if self.aux_loss:
aux_loss = output.aux_loss
else:
aux_loss = 0
loss = critic_loss + aux_loss * self.args.aux_loss_coef
self.strategy.backward(loss, self.critic, self.critic_optim)
self.strategy.optimizer_step(self.critic_optim, self.critic, self.critic_scheduler, name="critic")
# Status
status = {
"critic_loss": critic_loss.item(),
"values": masked_mean(values, experience.action_mask).item(),
"critic_lr": self.critic_scheduler.get_last_lr()[0],
}
return status
[docs] def save_logs_and_checkpoints(self, args, global_step, step_bar, logs_dict={}, client_states={}):
"""
Save logs to wandb/tensorboard and save model checkpoints.
:param args: Training arguments.
:type args: Namespace
:param global_step: Current global step.
:type global_step: int
:param step_bar: Progress bar object.
:type step_bar: tqdm
:param logs_dict: Dictionary of metrics to log, defaults to {}.
:type logs_dict: dict
:param client_states: Client state for checkpoint recovery, defaults to {}.
:type client_states: dict
"""
if global_step % args.logging_steps == 0:
# Define which metrics should be excluded from train/ logs to avoid duplication
# These metrics are already logged in the rollout/ namespace
ROLLOUT_ONLY_METRICS = {
'reward', # Already logged as rollout/reward
'response_length', # Already logged as rollout/response_length
'total_length', # Rollout-specific metric
'num_actions', # Rollout-specific metric
'return', # Rollout-specific metric (computed from rewards)
}
# Also exclude reward_metrics sub-keys (format_reward, accuracy_reward)
ROLLOUT_ONLY_METRIC_PREFIXES = {'reward_metrics/'}
# Separate rollout and training metrics for clarity
rollout_metrics = {}
train_metrics = {}
for k, v in logs_dict.items():
if k.startswith('rollout_'):
# Remove 'rollout_' prefix and log under rollout/ namespace
clean_key = k.replace('rollout_', '', 1)
rollout_metrics[clean_key] = v
elif k in ROLLOUT_ONLY_METRICS:
# Skip metrics that are already in rollout/ namespace
continue
elif any(k.startswith(prefix) for prefix in ROLLOUT_ONLY_METRIC_PREFIXES):
# Skip reward_metrics/* sub-keys
continue
else:
# Training-specific metrics go under train/ namespace
train_metrics[k] = v
# Wandb logging
if self._wandb is not None and self.strategy.is_rank_0():
# Log rollout metrics with rollout/ prefix
if rollout_metrics:
rollout_logs = {f"rollout/{k}": v for k, v in rollout_metrics.items()}
rollout_logs["rollout/global_step"] = global_step
self._wandb.log(rollout_logs)
# Log training metrics with train/ prefix
if train_metrics:
train_logs = {f"train/{k}": v for k, v in train_metrics.items()}
train_logs["train/global_step"] = global_step
self._wandb.log(train_logs)
# Log performance stats
if self.experience_maker.perf_stats is not None:
perf_logs = {f"perf/experience_maker/{k}": v for k, v in self.experience_maker.perf_stats.items()}
self._wandb.log(perf_logs)
# TensorBoard logging
elif self._tensorboard is not None and self.strategy.is_rank_0():
for k, v in rollout_metrics.items():
self._tensorboard.add_scalar(f"rollout/{k}", v, global_step)
for k, v in train_metrics.items():
self._tensorboard.add_scalar(f"train/{k}", v, global_step)
if self.experience_maker.perf_stats is not None:
for k, v in self.experience_maker.perf_stats.items():
self._tensorboard.add_scalar(f"perf/experience_maker/{k}", v, global_step)
# TODO: Add evaluation mechanism for PPO
if global_step % args.eval_steps == 0:
# self.evaluate(self.eval_dataloader, global_step)
pass
# Save checkpoint
# TODO: Save best model on dev, use loss/perplexity/others on whole dev dataset as metric
if global_step % args.save_steps == 0:
tag = f"global_step{global_step}"
self._save_checkpoint(args, tag, client_states)
def _save_checkpoint(self, args, tag, client_states):
"""
Save model checkpoint to disk.
:param args: Training arguments.
:type args: Namespace
:param tag: Checkpoint tag (e.g., "global_step1000").
:type tag: str
:param client_states: Client state for checkpoint recovery.
:type client_states: dict
"""
if not self.disable_ds_ckpt:
self.strategy.save_ckpt(
self.actor.model,
os.path.join(args.ckpt_path, "_actor"),
tag,
args.max_ckpt_num,
args.max_ckpt_mem,
client_states,
)
if self.critic is not None:
self.strategy.save_ckpt(
self.critic, os.path.join(args.ckpt_path, "_critic"), tag, args.max_ckpt_num, args.max_ckpt_mem
)
if self.save_hf_ckpt:
# Get current filename and line number for debugging
current_filename = os.path.basename(__file__)
current_lineno = sys._getframe().f_lineno
self.strategy.print(f"[{current_filename}:{current_lineno}] self.save_hf_ckpt: {self.save_hf_ckpt}")
save_path = os.path.join(args.ckpt_path, f"{tag}_hf")
self.strategy.save_model(self.actor, self.tokenizer, save_path)