Shortcuts

lightrft.strategy.utils.statistic

Utilities for analyzing and visualizing the length distribution of generated outputs.

This module provides functionality for collecting, analyzing, and visualizing the length distribution of generated outputs from language models. It includes tools for gathering output lengths across distributed processes, computing statistics like percentiles, and creating visualizations using matplotlib and TensorBoard.

The main components are: - GenLenAnalyser: A class for continuous monitoring and visualization of generation lengths - Helper functions for collecting and analyzing output lengths in distributed environments

GenLenAnalyser

class lightrft.strategy.utils.statistic.GenLenAnalyser(engine_dp_group: torch.distributed.ProcessGroup, plot_every: int = 2, percentiles: List[int] = [50, 80], plot_out_dir: str | None = None)[source]

Analyzer for tracking and visualizing the length distribution of generated outputs.

This class collects length statistics of generated outputs over time, computes various metrics, and can visualize the distributions using matplotlib and TensorBoard. It is designed to work in distributed training environments and provides continuous monitoring capabilities with configurable plotting intervals.

Parameters:
  • engine_dp_group (torch.distributed.ProcessGroup) – The distributed process group for communication

  • plot_every (int) – How often to plot the distribution (in steps), set to 0 to disable plotting

  • percentiles (list) – List of percentiles to compute for the length distribution

  • plot_out_dir (str or None) – Directory to save plots and TensorBoard logs, if None no plots are saved

Example:

>>> import torch.distributed as dist
>>> # Initialize distributed process group
>>> analyzer = GenLenAnalyser(
...     engine_dp_group=dist.group.WORLD,
...     plot_every=10,
...     percentiles=[25, 50, 75, 90],
...     plot_out_dir="./output_analysis"
... )
>>> # Use during training loop
>>> stats = analyzer.collect(generation_outputs, step=100, is_rank_0=True)
collect(gen_outputs: List[Dict[str, Any]], cur_step: int, is_rank_0: bool) Dict[str, Any] | None[source]

Collect and analyze generation length data at the current step.

This method gathers output lengths from all processes, computes statistics, and optionally creates visualizations if conditions are met. The collection happens at intervals specified by plot_every parameter.

Parameters:
  • gen_outputs (List[Dict[str, Any]]) – List of generation outputs to analyze, each containing ‘output_token_ids’

  • cur_step (int) – Current training/generation step

  • is_rank_0 (bool) – Whether the current process is the main process (rank 0)

Returns:

Dictionary containing length statistics or None if collection is skipped

Return type:

dict or None

Example:

>>> gen_outputs = [
...     {"output_token_ids": [1, 2, 3, 4, 5]},
...     {"output_token_ids": [1, 2, 3]}
... ]
>>> stats = analyzer.collect(gen_outputs, cur_step=50, is_rank_0=True)
>>> if stats:
...     print(f"Mean length: {stats['mean_length']}")

analyse_output_lengths

lightrft.strategy.utils.statistic.analyse_output_lengths(gen_outputs: List[Dict[str, Any]], engine_dp_group: torch.distributed.ProcessGroup, percentiles: List[int] = [50, 80], plot_out_dir: str | None = None, prefix: str = '') Dict[str, Any][source]

Analyze the length distribution of generated outputs.

This is a convenience function that collects local output lengths, gathers them across all processes, and computes statistics. It provides a one-time analysis without the continuous monitoring capabilities of GenLenAnalyser.

Parameters:
  • gen_outputs (list) – List of generation outputs to analyze, each containing ‘output_token_ids’

  • engine_dp_group (torch.distributed.ProcessGroup) – The distributed process group for communication

  • percentiles (list) – List of percentiles to compute for the length distribution

  • plot_out_dir (str or None) – Directory to save plots, if None no plots are saved

  • prefix (str) – Prefix for plot filenames

Returns:

Dictionary containing length statistics

Return type:

dict

Example:

>>> gen_outputs = [
...     {"output_token_ids": [1, 2, 3, 4, 5, 6]},
...     {"output_token_ids": [1, 2, 3]}
... ]
>>> stats = analyse_output_lengths(
...     gen_outputs,
...     engine_dp_group=dist.group.WORLD,
...     percentiles=[25, 50, 75]
... )
>>> print(f"Median length: {stats['median_length']}")

collect_local_output_lengths

lightrft.strategy.utils.statistic.collect_local_output_lengths(outputs: List[Dict[str, Any]]) List[int][source]

Collect the lengths of generated outputs from the local process.

This function extracts the length of each output by counting the tokens in the ‘output_token_ids’ field of each output dictionary.

Parameters:

outputs (list) – List of generation outputs, each containing ‘output_token_ids’

Returns:

List of output lengths corresponding to each input output

Return type:

list

Example:

>>> outputs = [
...     {"output_token_ids": [1, 2, 3, 4, 5]},
...     {"output_token_ids": [10, 20]},
...     {"output_token_ids": [100, 200, 300]}
... ]
>>> lengths = collect_local_output_lengths(outputs)
>>> print(lengths)  # [5, 2, 3]

gather_all_lengths

lightrft.strategy.utils.statistic.gather_all_lengths(local_lengths: List[int], group: torch.distributed.ProcessGroup) List[int][source]

Gather output lengths from all processes in the distributed group.

This function uses PyTorch’s distributed communication to collect length data from all processes in the specified group, enabling global analysis of generation length distributions across the entire distributed system.

Parameters:
  • local_lengths (list) – List of output lengths from the local process

  • group (torch.distributed.ProcessGroup) – The distributed process group for communication

Returns:

Combined list of output lengths from all processes

Return type:

list

Example:

>>> # Assuming distributed environment is set up
>>> local_lengths = [5, 3, 7]
>>> all_lengths = gather_all_lengths(local_lengths, dist.group.WORLD)
>>> # all_lengths now contains lengths from all processes

analyze_output_lengths

lightrft.strategy.utils.statistic.analyze_output_lengths(all_lengths: List[int], percentiles: List[int]) Dict[str, Any][source]

Analyze the distribution of output lengths and compute statistics.

This function computes comprehensive statistics about the length distribution, including basic statistics (min, max, mean, median) and user-specified percentiles. The results provide insights into the generation behavior and can help with optimization and monitoring.

Parameters:
  • all_lengths (List[int]) – List of output lengths from all processes

  • percentiles (List[int]) – List of percentiles to compute (e.g., [25, 50, 75, 90])

Returns:

Dictionary containing statistics about the length distribution

Return type:

Dict[str, Any]

Example:

>>> lengths = [10, 15, 20, 25, 30, 35, 40]
>>> stats = analyze_output_lengths(lengths, percentiles=[25, 50, 75])
>>> print(f"Mean: {stats['mean_length']}")
>>> print(f"75th percentile: {stats['percentiles'][75]}")