Shortcuts

lightrft.trainer.utils

This module provides utilities for statistical computations commonly used in reinforcement learning and machine learning workflows. It includes functions for computing clipping fractions and classes for tracking running statistics of data streams.

The main components are: - compute_clip_fraction: Calculates the fraction of tensor elements that fall outside specified bounds - RunningMoments: Maintains running mean and standard deviation statistics for streaming data - get_cpgd_advantages_returns: Computes advantages and returns for CPGD algorithm - vllm_ge_0130: Version checking utility for vLLM compatibility

These utilities are particularly useful in RL algorithms like PPO where clipping statistics and normalization are important for training stability and monitoring.

class lightrft.trainer.utils.RunningMoments[source]

Bases: object

Calculate the running mean and standard deviation of a data stream.

This class implements Welford’s online algorithm for computing running statistics, allowing efficient computation of mean and standard deviation as new data arrives without storing all historical data. This is particularly useful for normalizing inputs in reinforcement learning or for monitoring training statistics.

The implementation uses a parallel algorithm to combine statistics from new batches with existing running statistics, ensuring numerical stability even with large amounts of data.

Adapted from https://github.com/alibaba/ROLL

Example:

.. code-block:: python

    >>> moments = RunningMoments()
    >>> batch1 = torch.randn(100)
    >>> mean1, std1 = moments.update(batch1)
    >>> batch2 = torch.randn(100)
    >>> mean2, std2 = moments.update(batch2)
    >>> print(f"Running mean: {moments.mean}, Running std: {moments.std}")
update(xs: torch.Tensor) Tuple[float, float]

Update running statistics with a new batch of data.

This method uses Welford’s online algorithm combined with a parallel algorithm to efficiently update the running mean, variance, and standard deviation with a new batch of data. The algorithm is numerically stable and doesn’t require storing all previous data points.

Parameters:

xs (torch.Tensor) – The new tensor of data to incorporate into the running statistics.

Returns:

A tuple of (mean, std) for the current batch xs.

Return type:

Tuple[float, float]

Example:

.. code-block:: python

    >>> moments = RunningMoments()
    >>> new_data = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0])
    >>> batch_mean, batch_std = moments.update(new_data)
    >>> print(f"Batch mean: {batch_mean}, Batch std: {batch_std}")
lightrft.trainer.utils.compute_clip_fraction(values: torch.Tensor, clip_max: float, clip_min: float) torch.Tensor[source]

Compute the fraction of elements in a tensor that are clipped.

This function calculates what proportion of the input tensor’s elements fall outside the specified clipping bounds [clip_min, clip_max]. This is commonly used in reinforcement learning to monitor how often policy updates are being clipped, which can indicate training stability.

Parameters:
  • values (torch.Tensor) – The input tensor to analyze for clipping.

  • clip_max (float) – The maximum value for clipping bounds.

  • clip_min (float) – The minimum value for clipping bounds.

Returns:

A tensor of shape (batch_size,) where each element is the fraction of clipped values in the input tensor.

Return type:

torch.Tensor

Example:

.. code-block:: python

    >>> values = torch.tensor([[1.0, 2.0, 3.0], [0.5, 1.5, 2.5]])
    >>> clip_fraction = compute_clip_fraction(values, clip_max=2.0, clip_min=1.0)
    >>> print(clip_fraction)  # Should show fraction of values outside [1.0, 2.0]
lightrft.trainer.utils.fire_sampling(all_prompt_token_ids: List[List[int]], generate_fn: Callable, engine_type: str = 'vllm', first_token_temperature: float = 10.0, temperature: float = 1.0, first_token_top_k: int = -1, first_token_top_p: float = 1.0, is_multimodal: bool = False, all_prompts: List[str] | None = None, all_images: List | None = None, all_images_num: List[int] | None = None, all_videos: List | None = None, all_videos_num: List[int] | None = None, sampling_params: dict | object | None = None) List[source]

