Shortcuts

lightrft.models.loss

Loss functions used across LightRFT models.

This module implements a comprehensive collection of loss functions for reinforcement learning from human feedback (RLHF) and related training paradigms:

Policy Optimization Losses: - PolicyLoss: Multi-purpose policy loss supporting PPO, CPGD (via use_cpg_loss), DAPO-style

decoupled clipping, and high-entropy token filtering for efficient training.

  • ValueLoss: Value function loss for PPO with optional value clipping.

Reward Model Losses: - GPTLMLoss: Next-token prediction loss for generative reward model training. - LogSigmoidLoss: Log-sigmoid pairwise loss for scalar reward model training. - LogExpLoss: Log-exp pairwise loss for scalar reward model training. - HPSLoss: Human Preference Score loss for scalar reward model training. - PairWiseLoss: Generic pairwise preference loss for reward models. - PRMLoss: Process Reward Model loss for token-level reward prediction.

Preference Learning Losses: - DPOLoss: Direct Preference Optimization loss for aligning language models with preferences. - KTOLoss: Kahneman-Tversky Optimization loss for uneven sampling scenarios. - VanillaKTOLoss: Simplified KTO loss for even sampling scenarios.

Knowledge Distillation: - KDLoss: Knowledge Distillation loss for transferring knowledge from teacher to student models.

All loss functions are designed to work seamlessly with the LightRFT training framework, supporting distributed training, mixed precision, and various optimization strategies.

class lightrft.models.loss.DPOLoss(*args: Any, **kwargs: Any)[source]

Bases: Module

DPO Loss

forward(policy_chosen_logps: torch.Tensor, policy_rejected_logps: torch.Tensor, reference_chosen_logps: torch.Tensor, reference_rejected_logps: torch.Tensor) Tuple[torch.Tensor, torch.Tensor, torch.Tensor][source]

Compute DPO (Direct Preference Optimization) loss.

Parameters:
  • policy_chosen_logps (torch.Tensor) – Log probabilities under policy for chosen samples.

  • policy_rejected_logps (torch.Tensor) – Log probabilities under policy for rejected samples.

  • reference_chosen_logps (torch.Tensor) – Log probabilities under reference model for chosen samples.

  • reference_rejected_logps (torch.Tensor) – Log probabilities under reference model for rejected samples.

Returns:

Tuple of (loss, chosen_rewards, rejected_rewards).

Return type:

Tuple[torch.Tensor, torch.Tensor, torch.Tensor]

class lightrft.models.loss.GPTLMLoss(*args: Any, **kwargs: Any)[source]

Bases: Module

GPT Language Model loss for next-token prediction. Used for generative reward model training.

Variables:
  • IGNORE_INDEX (int) – Label index to ignore when computing the cross-entropy (default: -100), matching Hugging Face conventions.

  • loss (torch.nn.CrossEntropyLoss) – Underlying cross-entropy criterion configured to ignore IGNORE_INDEX.

forward(logits: torch.Tensor, labels: torch.Tensor) torch.Tensor[source]

Compute next-token prediction loss.

Uses the common shifting scheme: shift_logits = logits[..., :-1, :] and shift_labels = labels[..., 1:].

Parameters:
  • logits (torch.Tensor) – Model output logits.

  • labels (torch.Tensor) – Token ids aligned with logits. Tokens to be ignored should be set to IGNORE_INDEX (default -100).

Returns:

Scalar mean cross-entropy loss.

Return type:

torch.Tensor

Shape logits:

(..., seq_len, vocab_size)

Shape labels:

(..., seq_len)

class lightrft.models.loss.HPSLoss(*args: Any, **kwargs: Any)[source]

Bases: Module

Human Preference Score (HPS) Loss for scalar reward model training. Implements the cross-entropy loss over the logits formed by concatenating the chosen and rejected rewards. The core idea is to treat the preference prediction as binary classification task.

Paper: https://arxiv.org/abs/2303.14420

forward(chosen_reward: torch.Tensor, reject_reward: torch.Tensor, margin: torch.Tensor | None = None) torch.Tensor[source]

Compute HPS loss.

