Shortcuts

Source code for lightrft.trainer.kl_controller

import numpy as np


[docs]class AdaptiveKLController: """ Adaptive KL controller for PPO training. Implements the adaptive KL penalty coefficient adjustment described in: "Fine-Tuning Language Models from Human Preferences" (https://arxiv.org/pdf/1909.08593.pdf) This controller dynamically adjusts the KL penalty coefficient based on how the current KL divergence compares to a target value, helping maintain stable training while preventing the policy from deviating too far from the reference. """ def __init__(self, init_kl_coef: float, target: float, horizon: int) -> None: self.value = init_kl_coef self.target = target self.horizon = horizon
[docs] def update(self, current: float, n_steps: int) -> None: """ Update KL coefficient using adaptive algorithm. :param current: Current KL divergence value. :type current: float :param n_steps: Number of training steps taken. :type n_steps: int """ target = self.target proportional_error = np.clip(current / target - 1, -0.2, 0.2) mult = 1 + proportional_error * n_steps / self.horizon self.value *= mult
[docs]class FixedKLController: """ Fixed KL controller that maintains a constant KL penalty coefficient. Unlike the adaptive controller, this keeps the KL coefficient fixed throughout training, providing consistent regularization strength. """ def __init__(self, kl_coef: float) -> None: self.value = kl_coef
[docs] def update(self, current: float, n_steps: int) -> None: """ Update KL controller state (no-op for fixed KL). :param current: Current KL divergence value (unused). :type current: float :param n_steps: Number of training steps (unused). :type n_steps: int """ pass