FIRE sampling (Flaming-hot Initiation with Regular Execution)

FIRE sampling paper link: https://arxiv.org/abs/2410.21236

According to the paper, FIRE sampling: 1. Samples the FIRST token at a very high temperature (“flaming-hot initiation”) 2. Proceeds with regular temperature for remaining tokens 3. IMPORTANT: top_k, top_p, min_p, and other sampling parameters remain THE SAME

for both first token and remaining tokens. Only temperature changes.

This implementation follows the paper’s design: we only modify temperature between the first token and remaining tokens, keeping all other sampling parameters identical.

Parameters:
  • all_prompt_token_ids (List[List[int]]) – List of tokenized prompts

  • generate_fn (Callable) – Function to call for generation (with pre-configured parameters)

  • engine_type (str) – Backend type (“vllm” or “sglang”)

  • first_token_temperature (float) – Temperature for first token generation (default: 10.0)

  • temperature (float) – Temperature for remaining tokens

  • first_token_top_k (int) – DEPRECATED - kept for backward compatibility, will be ignored

  • first_token_top_p (float) – DEPRECATED - kept for backward compatibility, will be ignored

  • is_multimodal (bool) – Whether this is multimodal generation

  • all_prompts (Optional[List[str]]) – Text prompts (for multimodal)

  • all_images (Optional[List]) – Images (for multimodal)

  • all_images_num (Optional[List[int]]) – Number of images per prompt

  • all_videos (Optional[List]) – Videos (for multimodal)

  • all_videos_num (Optional[List[int]]) – Number of videos per prompt

  • sampling_params (Optional[Union[dict, object]]) – Original sampling parameters

Returns:

List of generation outputs

Return type:

List

lightrft.trainer.utils.get_cpgd_advantages_returns(reward: torch.Tensor, action_mask: torch.Tensor, weight_factor: str = 'STD_weight', epsilon: float = 1e-06) Tuple[torch.Tensor, torch.Tensor][source]

Aggregate token-level rewards into episode-level scores, normalize them group-wise, and then broadcast the normalized scores back to the token dimension to obtain both the advantages and the returns that are required by the CPGD (Clipped Policy Gradient Optimization with Policy Drift) algorithm.

Parameters:
  • reward (torch.Tensor) – Tensor of shape (num_actions, seq_len) containing token-level rewards produced by the reward model. Each row corresponds to one sampled response (action trajectory).

  • action_mask (torch.Tensor) – Tensor of the same shape as reward. Elements belonging to the generated response tokens are 1; padding / prompt tokens are 0. The mask is used so that only response tokens contribute to the final advantages / returns.

  • weight_factor (str) –

    Determines how the per-sample scalar scores are normalized:

    • ”STD_weight”: z-score normalization: score_i = (score_i − mean) / (std + ε)

    • ”clip_filter_like_weight”: a simplified version of the Clip-Filter weight used in early RLHF repos: score_i = (score_i − mean) * clamp(num_actions / nz, max=3)

    • any other value: mean-centering only: score_i = score_i − mean

    Defaults to “STD_weight”.

  • epsilon (float) – Small constant added to the denominator to avoid division by zero, defaults to 1e-6.

Returns:

A tuple of (advantages, returns). - advantages: Normalized per-token advantages, shape (num_actions, seq_len). - returns: Identical to advantages in CPGD; returned separately for API symmetry.

Return type:

Tuple[torch.Tensor, torch.Tensor]

Note

Both advantages and returns are masked so that non-response tokens are always zero. The function performs no gradient-tracking operations and is intended to be called outside the optimization graph.

lightrft.trainer.utils.vllm_ge_0130()[source]

Check if vLLM version is greater than or equal to 0.13.0.

Starting from vLLM 0.13.0, truncate_prompt_tokens parameter must not exceed max_model_len, requiring additional validation logic.

Returns:

True if vLLM version >= 0.13.0, False otherwise

Return type:

bool