Parameters:
  • chosen_reward (torch.Tensor) – Predicted reward for the preferred (chosen) sample.

  • reject_reward (torch.Tensor) – Predicted reward for the rejected sample.

  • margin (Optional[torch.Tensor]) – Unused; included for API compatibility with PairWiseLoss.

Returns:

Mean cross-entropy loss over the batch.

Return type:

torch.Tensor

class lightrft.models.loss.KDLoss(*args: Any, **kwargs: Any)[source]

Bases: Module

Language Model Knowledge Distillation Loss

forward(logits: torch.Tensor, teacher_logits: torch.Tensor, label: torch.Tensor) torch.Tensor[source]

Compute knowledge distillation loss.

Parameters:
  • logits (torch.Tensor) – Student model logits.

  • teacher_logits (torch.Tensor) – Teacher model logits (detached).

  • label (torch.Tensor) – Ground truth labels (tokens to ignore set to IGNORE_INDEX).

Returns:

Scalar KD loss.

Return type:

torch.Tensor

class lightrft.models.loss.KTOLoss(*args: Any, **kwargs: Any)[source]

Bases: Module

KTO loss for uneven sampling

forward(policy_chosen_logps: torch.FloatTensor, policy_rejected_logps: torch.FloatTensor, policy_KL_logps: torch.FloatTensor, reference_chosen_logps: torch.FloatTensor, reference_rejected_logps: torch.FloatTensor, reference_KL_logps: torch.FloatTensor) Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor][source]

Compute KTO loss for unevenly sampled chosen/rejected pairs with distributed KL estimation.

Parameters:
  • policy_chosen_logps (torch.FloatTensor) – Log probabilities under policy for chosen samples.

  • policy_rejected_logps (torch.FloatTensor) – Log probabilities under policy for rejected samples.

  • policy_KL_logps (torch.FloatTensor) – Log probabilities under policy for KL estimation samples.

  • reference_chosen_logps (torch.FloatTensor) – Log probabilities under reference model for chosen samples.

  • reference_rejected_logps (torch.FloatTensor) – Log probabilities under reference model for rejected samples.

  • reference_KL_logps (torch.FloatTensor) – Log probabilities under reference model for KL estimation samples.

Returns:

Tuple of (losses, chosen_rewards, rejected_rewards, KL).

Return type:

Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]

class lightrft.models.loss.LogExpLoss(*args: Any, **kwargs: Any)[source]

Bases: Module

Log-exp (softplus) pairwise loss for scalar reward model training.

This loss corresponds to log(1 + exp(reject - chosen)) averaged over the batch. See: https://arxiv.org/abs/2204.05862

forward(chosen_reward: torch.Tensor, reject_reward: torch.Tensor, margin: torch.Tensor | None = None) torch.Tensor[source]

Compute log-exp pairwise loss.

Parameters:
  • chosen_reward (torch.Tensor) – Predicted reward for the preferred (chosen) sample.

  • reject_reward (torch.Tensor) – Predicted reward for the rejected sample.

  • margin (Optional[torch.Tensor]) – Unused; included for API compatibility with PairWiseLoss.

Returns:

Mean log(1 + exp(reject - chosen)) over the batch.

Return type:

torch.Tensor

class lightrft.models.loss.LogSigmoidLoss(*args: Any, **kwargs: Any)[source]

Bases: Module

Pairwise preference loss for scalar reward models using the log-sigmoid objective.

Encourages the chosen sample to have a higher reward than the rejected sample. Optionally supports a non-negative margin.

forward(chosen_reward: torch.Tensor, reject_reward: torch.Tensor, margin: torch.Tensor | None = None) torch.Tensor[source]

Compute log-sigmoid pairwise loss.

Parameters:
  • chosen_reward (torch.Tensor) – Predicted reward for the preferred (chosen) sample.

  • reject_reward (torch.Tensor) – Predicted reward for the rejected sample.

  • margin (Optional[torch.Tensor]) – Optional non-negative margin. If provided, the objective becomes logsigmoid(chosen - reject - margin). Supports broadcasting across batch dimensions.

Returns:

Mean negative log-sigmoid loss over the batch.

Return type:

torch.Tensor

class lightrft.models.loss.PRMLoss(*args: Any, **kwargs: Any)[source]

Bases: Module

Process Reward Model Loss

