Shortcuts

Source code for lightrft.strategy.utils.statistic

"""
Utilities for analyzing and visualizing the length distribution of generated outputs.

This module provides functionality for collecting, analyzing, and visualizing the length
distribution of generated outputs from language models. It includes tools for gathering
output lengths across distributed processes, computing statistics like percentiles,
and creating visualizations using matplotlib and TensorBoard.

The main components are:
- GenLenAnalyser: A class for continuous monitoring and visualization of generation lengths
- Helper functions for collecting and analyzing output lengths in distributed environments
"""

from typing import List, Dict, Any, Optional
import os

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.distributed as dist
from torch.utils.tensorboard import SummaryWriter


[docs]class GenLenAnalyser: """ Analyzer for tracking and visualizing the length distribution of generated outputs. This class collects length statistics of generated outputs over time, computes various metrics, and can visualize the distributions using matplotlib and TensorBoard. It is designed to work in distributed training environments and provides continuous monitoring capabilities with configurable plotting intervals. :param engine_dp_group: The distributed process group for communication :type engine_dp_group: torch.distributed.ProcessGroup :param plot_every: How often to plot the distribution (in steps), set to 0 to disable plotting :type plot_every: int :param percentiles: List of percentiles to compute for the length distribution :type percentiles: list :param plot_out_dir: Directory to save plots and TensorBoard logs, if None no plots are saved :type plot_out_dir: str or None Example:: >>> import torch.distributed as dist >>> # Initialize distributed process group >>> analyzer = GenLenAnalyser( ... engine_dp_group=dist.group.WORLD, ... plot_every=10, ... percentiles=[25, 50, 75, 90], ... plot_out_dir="./output_analysis" ... ) >>> # Use during training loop >>> stats = analyzer.collect(generation_outputs, step=100, is_rank_0=True) """ def __init__( self, engine_dp_group: dist.ProcessGroup, plot_every: int = 2, percentiles: List[int] = [50, 80], plot_out_dir: Optional[str] = None ) -> None: self.engine_dp_group = engine_dp_group self.plot_out_dir = plot_out_dir self.percentiles = percentiles self.plot_every = plot_every self.hist_data = {} if plot_every > 0 and plot_out_dir is not None: os.makedirs(plot_out_dir, exist_ok=True) self.tb_writer = SummaryWriter(log_dir=plot_out_dir) print(f"GenLenAnalyser is initialized and will log to {plot_out_dir} every {self.plot_every}")
[docs] def collect(self, gen_outputs: List[Dict[str, Any]], cur_step: int, is_rank_0: bool) -> Optional[Dict[str, Any]]: """ Collect and analyze generation length data at the current step. This method gathers output lengths from all processes, computes statistics, and optionally creates visualizations if conditions are met. The collection happens at intervals specified by plot_every parameter. :param gen_outputs: List of generation outputs to analyze, each containing 'output_token_ids' :type gen_outputs: List[Dict[str, Any]] :param cur_step: Current training/generation step :type cur_step: int :param is_rank_0: Whether the current process is the main process (rank 0) :type is_rank_0: bool :return: Dictionary containing length statistics or None if collection is skipped :rtype: dict or None Example:: >>> gen_outputs = [ ... {"output_token_ids": [1, 2, 3, 4, 5]}, ... {"output_token_ids": [1, 2, 3]} ... ] >>> stats = analyzer.collect(gen_outputs, cur_step=50, is_rank_0=True) >>> if stats: ... print(f"Mean length: {stats['mean_length']}") """ if self.plot_every > 0 and cur_step % self.plot_every != 0: return None local_out_lens = collect_local_output_lengths(gen_outputs) glb_output_lens = gather_all_lengths(local_out_lens, self.engine_dp_group) self.hist_data[cur_step] = glb_output_lens if self.plot_out_dir is not None and is_rank_0: for step, vals in self.hist_data.items(): self.tb_writer.add_histogram( "VLLM GenerateOutputLength Distribution", np.asarray(vals, dtype="int"), step, bins="auto", max_bins=50, ) plot_out_dir = f"{self.plot_out_dir}/gen_len_step_{cur_step}.png" plt.figure(figsize=(10, 6)) plt.xlabel("GenerateOutputLength") plt.ylabel("Frequency") plt.title("VLLM GenerateOutputLength Distribution") plt.grid(True, alpha=0.3) plt.legend() plot_data = self.hist_data colors = plt.cm.viridis(range(len(plot_data))) for (step, data), color in zip(plot_data.items(), colors): plt.hist(data, label=step, color=color) plt.savefig(plot_out_dir, bbox_inches="tight") infos = analyze_output_lengths(glb_output_lens, self.percentiles) return infos
[docs]def analyse_output_lengths( gen_outputs: List[Dict[str, Any]], engine_dp_group: dist.ProcessGroup, percentiles: List[int] = [50, 80], plot_out_dir: Optional[str] = None, prefix: str = "" ) -> Dict[str, Any]: """ Analyze the length distribution of generated outputs. This is a convenience function that collects local output lengths, gathers them across all processes, and computes statistics. It provides a one-time analysis without the continuous monitoring capabilities of GenLenAnalyser. :param gen_outputs: List of generation outputs to analyze, each containing 'output_token_ids' :type gen_outputs: list :param engine_dp_group: The distributed process group for communication :type engine_dp_group: torch.distributed.ProcessGroup :param percentiles: List of percentiles to compute for the length distribution :type percentiles: list :param plot_out_dir: Directory to save plots, if None no plots are saved :type plot_out_dir: str or None :param prefix: Prefix for plot filenames :type prefix: str :return: Dictionary containing length statistics :rtype: dict Example:: >>> gen_outputs = [ ... {"output_token_ids": [1, 2, 3, 4, 5, 6]}, ... {"output_token_ids": [1, 2, 3]} ... ] >>> stats = analyse_output_lengths( ... gen_outputs, ... engine_dp_group=dist.group.WORLD, ... percentiles=[25, 50, 75] ... ) >>> print(f"Median length: {stats['median_length']}") """ local_out_lens = collect_local_output_lengths(gen_outputs) glb_output_lens = gather_all_lengths(local_out_lens, engine_dp_group) analyse_info = analyze_output_lengths(glb_output_lens, percentiles, plot_out_dir) return analyse_info
[docs]def collect_local_output_lengths(outputs: List[Dict[str, Any]]) -> List[int]: """ Collect the lengths of generated outputs from the local process. This function extracts the length of each output by counting the tokens in the 'output_token_ids' field of each output dictionary. :param outputs: List of generation outputs, each containing 'output_token_ids' :type outputs: list :return: List of output lengths corresponding to each input output :rtype: list Example:: >>> outputs = [ ... {"output_token_ids": [1, 2, 3, 4, 5]}, ... {"output_token_ids": [10, 20]}, ... {"output_token_ids": [100, 200, 300]} ... ] >>> lengths = collect_local_output_lengths(outputs) >>> print(lengths) # [5, 2, 3] """ output_lengths = [] for i, output in enumerate(outputs): # This key is set in strategy.gather_and_generate output_len = len(output["output_token_ids"]) output_lengths.append(output_len) return output_lengths
[docs]def gather_all_lengths(local_lengths: List[int], group: dist.ProcessGroup) -> List[int]: """ Gather output lengths from all processes in the distributed group. This function uses PyTorch's distributed communication to collect length data from all processes in the specified group, enabling global analysis of generation length distributions across the entire distributed system. :param local_lengths: List of output lengths from the local process :type local_lengths: list :param group: The distributed process group for communication :type group: torch.distributed.ProcessGroup :return: Combined list of output lengths from all processes :rtype: list Example:: >>> # Assuming distributed environment is set up >>> local_lengths = [5, 3, 7] >>> all_lengths = gather_all_lengths(local_lengths, dist.group.WORLD) >>> # all_lengths now contains lengths from all processes """ local_lengths_tensor = torch.tensor(local_lengths, dtype=torch.int64, device=torch.cuda.current_device()) world_size = dist.get_world_size(group=group) gathered_lengths = [torch.zeros_like(local_lengths_tensor) for _ in range(world_size)] dist.all_gather(gathered_lengths, local_lengths_tensor, group=group) all_lengths = [] for gathered_len in gathered_lengths: all_lengths.extend(gathered_len.tolist()) return all_lengths
[docs]def analyze_output_lengths(all_lengths: List[int], percentiles: List[int]) -> Dict[str, Any]: """ Analyze the distribution of output lengths and compute statistics. This function computes comprehensive statistics about the length distribution, including basic statistics (min, max, mean, median) and user-specified percentiles. The results provide insights into the generation behavior and can help with optimization and monitoring. :param all_lengths: List of output lengths from all processes :type all_lengths: List[int] :param percentiles: List of percentiles to compute (e.g., [25, 50, 75, 90]) :type percentiles: List[int] :return: Dictionary containing statistics about the length distribution :rtype: Dict[str, Any] Example:: >>> lengths = [10, 15, 20, 25, 30, 35, 40] >>> stats = analyze_output_lengths(lengths, percentiles=[25, 50, 75]) >>> print(f"Mean: {stats['mean_length']}") >>> print(f"75th percentile: {stats['percentiles'][75]}") """ all_lengths = np.array(all_lengths) min_len = np.min(all_lengths) max_len = np.max(all_lengths) mean_len = np.mean(all_lengths) median_len = np.median(all_lengths) stats = { "total_samples": len(all_lengths), "min_length": min_len, "max_length": max_len, "mean_length": mean_len, "median_length": median_len, "percentiles": {}, } percentile_values = np.percentile(all_lengths, percentiles) for p, v in zip(percentiles, percentile_values): stats["percentiles"][p] = round(v) # if plot_out_dir is not None: # plt.figure(figsize=(10, 6)) # plt.hist(all_lengths, bins=50, alpha=0.7, color="blue") # plt.xlabel("GenerateOutputLength") # plt.ylabel("Frequency") # plt.title("VLLM GenerateOutputLength Distribution") # plt.grid(True, alpha=0.3) # plt.legend() # plt.savefig(plot_out_dir, bbox_inches="tight") return stats