lightrft.strategy.utils.statistic¶
Utilities for analyzing and visualizing the length distribution of generated outputs.
This module provides functionality for collecting, analyzing, and visualizing the length distribution of generated outputs from language models. It includes tools for gathering output lengths across distributed processes, computing statistics like percentiles, and creating visualizations using matplotlib and TensorBoard.
The main components are: - GenLenAnalyser: A class for continuous monitoring and visualization of generation lengths - Helper functions for collecting and analyzing output lengths in distributed environments
GenLenAnalyser¶
- class lightrft.strategy.utils.statistic.GenLenAnalyser(engine_dp_group: torch.distributed.ProcessGroup, plot_every: int = 2, percentiles: List[int] = [50, 80], plot_out_dir: str | None = None)[source]¶
Analyzer for tracking and visualizing the length distribution of generated outputs.
This class collects length statistics of generated outputs over time, computes various metrics, and can visualize the distributions using matplotlib and TensorBoard. It is designed to work in distributed training environments and provides continuous monitoring capabilities with configurable plotting intervals.
- Parameters:
engine_dp_group (torch.distributed.ProcessGroup) – The distributed process group for communication
plot_every (int) – How often to plot the distribution (in steps), set to 0 to disable plotting
percentiles (list) – List of percentiles to compute for the length distribution
plot_out_dir (str or None) – Directory to save plots and TensorBoard logs, if None no plots are saved
Example:
>>> import torch.distributed as dist >>> # Initialize distributed process group >>> analyzer = GenLenAnalyser( ... engine_dp_group=dist.group.WORLD, ... plot_every=10, ... percentiles=[25, 50, 75, 90], ... plot_out_dir="./output_analysis" ... ) >>> # Use during training loop >>> stats = analyzer.collect(generation_outputs, step=100, is_rank_0=True)
- collect(gen_outputs: List[Dict[str, Any]], cur_step: int, is_rank_0: bool) Dict[str, Any] | None[source]¶
Collect and analyze generation length data at the current step.
This method gathers output lengths from all processes, computes statistics, and optionally creates visualizations if conditions are met. The collection happens at intervals specified by plot_every parameter.
- Parameters:
gen_outputs (List[Dict[str, Any]]) – List of generation outputs to analyze, each containing ‘output_token_ids’
cur_step (int) – Current training/generation step
is_rank_0 (bool) – Whether the current process is the main process (rank 0)
- Returns:
Dictionary containing length statistics or None if collection is skipped
- Return type:
dict or None
Example:
>>> gen_outputs = [ ... {"output_token_ids": [1, 2, 3, 4, 5]}, ... {"output_token_ids": [1, 2, 3]} ... ] >>> stats = analyzer.collect(gen_outputs, cur_step=50, is_rank_0=True) >>> if stats: ... print(f"Mean length: {stats['mean_length']}")
analyse_output_lengths¶
- lightrft.strategy.utils.statistic.analyse_output_lengths(gen_outputs: List[Dict[str, Any]], engine_dp_group: torch.distributed.ProcessGroup, percentiles: List[int] = [50, 80], plot_out_dir: str | None = None, prefix: str = '') Dict[str, Any][source]¶
Analyze the length distribution of generated outputs.
This is a convenience function that collects local output lengths, gathers them across all processes, and computes statistics. It provides a one-time analysis without the continuous monitoring capabilities of GenLenAnalyser.
- Parameters:
gen_outputs (list) – List of generation outputs to analyze, each containing ‘output_token_ids’
engine_dp_group (torch.distributed.ProcessGroup) – The distributed process group for communication
percentiles (list) – List of percentiles to compute for the length distribution
plot_out_dir (str or None) – Directory to save plots, if None no plots are saved
prefix (str) – Prefix for plot filenames
- Returns:
Dictionary containing length statistics
- Return type:
dict
Example:
>>> gen_outputs = [ ... {"output_token_ids": [1, 2, 3, 4, 5, 6]}, ... {"output_token_ids": [1, 2, 3]} ... ] >>> stats = analyse_output_lengths( ... gen_outputs, ... engine_dp_group=dist.group.WORLD, ... percentiles=[25, 50, 75] ... ) >>> print(f"Median length: {stats['median_length']}")
collect_local_output_lengths¶
- lightrft.strategy.utils.statistic.collect_local_output_lengths(outputs: List[Dict[str, Any]]) List[int][source]¶
Collect the lengths of generated outputs from the local process.
This function extracts the length of each output by counting the tokens in the ‘output_token_ids’ field of each output dictionary.
- Parameters:
outputs (list) – List of generation outputs, each containing ‘output_token_ids’
- Returns:
List of output lengths corresponding to each input output
- Return type:
list
Example:
>>> outputs = [ ... {"output_token_ids": [1, 2, 3, 4, 5]}, ... {"output_token_ids": [10, 20]}, ... {"output_token_ids": [100, 200, 300]} ... ] >>> lengths = collect_local_output_lengths(outputs) >>> print(lengths) # [5, 2, 3]
gather_all_lengths¶
- lightrft.strategy.utils.statistic.gather_all_lengths(local_lengths: List[int], group: torch.distributed.ProcessGroup) List[int][source]¶
Gather output lengths from all processes in the distributed group.
This function uses PyTorch’s distributed communication to collect length data from all processes in the specified group, enabling global analysis of generation length distributions across the entire distributed system.
- Parameters:
local_lengths (list) – List of output lengths from the local process
group (torch.distributed.ProcessGroup) – The distributed process group for communication
- Returns:
Combined list of output lengths from all processes
- Return type:
list
Example:
>>> # Assuming distributed environment is set up >>> local_lengths = [5, 3, 7] >>> all_lengths = gather_all_lengths(local_lengths, dist.group.WORLD) >>> # all_lengths now contains lengths from all processes
analyze_output_lengths¶
- lightrft.strategy.utils.statistic.analyze_output_lengths(all_lengths: List[int], percentiles: List[int]) Dict[str, Any][source]¶
Analyze the distribution of output lengths and compute statistics.
This function computes comprehensive statistics about the length distribution, including basic statistics (min, max, mean, median) and user-specified percentiles. The results provide insights into the generation behavior and can help with optimization and monitoring.
- Parameters:
all_lengths (List[int]) – List of output lengths from all processes
percentiles (List[int]) – List of percentiles to compute (e.g., [25, 50, 75, 90])
- Returns:
Dictionary containing statistics about the length distribution
- Return type:
Dict[str, Any]
Example:
>>> lengths = [10, 15, 20, 25, 30, 35, 40] >>> stats = analyze_output_lengths(lengths, percentiles=[25, 50, 75]) >>> print(f"Mean: {stats['mean_length']}") >>> print(f"75th percentile: {stats['percentiles'][75]}")