Source code for lightrft.models.loss
"""
Loss functions used across LightRFT models.
This module implements a comprehensive collection of loss functions for reinforcement learning
from human feedback (RLHF) and related training paradigms:
**Policy Optimization Losses:**
- PolicyLoss: Multi-purpose policy loss supporting PPO, CPGD (via use_cpg_loss), DAPO-style
decoupled clipping, and high-entropy token filtering for efficient training.
- ValueLoss: Value function loss for PPO with optional value clipping.
**Reward Model Losses:**
- GPTLMLoss: Next-token prediction loss for generative reward model training.
- LogSigmoidLoss: Log-sigmoid pairwise loss for scalar reward model training.
- LogExpLoss: Log-exp pairwise loss for scalar reward model training.
- HPSLoss: Human Preference Score loss for scalar reward model training.
- PairWiseLoss: Generic pairwise preference loss for reward models.
- PRMLoss: Process Reward Model loss for token-level reward prediction.
**Preference Learning Losses:**
- DPOLoss: Direct Preference Optimization loss for aligning language models with preferences.
- KTOLoss: Kahneman-Tversky Optimization loss for uneven sampling scenarios.
- VanillaKTOLoss: Simplified KTO loss for even sampling scenarios.
**Knowledge Distillation:**
- KDLoss: Knowledge Distillation loss for transferring knowledge from teacher to student models.
All loss functions are designed to work seamlessly with the LightRFT training framework,
supporting distributed training, mixed precision, and various optimization strategies.
"""
from typing import Optional, Tuple, Union
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from .utils import masked_mean
[docs]class GPTLMLoss(nn.Module):
"""
GPT Language Model loss for next-token prediction.
Used for generative reward model training.
:ivar int IGNORE_INDEX: Label index to ignore when computing the
cross-entropy (default: ``-100``), matching Hugging Face conventions.
:ivar torch.nn.CrossEntropyLoss loss: Underlying cross-entropy criterion
configured to ignore ``IGNORE_INDEX``.
"""
def __init__(self):
super().__init__()
self.IGNORE_INDEX = -100
self.loss = nn.CrossEntropyLoss(ignore_index=self.IGNORE_INDEX)
[docs] def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
"""
Compute next-token prediction loss.
Uses the common shifting scheme:
``shift_logits = logits[..., :-1, :]`` and
``shift_labels = labels[..., 1:]``.
:param logits: Model output logits.
:type logits: torch.Tensor
:param labels: Token ids aligned with logits. Tokens to be ignored
should be set to ``IGNORE_INDEX`` (default ``-100``).
:type labels: torch.Tensor
:returns: Scalar mean cross-entropy loss.
:rtype: torch.Tensor
:shape logits: ``(..., seq_len, vocab_size)``
:shape labels: ``(..., seq_len)``
"""
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = self.loss(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
return loss
[docs]class PolicyLoss(nn.Module):
"""
Multi-purpose policy loss function supporting multiple reinforcement learning algorithms.
This class implements a unified policy loss that can be configured to support various
policy optimization algorithms including PPO, CPGD, and high-entropy token filtering
strategies. The loss function computes clipped policy gradients with optional masking
for efficient training.
**Supported Algorithms:**
- **PPO (Proximal Policy Optimization)**: Default mode using standard clipped surrogate
objective. The loss is computed as ``-min(ratio * advantages, clipped_ratio * advantages)``
where ``ratio = exp(log_probs - old_log_probs)`` and clipping is applied to prevent
large policy updates.
- **Clipped Policy Gradient Optimization with Policy Drift (CPGD)**: Enabled via ``use_cpg_loss=True``. Uses
asymmetric clipping bounds for positive and negative advantages, providing better
stability for constrained policy optimization. See: https://arxiv.org/abs/2505.12504
- **High-Entropy Token Filtering**: Enabled via ``high_entropy_token_ratio > 0`` or by
providing an ``entropy_mask`` in the forward pass. This feature allows training only on
high-entropy tokens (forking tokens that determine reasoning directions), significantly
improving training efficiency. Based on: https://arxiv.org/abs/2506.01939
:param clip_eps: Clipping epsilon for PPO-style policy updates. Determines the maximum
allowed ratio between new and old policy probabilities. Typical values range from
0.1 to 0.3. Default: 0.2
:type clip_eps: float
:param use_dapo: Flag for DAPO (Decoupled Clip and Dynamic sAmpling Policy Optimization).
Currently reserved for future implementation. Default: False
:type use_dapo: bool
:param use_cpg_loss: If True, uses CPGD-style clipped policy gradient loss with
asymmetric clipping bounds. When False, uses standard PPO clipping. Default: False
:type use_cpg_loss: bool
:param high_entropy_token_ratio: Ratio of high-entropy tokens to keep for training
(e.g., 0.2 means top 20% highest entropy tokens). When > 0, enables high-entropy
token filtering. Set to 0.0 to disable. Default: 0.0
:type high_entropy_token_ratio: float
**Loss Computation:**
The loss is computed as follows:
1. **Mask Application**: Combines ``action_mask`` (valid tokens) with ``entropy_mask``
(high-entropy tokens) to create a final mask for loss computation.
2. **PPO Mode** (default, ``use_cpg_loss=False``):
- Computes policy ratio: ``ratio = exp(log_probs - old_log_probs)``
- Clips ratio: ``clipped_ratio = clamp(ratio, 1 - clip_eps, 1 + clip_eps)``
- Loss: ``-min(ratio * advantages, clipped_ratio * advantages)``
3. **CPGD Mode** (``use_cpg_loss=True``):
- Uses asymmetric clipping: upper bound ``log(1 + clip_eps)`` for positive advantages,
lower bound ``log(1 - clip_eps)`` for negative advantages
- Loss: ``-clipped_log_probs * advantages``
4. **Masking**: The computed loss is masked using ``final_mask`` and averaged only over
valid, high-entropy tokens (if enabled).
**Example Usage:**
.. code-block:: python
# Standard PPO loss
policy_loss = PolicyLoss(clip_eps=0.2)
loss = policy_loss(log_probs, old_log_probs, advantages, action_mask)
# CPGD loss
policy_loss = PolicyLoss(clip_eps=0.2, use_cpg_loss=True)
loss = policy_loss(log_probs, old_log_probs, advantages, action_mask)
# PPO with high-entropy token filtering (top 20%)
policy_loss = PolicyLoss(clip_eps=0.2, high_entropy_token_ratio=0.2)
loss = policy_loss(log_probs, old_log_probs, advantages, action_mask, entropy_mask)
**References:**
- PPO: https://arxiv.org/abs/1707.06347
- CPGD: https://arxiv.org/abs/2505.12504
- High-Entropy Token Filtering: https://arxiv.org/abs/2506.01939
"""
def __init__(
self,
clip_eps: float = 0.2,
use_dapo: bool = False,
use_cpg_loss: bool = False,
high_entropy_token_ratio: float = 0.0,
) -> None:
super().__init__()
self.clip_eps = clip_eps
self.use_dapo = use_dapo
self.use_cpg_loss = use_cpg_loss
self.high_entropy_token_ratio = high_entropy_token_ratio
[docs] def forward(
self,
log_probs: torch.Tensor,
old_log_probs: torch.Tensor,
advantages: torch.Tensor,
action_mask: Optional[torch.Tensor] = None,
entropy_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Compute policy loss with optional masking and algorithm-specific clipping.
This method computes the policy loss based on the configured algorithm (PPO or CPGD)
and applies masking for valid tokens and optionally high-entropy tokens.
:param log_probs: Log probabilities of actions under the current policy.
Shape: ``(batch_size, num_actions)``
:type log_probs: torch.Tensor
:param old_log_probs: Log probabilities of actions under the old/reference policy.
Shape: ``(batch_size, num_actions)``
:type old_log_probs: torch.Tensor
:param advantages: Advantage estimates for each action. Positive values indicate
better-than-average actions. Shape: ``(batch_size, num_actions)``
:type advantages: torch.Tensor
:param action_mask: Binary mask indicating valid action tokens (1 for valid, 0 for padding).
If None, all tokens are considered valid. Shape: ``(batch_size, num_actions)``
:type action_mask: Optional[torch.Tensor]
:param entropy_mask: Binary mask for high-entropy tokens to keep for training.
If provided, overrides the instance-level ``entropy_mask``. Shape: ``(batch_size, num_actions)``
:type entropy_mask: Optional[torch.Tensor]
:returns: Scalar policy loss averaged over valid (and optionally high-entropy) tokens.
:rtype: torch.Tensor
**Masking Strategy:**
The final mask is computed as:
- If ``entropy_mask`` is provided: ``final_mask = entropy_mask``
(Note: ``entropy_mask`` is already created considering ``action_mask`` in
``create_high_entropy_mask``, so padding positions are already excluded)
- Else: ``final_mask = action_mask``
Only tokens where ``final_mask == 1`` contribute to the loss computation.
**Algorithm Details:**
- **PPO**: Uses symmetric clipping ``[1 - clip_eps, 1 + clip_eps]`` on the policy ratio.
- **CPGD**: Uses asymmetric clipping with log-space bounds for better stability.
"""
# Apply entropy mask if provided (for high-entropy token filtering)
# action_mask shape: (batch_size, num_actions) - binary mask indicating valid tokens
# entropy_mask shape: (batch_size, num_actions) - binary mask for high-entropy tokens
# Note: entropy_mask is already created considering action_mask in create_high_entropy_mask,
# so it already excludes padding positions. No need to multiply with action_mask again.
if entropy_mask is not None:
# entropy_mask already respects action_mask boundaries (padding positions are 0)
final_mask = entropy_mask
else:
# No entropy masking, use action_mask only
final_mask = action_mask
if self.use_cpg_loss:
clipped_log_probs = torch.where(
advantages > 0, torch.clamp(log_probs, max=torch.log(torch.tensor(1 + self.clip_eps)) + old_log_probs),
torch.clamp(log_probs, min=torch.log(torch.tensor(1 - self.clip_eps)) + old_log_probs)
)
loss = -clipped_log_probs * advantages
loss = masked_mean(loss, final_mask, dim=-1).mean()
return loss
# PPO loss
ratio = (log_probs - old_log_probs).exp()
surr1 = ratio * advantages
surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
loss = -torch.min(surr1, surr2)
loss = masked_mean(loss, final_mask, dim=-1).mean()
return loss
[docs]class ValueLoss(nn.Module):
"""
Value Loss for PPO
"""
def __init__(self, clip_eps: float = None) -> None:
super().__init__()
self.clip_eps = clip_eps
[docs] def forward(
self,
values: torch.Tensor,
old_values: torch.Tensor,
returns: torch.Tensor,
action_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Compute PPO value function loss with optional clipping.
:param values: Current value predictions.
:type values: torch.Tensor
:param old_values: Value predictions from old policy (for clipping).
:type old_values: torch.Tensor
:param returns: Target return values (e.g., GAE returns).
:type returns: torch.Tensor
:param action_mask: Optional mask for valid timesteps (1 = valid, 0 = ignore).
:type action_mask: Optional[torch.Tensor]
:return: Scalar value loss (0.5 * MSE).
:rtype: torch.Tensor
"""
if self.clip_eps is not None:
values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps)
surr1 = (values_clipped - returns) ** 2
surr2 = (values - returns) ** 2
loss = torch.max(surr1, surr2)
else:
loss = (values - returns) ** 2
loss = masked_mean(loss, action_mask, dim=-1).mean()
return 0.5 * loss
[docs]class PairWiseLoss(nn.Module):
"""
Pairwise Loss for Reward Model
"""
[docs] def forward(
self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor, margin: torch.Tensor = None
) -> torch.Tensor:
"""
Compute pairwise ranking loss.
:param chosen_reward: Reward scores for chosen/preferred samples.
:type chosen_reward: torch.Tensor
:param reject_reward: Reward scores for rejected samples.
:type reject_reward: torch.Tensor
:param margin: Optional margin value to enforce separation.
:type margin: Optional[torch.Tensor]
:return: Mean negative log-sigmoid loss.
:rtype: torch.Tensor
"""
if margin is not None:
loss = -F.logsigmoid(chosen_reward - reject_reward - margin)
else:
loss = -F.logsigmoid(chosen_reward - reject_reward)
return loss.mean()
[docs]class DPOLoss(nn.Module):
"""
DPO Loss
"""
def __init__(self, beta: float, label_smoothing: float = 0.0, ipo: bool = False) -> None:
super().__init__()
self.beta = beta
self.label_smoothing = label_smoothing
self.ipo = ipo
[docs] def forward(
self,
policy_chosen_logps: torch.Tensor,
policy_rejected_logps: torch.Tensor,
reference_chosen_logps: torch.Tensor,
reference_rejected_logps: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Compute DPO (Direct Preference Optimization) loss.
:param policy_chosen_logps: Log probabilities under policy for chosen samples.
:type policy_chosen_logps: torch.Tensor
:param policy_rejected_logps: Log probabilities under policy for rejected samples.
:type policy_rejected_logps: torch.Tensor
:param reference_chosen_logps: Log probabilities under reference model for chosen samples.
:type reference_chosen_logps: torch.Tensor
:param reference_rejected_logps: Log probabilities under reference model for rejected samples.
:type reference_rejected_logps: torch.Tensor
:return: Tuple of (loss, chosen_rewards, rejected_rewards).
:rtype: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
"""
pi_logratios = policy_chosen_logps - policy_rejected_logps
ref_logratios = reference_chosen_logps - reference_rejected_logps
logits = pi_logratios - ref_logratios
if self.ipo:
losses = (logits - 1 / (2 * self.beta)) ** 2 # Eq. 17 of https://arxiv.org/pdf/2310.12036v2.pdf
else:
# Eq. 3 https://ericmitchell.ai/cdpo.pdf; label_smoothing=0 gives original DPO
# (Eq. 7 of https://arxiv.org/pdf/2305.18290.pdf)
losses = (
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) -
F.logsigmoid(-self.beta * logits) * self.label_smoothing
)
loss = losses.mean()
chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach()
rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach()
return loss, chosen_rewards, rejected_rewards
# Adapted from https://github.com/ContextualAI/HALOs/blob/ca9b7e3eeea220c0944ad8095d641da33f907a7e/trainers.py#L742
[docs]class VanillaKTOLoss(nn.Module):
"""
KTO loss for even sampling
"""
def __init__(self, beta: float) -> None:
super().__init__()
self.beta = beta
[docs] def forward(
self,
policy_chosen_logps: torch.FloatTensor,
policy_rejected_logps: torch.FloatTensor,
reference_chosen_logps: torch.FloatTensor,
reference_rejected_logps: torch.FloatTensor,
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
"""
Compute vanilla KTO loss for evenly sampled chosen/rejected pairs.
:param policy_chosen_logps: Log probabilities under policy for chosen samples.
:type policy_chosen_logps: torch.FloatTensor
:param policy_rejected_logps: Log probabilities under policy for rejected samples.
:type policy_rejected_logps: torch.FloatTensor
:param reference_chosen_logps: Log probabilities under reference model for chosen samples.
:type reference_chosen_logps: torch.FloatTensor
:param reference_rejected_logps: Log probabilities under reference model for rejected samples.
:type reference_rejected_logps: torch.FloatTensor
:return: Tuple of (losses, chosen_rewards, rejected_rewards).
:rtype: Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]
"""
chosen_KL = (policy_chosen_logps - reference_chosen_logps).mean().clamp(min=0)
rejected_KL = (policy_rejected_logps - reference_rejected_logps).mean().clamp(min=0)
chosen_logratios = policy_chosen_logps - reference_chosen_logps
rejected_logratios = policy_rejected_logps - reference_rejected_logps
losses = torch.cat(
(
1 - F.sigmoid(self.beta * (chosen_logratios - rejected_KL)),
1 - F.sigmoid(self.beta * (chosen_KL - rejected_logratios)),
),
0,
).mean()
chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach()
rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach()
return losses, chosen_rewards, rejected_rewards
# Adapted from https://github.com/ContextualAI/HALOs/blob/ca9b7e3eeea220c0944ad8095d641da33f907a7e/trainers.py#L770
[docs]class KTOLoss(nn.Module):
"""
KTO loss for uneven sampling
"""
def __init__(
self, beta: float, desirable_weight: float, undesirable_weight: float, world_size: int, device: torch.device
) -> None:
super().__init__()
self.beta = beta
self.world_size = world_size
self.device = device
self.desirable_weight = desirable_weight
self.undesirable_weight = undesirable_weight
[docs] def forward(
self,
policy_chosen_logps: torch.FloatTensor,
policy_rejected_logps: torch.FloatTensor,
policy_KL_logps: torch.FloatTensor,
reference_chosen_logps: torch.FloatTensor,
reference_rejected_logps: torch.FloatTensor,
reference_KL_logps: torch.FloatTensor,
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
"""
Compute KTO loss for unevenly sampled chosen/rejected pairs with distributed KL estimation.
:param policy_chosen_logps: Log probabilities under policy for chosen samples.
:type policy_chosen_logps: torch.FloatTensor
:param policy_rejected_logps: Log probabilities under policy for rejected samples.
:type policy_rejected_logps: torch.FloatTensor
:param policy_KL_logps: Log probabilities under policy for KL estimation samples.
:type policy_KL_logps: torch.FloatTensor
:param reference_chosen_logps: Log probabilities under reference model for chosen samples.
:type reference_chosen_logps: torch.FloatTensor
:param reference_rejected_logps: Log probabilities under reference model for rejected samples.
:type reference_rejected_logps: torch.FloatTensor
:param reference_KL_logps: Log probabilities under reference model for KL estimation samples.
:type reference_KL_logps: torch.FloatTensor
:return: Tuple of (losses, chosen_rewards, rejected_rewards, KL).
:rtype: Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]
"""
KL = (policy_KL_logps - reference_KL_logps).mean().detach()
# all_reduce sums up the KL estimates across all devices (gradient will also be scaled by world size)
dist.all_reduce(KL, op=dist.ReduceOp.SUM)
# take average (will also scale gradients appropriately)
KL = (KL / self.world_size).clamp(min=0)
if policy_chosen_logps.shape[0] != 0:
chosen_logratios = policy_chosen_logps - reference_chosen_logps
chosen_losses = 1 - F.sigmoid(self.beta * (chosen_logratios - KL))
chosen_rewards = self.beta * chosen_logratios.detach()
else:
# important to cast to policy_dtype; otherwise error will occur during all_gather
chosen_losses = torch.Tensor([]).to(policy_rejected_logps.dtype).to(self.device)
chosen_rewards = torch.Tensor([]).to(policy_rejected_logps.dtype).to(self.device)
if policy_rejected_logps.shape[0] != 0:
rejected_logratios = policy_rejected_logps - reference_rejected_logps
rejected_losses = 1 - F.sigmoid(self.beta * (KL - rejected_logratios))
rejected_rewards = self.beta * rejected_logratios.detach()
else:
# important to cast to policy_dtype; otherwise error will occur during all_gather
rejected_losses = torch.Tensor([]).to(policy_chosen_logps.dtype).to(self.device)
rejected_rewards = torch.Tensor([]).to(policy_chosen_logps.dtype).to(self.device)
losses = torch.cat((self.desirable_weight * chosen_losses, self.undesirable_weight * rejected_losses), 0).mean()
return losses, chosen_rewards, rejected_rewards, KL
# Adapted from https://github.com/microsoft/LMOps/blob/main/minillm/finetune.py#L166
[docs]class KDLoss(nn.Module):
"""
Language Model Knowledge Distillation Loss
"""
def __init__(self):
super().__init__()
self.IGNORE_INDEX = -100
[docs] def forward(self, logits: torch.Tensor, teacher_logits: torch.Tensor, label: torch.Tensor) -> torch.Tensor:
"""
Compute knowledge distillation loss.
:param logits: Student model logits.
:type logits: torch.Tensor
:param teacher_logits: Teacher model logits (detached).
:type teacher_logits: torch.Tensor
:param label: Ground truth labels (tokens to ignore set to IGNORE_INDEX).
:type label: torch.Tensor
:return: Scalar KD loss.
:rtype: torch.Tensor
"""
teacher_probs = F.softmax(teacher_logits, dim=-1, dtype=torch.float32)
inf_mask = torch.isinf(logits)
logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
prod_probs = torch.masked_fill(teacher_probs * logprobs, inf_mask, 0)
x = torch.sum(prod_probs, dim=-1).view(-1)
mask = (label != self.IGNORE_INDEX).int()
distil_loss = -torch.sum(x * mask.view(-1), dim=0) / torch.sum(mask.view(-1), dim=0)
return distil_loss
[docs]class PRMLoss(nn.Module):
"""
Process Reward Model Loss
"""
def __init__(self, placeholder_token_id: int, reward_token_ids: Optional[list[int]] = None):
super().__init__()
self.IGNORE_INDEX = -100
self.loss = nn.CrossEntropyLoss(ignore_index=self.IGNORE_INDEX)
self.placeholder_token_id = placeholder_token_id
self.reward_token_ids = reward_token_ids
[docs] def forward(self,
inputs: torch.Tensor,
logits: torch.Tensor,
labels: torch.Tensor,
*,
return_acc: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Compute process reward model loss.
:param inputs: Input token IDs (used to locate placeholder tokens).
:type inputs: torch.Tensor
:param logits: Model output logits.
:type logits: torch.Tensor
:param labels: Target labels (hard or soft labels for reward tokens).
:type labels: torch.Tensor
:param return_acc: If True, also return accuracy.
:type return_acc: bool
:return: Loss tensor or tuple of (loss, accuracy) if return_acc=True.
:rtype: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
"""
placeholder_mask = inputs == self.placeholder_token_id
logits = logits[placeholder_mask]
labels = labels[placeholder_mask]
if labels.dtype == torch.float:
# soft label
assert len(self.reward_token_ids) == 2, "reward_token_ids should have 2 tokens for soft labels"
logits = logits[..., self.reward_token_ids]
positive_labels = labels.to(logits.dtype)
negative_labels = 1 - positive_labels
negative_labels[positive_labels != -100] = 1 - positive_labels[positive_labels != -100]
labels = torch.stack([positive_labels, negative_labels], dim=-1)
elif self.reward_token_ids is not None:
# hard label with reward_token_ids set. (otherwise the whole vocab will be trained together.)
logits = logits[..., self.reward_token_ids]
# this is slow....
for i, token in enumerate(self.reward_token_ids):
labels = torch.where(labels == token, i, labels)
loss = self.loss(logits, labels)
if not return_acc:
return loss
if labels.dtype == logits.dtype:
labels = labels.argmax(dim=-1)
acc = (logits.argmax(dim=-1) == labels).float().mean()
return loss, acc
[docs]class LogSigmoidLoss(nn.Module):
"""
Pairwise preference loss for scalar reward models using the log-sigmoid objective.
Encourages the chosen sample to have a higher reward than the rejected
sample. Optionally supports a non-negative margin.
"""
[docs] def forward(
self,
chosen_reward: torch.Tensor,
reject_reward: torch.Tensor,
margin: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Compute log-sigmoid pairwise loss.
:param chosen_reward: Predicted reward for the preferred (chosen) sample.
:type chosen_reward: torch.Tensor
:param reject_reward: Predicted reward for the rejected sample.
:type reject_reward: torch.Tensor
:param margin: Optional non-negative margin. If provided, the objective
becomes ``logsigmoid(chosen - reject - margin)``. Supports
broadcasting across batch dimensions.
:type margin: Optional[torch.Tensor]
:returns: Mean negative log-sigmoid loss over the batch.
:rtype: torch.Tensor
"""
if margin is not None:
loss = -F.logsigmoid(chosen_reward - reject_reward - margin)
else:
loss = -F.logsigmoid(chosen_reward - reject_reward)
return loss.mean()
[docs]class LogExpLoss(nn.Module):
"""
Log-exp (softplus) pairwise loss for scalar reward model training.
This loss corresponds to ``log(1 + exp(reject - chosen))`` averaged over
the batch. See: https://arxiv.org/abs/2204.05862
"""
[docs] def forward(
self,
chosen_reward: torch.Tensor,
reject_reward: torch.Tensor,
margin: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Compute log-exp pairwise loss.
:param chosen_reward: Predicted reward for the preferred (chosen) sample.
:type chosen_reward: torch.Tensor
:param reject_reward: Predicted reward for the rejected sample.
:type reject_reward: torch.Tensor
:param margin: Unused; included for API compatibility with
:class:`PairWiseLoss`.
:type margin: Optional[torch.Tensor]
:returns: Mean ``log(1 + exp(reject - chosen))`` over the batch.
:rtype: torch.Tensor
"""
loss = torch.log(1 + torch.exp(reject_reward - chosen_reward)).mean()
return loss
[docs]class HPSLoss(nn.Module):
"""
Human Preference Score (HPS) Loss for scalar reward model training.
Implements the cross-entropy loss over the logits formed by concatenating
the chosen and rejected rewards. The core idea is to treat the preference
prediction as binary classification task.
Paper: https://arxiv.org/abs/2303.14420
"""
[docs] def forward(
self,
chosen_reward: torch.Tensor,
reject_reward: torch.Tensor,
margin: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Compute HPS loss.
:param chosen_reward: Predicted reward for the preferred (chosen) sample.
:type chosen_reward: torch.Tensor
:param reject_reward: Predicted reward for the rejected sample.
:type reject_reward: torch.Tensor
:param margin: Unused; included for API compatibility with
:class:`PairWiseLoss`.
:type margin: Optional[torch.Tensor]
:returns: Mean cross-entropy loss over the batch.
:rtype: torch.Tensor
"""
logits = torch.cat([chosen_reward, reject_reward], dim=-1)
labels = torch.zeros(logits.shape[0], device=logits.device, dtype=torch.long)
loss = F.cross_entropy(logits, labels)
return loss