forward(inputs: torch.Tensor, logits: torch.Tensor, labels: torch.Tensor, *, return_acc: bool = False) torch.Tensor | Tuple[torch.Tensor, torch.Tensor][source]

Compute process reward model loss.

Parameters:
  • inputs (torch.Tensor) – Input token IDs (used to locate placeholder tokens).

  • logits (torch.Tensor) – Model output logits.

  • labels (torch.Tensor) – Target labels (hard or soft labels for reward tokens).

  • return_acc (bool) – If True, also return accuracy.

Returns:

Loss tensor or tuple of (loss, accuracy) if return_acc=True.

Return type:

Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]

class lightrft.models.loss.PairWiseLoss(*args: Any, **kwargs: Any)[source]

Bases: Module

Pairwise Loss for Reward Model

forward(chosen_reward: torch.Tensor, reject_reward: torch.Tensor, margin: torch.Tensor = None) torch.Tensor[source]

Compute pairwise ranking loss.

Parameters:
  • chosen_reward (torch.Tensor) – Reward scores for chosen/preferred samples.

  • reject_reward (torch.Tensor) – Reward scores for rejected samples.

  • margin (Optional[torch.Tensor]) – Optional margin value to enforce separation.

Returns:

Mean negative log-sigmoid loss.

Return type:

torch.Tensor

class lightrft.models.loss.PolicyLoss(*args: Any, **kwargs: Any)[source]

Bases: Module

Multi-purpose policy loss function supporting multiple reinforcement learning algorithms.

This class implements a unified policy loss that can be configured to support various policy optimization algorithms including PPO, CPGD, and high-entropy token filtering strategies. The loss function computes clipped policy gradients with optional masking for efficient training.

Supported Algorithms:

  • PPO (Proximal Policy Optimization): Default mode using standard clipped surrogate objective. The loss is computed as -min(ratio * advantages, clipped_ratio * advantages) where ratio = exp(log_probs - old_log_probs) and clipping is applied to prevent large policy updates.

  • Clipped Policy Gradient Optimization with Policy Drift (CPGD): Enabled via use_cpg_loss=True. Uses asymmetric clipping bounds for positive and negative advantages, providing better stability for constrained policy optimization. See: https://arxiv.org/abs/2505.12504

  • High-Entropy Token Filtering: Enabled via high_entropy_token_ratio > 0 or by providing an entropy_mask in the forward pass. This feature allows training only on high-entropy tokens (forking tokens that determine reasoning directions), significantly improving training efficiency. Based on: https://arxiv.org/abs/2506.01939

Parameters:
  • clip_eps (float) – Clipping epsilon for PPO-style policy updates. Determines the maximum allowed ratio between new and old policy probabilities. Typical values range from 0.1 to 0.3. Default: 0.2

  • use_dapo (bool) – Flag for DAPO (Decoupled Clip and Dynamic sAmpling Policy Optimization). Currently reserved for future implementation. Default: False

  • use_cpg_loss (bool) – If True, uses CPGD-style clipped policy gradient loss with asymmetric clipping bounds. When False, uses standard PPO clipping. Default: False

  • high_entropy_token_ratio (float) – Ratio of high-entropy tokens to keep for training (e.g., 0.2 means top 20% highest entropy tokens). When > 0, enables high-entropy token filtering. Set to 0.0 to disable. Default: 0.0

Loss Computation:

The loss is computed as follows:

  1. Mask Application: Combines action_mask (valid tokens) with entropy_mask (high-entropy tokens) to create a final mask for loss computation.

  2. PPO Mode (default, use_cpg_loss=False): - Computes policy ratio: ratio = exp(log_probs - old_log_probs) - Clips ratio: clipped_ratio = clamp(ratio, 1 - clip_eps, 1 + clip_eps) - Loss: -min(ratio * advantages, clipped_ratio * advantages)

  3. CPGD Mode (use_cpg_loss=True): - Uses asymmetric clipping: upper bound log(1 + clip_eps) for positive advantages,

    lower bound log(1 - clip_eps) for negative advantages

    • Loss: -clipped_log_probs * advantages

  4. Masking: The computed loss is masked using final_mask and averaged only over valid, high-entropy tokens (if enabled).

Example Usage:

# Standard PPO loss
policy_loss = PolicyLoss(clip_eps=0.2)
loss = policy_loss(log_probs, old_log_probs, advantages, action_mask)

# CPGD loss
policy_loss = PolicyLoss(clip_eps=0.2, use_cpg_loss=True)
loss = policy_loss(log_probs, old_log_probs, advantages, action_mask)

# PPO with high-entropy token filtering (top 20%)
policy_loss = PolicyLoss(clip_eps=0.2, high_entropy_token_ratio=0.2)
loss = policy_loss(log_probs, old_log_probs, advantages, action_mask, entropy_mask)

References:

forward(log_probs: torch.Tensor, old_log_probs: torch.Tensor, advantages: torch.Tensor, action_mask: torch.Tensor | None = None, entropy_mask: torch.Tensor | None = None) torch.Tensor[source]

Compute policy loss with optional masking and algorithm-specific clipping.

This method computes the policy loss based on the configured algorithm (PPO or CPGD) and applies masking for valid tokens and optionally high-entropy tokens.

Parameters:
  • log_probs (torch.Tensor) – Log probabilities of actions under the current policy. Shape: (batch_size, num_actions)

  • old_log_probs (torch.Tensor) – Log probabilities of actions under the old/reference policy. Shape: (batch_size, num_actions)

  • advantages (torch.Tensor) – Advantage estimates for each action. Positive values indicate better-than-average actions. Shape: (batch_size, num_actions)

  • action_mask (Optional[torch.Tensor]) – Binary mask indicating valid action tokens (1 for valid, 0 for padding). If None, all tokens are considered valid. Shape: (batch_size, num_actions)

  • entropy_mask (Optional[torch.Tensor]) – Binary mask for high-entropy tokens to keep for training. If provided, overrides the instance-level entropy_mask. Shape: (batch_size, num_actions)

Returns:

Scalar policy loss averaged over valid (and optionally high-entropy) tokens.

Return type:

torch.Tensor

Masking Strategy:

The final mask is computed as: - If entropy_mask is provided: final_mask = entropy_mask

(Note: entropy_mask is already created considering action_mask in create_high_entropy_mask, so padding positions are already excluded)

  • Else: final_mask = action_mask

Only tokens where final_mask == 1 contribute to the loss computation.

Algorithm Details:

  • PPO: Uses symmetric clipping [1 - clip_eps, 1 + clip_eps] on the policy ratio.

  • CPGD: Uses asymmetric clipping with log-space bounds for better stability.

class lightrft.models.loss.ValueLoss(*args: Any, **kwargs: Any)[source]

Bases: Module

Value Loss for PPO

forward(values: torch.Tensor, old_values: torch.Tensor, returns: torch.Tensor, action_mask: torch.Tensor | None = None) torch.Tensor[source]

Compute PPO value function loss with optional clipping.

Parameters:
  • values (torch.Tensor) – Current value predictions.

  • old_values (torch.Tensor) – Value predictions from old policy (for clipping).

  • returns (torch.Tensor) – Target return values (e.g., GAE returns).

  • action_mask (Optional[torch.Tensor]) – Optional mask for valid timesteps (1 = valid, 0 = ignore).

Returns:

Scalar value loss (0.5 * MSE).

Return type:

torch.Tensor

class lightrft.models.loss.VanillaKTOLoss(*args: Any, **kwargs: Any)[source]

Bases: Module

KTO loss for even sampling

forward(policy_chosen_logps: torch.FloatTensor, policy_rejected_logps: torch.FloatTensor, reference_chosen_logps: torch.FloatTensor, reference_rejected_logps: torch.FloatTensor) Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor][source]

Compute vanilla KTO loss for evenly sampled chosen/rejected pairs.

Parameters:
  • policy_chosen_logps (torch.FloatTensor) – Log probabilities under policy for chosen samples.

  • policy_rejected_logps (torch.FloatTensor) – Log probabilities under policy for rejected samples.

  • reference_chosen_logps (torch.FloatTensor) – Log probabilities under reference model for chosen samples.

  • reference_rejected_logps (torch.FloatTensor) – Log probabilities under reference model for rejected samples.

Returns:

Tuple of (losses, chosen_rewards, rejected_rewards).

Return type:

Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]