Shortcuts

Source code for lightrft.utils.utils

import os
import sys
from typing import Any, Dict, List, Optional, Tuple, Union

from datasets import interleave_datasets, load_dataset, load_from_disk, Dataset, DatasetDict
from transformers import AutoTokenizer, AutoProcessor, PreTrainedTokenizer, PreTrainedModel, ProcessorMixin
import torch
import torch.distributed as dist


def get_tokenizer(
    pretrain: str, model: PreTrainedModel, padding_side: str = "left", use_fast: bool = True
) -> PreTrainedTokenizer:
    """
    Load and configure a tokenizer for language models.

    :param pretrain: Path or name of the pretrained tokenizer.
    :type pretrain: str
    :param model: Model instance to sync pad_token_id with.
    :type model: transformers.PreTrainedModel
    :param padding_side: Which side to pad on ('left' or 'right'). Defaults to 'left'
        for causal language models to enable efficient batching during generation,
        where padding tokens should be on the left to avoid affecting the generation.
    :type padding_side: str
    :param use_fast: Whether to use fast tokenizer implementation.
    :type use_fast: bool
    :return: Configured tokenizer instance.
    :rtype: transformers.PreTrainedTokenizer
    """
    tokenizer = AutoTokenizer.from_pretrained(pretrain, trust_remote_code=True, use_fast=use_fast)
    tokenizer.padding_side = padding_side
    # NOTE: When enable vLLM, do not resize_token_embeddings, or the vocab size will mismatch with vLLM.
    # https://github.com/facebookresearch/llama-recipes/pull/196
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id
        model.config.pad_token_id = tokenizer.pad_token_id

    return tokenizer


def get_tokenizer_processor_vl(
    pretrain: str,
    model: PreTrainedModel,
    padding_side: str = "left",
    use_fast: bool = True
) -> Tuple[PreTrainedTokenizer, ProcessorMixin]:
    """
    Load and configure tokenizer and processor for vision-language models.

    :param pretrain: Path or name of the pretrained model.
    :type pretrain: str
    :param model: Model instance to sync pad_token_id with.
    :type model: transformers.PreTrainedModel
    :param padding_side: Which side to pad on ('left' or 'right'). Defaults to 'left'
        for causal language models to enable efficient batching during generation,
        where padding tokens should be on the left to avoid affecting the generation.
    :type padding_side: str
    :param use_fast: Whether to use fast tokenizer implementation.
    :type use_fast: bool
    :return: Tuple of (tokenizer, processor).
    :rtype: Tuple[transformers.PreTrainedTokenizer, transformers.ProcessorMixin]
    """
    tokenizer = AutoTokenizer.from_pretrained(pretrain, trust_remote_code=True, use_fast=use_fast)
    processor = AutoProcessor.from_pretrained(pretrain, trust_remote_code=True, use_fast=use_fast)

    tokenizer.padding_side = padding_side
    # NOTE: When enable vLLM, do not resize_token_embeddings, or the vocab size will mismatch with vLLM.
    # https://github.com/facebookresearch/llama-recipes/pull/196
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id
        model.config.pad_token_id = tokenizer.pad_token_id

    return tokenizer, processor


def blending_datasets(
    datasets: str,
    probabilities: str,
    strategy: Optional[Any] = None,
    seed: int = 42,
    max_count: int = 5000000,
    return_eval: bool = True,
    stopping_strategy: str = "first_exhausted",
    train_split: str = "train",
    eval_split: str = "test",
) -> Union[Dataset, Tuple[Dataset, Dataset]]:
    """
    Load and blend multiple datasets with specified sampling probabilities.

    Supports various dataset formats including local files (.json, .jsonl, .csv),
    HuggingFace datasets, and datasets saved with ``save_to_disk``.

    :param datasets: Comma-separated dataset paths or names (e.g., 'path1,path2').
    :type datasets: str
    :param probabilities: Comma-separated sampling probabilities (e.g., '0.5,0.5').
    :type probabilities: str
    :param strategy: Optional training strategy for distributed logging.
    :type strategy: Optional[Any]
    :param seed: Random seed for reproducible interleaving.
    :type seed: int
    :param max_count: Maximum number of samples to load per dataset.
    :type max_count: int
    :param return_eval: Whether to return evaluation dataset.
    :type return_eval: bool
    :param stopping_strategy: How to handle datasets of different sizes
        ('first_exhausted' or 'all_exhausted').
    :type stopping_strategy: str
    :param train_split: Name of the training split.
    :type train_split: str
    :param eval_split: Name of the evaluation split.
    :type eval_split: str
    :return: Training dataset, or tuple of (train_dataset, eval_dataset) if return_eval=True.
    :rtype: Union[Dataset, Tuple[Dataset, Dataset]]
    """
    datasets = datasets.split(",")
    probabilities = list(map(float, probabilities.split(",")))
    assert len(probabilities) == len(datasets)

    train_data_list = []
    eval_data_list = []
    for i, dataset in enumerate(datasets):
        dataset = dataset.strip()
        if strategy:
            strategy.print(f"dataset: {dataset}")

        data_dir = dataset.split("@")[1].strip() if "@" in dataset else None
        dataset = dataset.split("@")[0].strip()
        dataset_basename = os.path.basename(dataset)

        ext = os.path.splitext(dataset)[-1]
        # local python script
        if ext == ".py" or (os.path.isdir(dataset) and os.path.exists(os.path.join(dataset, f"{dataset_basename}.py"))):
            data = load_dataset(dataset, trust_remote_code=True)
            if strategy:
                strategy.print(f"loaded {dataset} with python script")
        # local text file
        elif ext in [".json", ".jsonl", ".csv"]:
            ext = ext.lower().strip(".")
            if ext == "jsonl":
                ext = "json"
            data = load_dataset(ext, data_files=dataset)
            if strategy:
                strategy.print(f"loaded {dataset} with data_files={dataset}")
        # local dataset saved with `datasets.Dataset.save_to_disk`
        elif os.path.isdir(dataset):
            try:
                data = load_from_disk(dataset)
                if strategy:
                    strategy.print(f"loaded {dataset} from disk")
            except Exception as e:
                if strategy:
                    strategy.print(f"failed to load {dataset} from disk: {e}")
                data = load_dataset(dataset, data_dir=data_dir)
                if strategy:
                    strategy.print(f"loaded {dataset} from files")
        # remote/local folder or common file
        else:
            data = load_dataset(dataset, data_dir=data_dir)
            if strategy:
                strategy.print(f"loaded {dataset} from files")

        # ==================== FIX AND OPTIMIZATION START ====================
        # This block is made robust to handle both Dataset and DatasetDict objects.

        # Try to get the specified training split
        if train_split and train_split in data:
            train_data = data[train_split].select(range(min(max_count, len(data[train_split]))))
        else:
            # If the specified split is not found, or if data is a single Dataset
            actual_dataset = None
            if isinstance(data, DatasetDict):
                # If it's a DatasetDict, intelligently use the first available split.
                # This makes the function compatible with datasets that don't have a 'train' split.
                available_splits = list(data.keys())
                if not available_splits:
                    raise ValueError(f"DatasetDict loaded from {dataset} is empty and has no splits.")

                split_name = available_splits[0]
                actual_dataset = data[split_name]
                if strategy:
                    strategy.print(
                        f"WARN: '{train_split}' split not found or not provided. Using the first split: '{split_name}'"
                    )
            elif isinstance(data, Dataset):
                # If it's already a single Dataset, use it directly.
                actual_dataset = data
            else:
                raise TypeError(f"Loaded data from {dataset} is of an unexpected type: {type(data)}")

            train_data = actual_dataset.select(range(min(max_count, len(actual_dataset))))

        train_data_list.append(train_data)
        # ===================== FIX AND OPTIMIZATION END =====================

        if return_eval:
            # Try to get the specified evaluation split
            if eval_split and eval_split in data:
                eval_data = data[eval_split].select(range(min(max_count, len(data[eval_split]))))
            else:
                # Fallback for evaluation data: use a small fraction of the training data.
                # This part is safe because `train_data` is guaranteed to be a `Dataset` object.
                eval_data = train_data.select(range(min(max_count, int(len(train_data) * 0.03))))
                if strategy:
                    strategy.print(
                        f"WARN: '{eval_split}' split not found. Using 3% of the training data for evaluation."
                    )
            eval_data_list.append(eval_data)

    # merge datasets
    if strategy and strategy.is_rank_0():
        print(f"Blending {len(train_data_list)} training datasets...")
        print(train_data_list)

    train_dataset = interleave_datasets(
        train_data_list,
        probabilities=probabilities,
        seed=seed,
        stopping_strategy=stopping_strategy,
    )
    if return_eval:
        if strategy and strategy.is_rank_0():
            print(f"Blending {len(eval_data_list)} evaluation datasets...")
            print(eval_data_list)
        eval_dataset = interleave_datasets(
            eval_data_list,
            probabilities=probabilities,
            seed=seed,
            stopping_strategy=stopping_strategy,
        )
        return train_dataset, eval_dataset
    else:
        return train_dataset


def convert_token_to_id(token: str, tokenizer: PreTrainedTokenizer) -> int:
    """
    Convert a string token to its corresponding token ID.

    :param token: Token string to convert.
    :type token: str
    :param tokenizer: Tokenizer instance to use for conversion.
    :type tokenizer: transformers.PreTrainedTokenizer
    :return: Token ID.
    :rtype: int
    :raises ValueError: If token is not a string or encodes to multiple IDs.
    """
    if isinstance(token, str):
        token_ids = tokenizer.encode(token, add_special_tokens=False)
        assert len(token_ids) == 1, f"Token '{token}' encodes to {len(token_ids)} IDs, expected 1"
        return token_ids[0]
    else:
        raise ValueError(f"token should be a string, got {type(token).__name__}")





[docs]def get_current_device(num_device_per_node: int = 8) -> torch.device: """ Returns the current CUDA device. This function provides a convenient way to get the current CUDA device being used by PyTorch. :param num_device_per_node: Number of devices per node for distributed training :type num_device_per_node: int :return: Current CUDA device :rtype: torch.device Example:: >>> device = get_current_device() >>> model = model.to(device) """ if not torch.distributed.is_initialized(): return torch.cuda.current_device() return torch.device(f"cuda:{torch.distributed.get_rank() % num_device_per_node}")
[docs]def get_torch_profiler(output_file: str, warmup: int = 1, active: int = 1, repeat: int = 1) -> Union[torch.profiler.profile, "DummyProfile"]: """ Creates and returns a PyTorch profiler configured for distributed training. This function returns a properly configured PyTorch profiler for the current process. For rank 0 process, it returns a full-featured profiler that records CPU and CUDA activities. For other ranks, it returns a dummy profiler that does nothing. For more details on PyTorch Profiler, see: https://docs.pytorch.org/docs/stable/profiler.html :param output_file: Path where profiling results will be saved (only for rank 0) :type output_file: str :param warmup: Number of steps to wait before profiling starts :type warmup: int :param active: Number of steps with active profiling :type active: int :param repeat: Number of times to repeat the profiling cycle :type repeat: int :return: A PyTorch profiler object or a dummy profiler :rtype: torch.profiler.profile or DummyProfile Example:: >>> with get_torch_profiler("./profiler_output", warmup=5, active=10) as prof: >>> for step in range(100): >>> train_step() >>> prof.step() """ from torch.profiler import ProfilerActivity if torch.distributed.get_rank() == 0: prof = torch.profiler.profile( schedule=torch.profiler.schedule(wait=1, warmup=warmup, active=active, repeat=repeat), on_trace_ready=torch.profiler.tensorboard_trace_handler(output_file), activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, with_stack=True, profile_memory=True, ) else: prof = DummyProfile() return prof
[docs]class DummyProfile: """ Dummy Profile class that mimics the PyTorch profiler API but does nothing. This class is used as a placeholder for non-rank-0 processes where profiling is not needed, allowing the same code to be used across all processes without conditional branches. Example:: >>> prof = DummyProfile() if rank != 0 else torch.profiler.profile(...) >>> with prof: >>> # code to be profiled """
[docs] def __init__(self, *args, **kwargs) -> None: """ Initialize a dummy profiler instance. :param args: Positional arguments (ignored) :param kwargs: Keyword arguments (ignored) """ pass
[docs] def __enter__(self) -> "DummyProfile": """ Context manager entry method. :return: Self instance :rtype: DummyProfile """ return self
[docs] def __exit__(self, a: Any, b: Any, c: Any) -> None: """ Context manager exit method. :param a: Exception type :param b: Exception value :param c: Exception traceback """ pass
[docs] def start(self) -> None: """ Dummy implementation of the profiler start method. Does nothing. """ pass
[docs] def stop(self) -> None: """ Dummy implementation of the profiler stop method. Does nothing. """ pass
[docs] def step(self) -> None: """ Dummy implementation of the profiler step method. Does nothing. """ pass
def ensure_video_input_available() -> None: """ Ensure ``VideoInput`` is available from ``transformers.image_utils``. This function handles compatibility issues across different versions of Transformers where ``VideoInput`` may be defined in different modules. Version behavior ---------------- * Transformers < 4.52.0: ``VideoInput`` is defined in ``transformers.image_utils``, so ``from transformers.image_utils import VideoInput`` works. * Transformers >= 4.52.0: ``VideoInput`` has been moved to ``transformers.video_utils`` and is no longer exported from ``transformers.image_utils``. Importing ``VideoInput`` from ``transformers.image_utils`` will fail unless we manually patch it. What this helper does --------------------- * Tries to import ``VideoInput`` from ``transformers.image_utils``. * If that fails (e.g. Transformers >= 4.52.0), it tries to import ``VideoInput`` from ``transformers.video_utils`` instead. * If both imports fail, it creates a dummy ``VideoInput`` class as a fallback. * In all non-error cases, it also attaches ``VideoInput`` back onto the ``transformers.image_utils`` module so that: >>> ensure_video_input_available() >>> from transformers.image_utils import VideoInput # now works for ... # both old and ... # new versions This keeps downstream code compatible with projects that expect ``transformers.image_utils.VideoInput`` to exist, regardless of the installed Transformers version. """ try: from transformers.image_utils import VideoInput except ImportError: try: from transformers.video_utils import VideoInput except ImportError: class VideoInput: """ Placeholder class for VideoInput when transformers doesn't provide it. This class serves as a compatibility shim for older Transformers versions that don't export VideoInput from transformers.image_utils or transformers.video_utils. """ pass import transformers transformers.image_utils.VideoInput = VideoInput sys.modules["transformers.image_utils"].VideoInput = VideoInput def all_gather_and_flatten(data: Any, group: Optional[dist.ProcessGroup] = None) -> List[Any]: """ Gather data from all processes and flatten the result into a single list. :param data: The data to gather from the current process. :type data: Any :param group: The process group to work on. If None, the default process group is used. :type group: ProcessGroup, optional :returns: A flattened list containing data from all processes. :rtype: List[Any] """ if not dist.is_initialized(): return data if isinstance(data, list) else [data] world_size = dist.get_world_size(group=group) gathered_data = [None] * world_size dist.all_gather_object(gathered_data, data, group=group) flattened_data = [] for item in gathered_data: if isinstance(item, list): flattened_data.extend(item) else: flattened_data.append(item) return flattened_data def all_reduce_dict(metrics_dict: Dict[str, float], op: str = "sum", group: Optional[dist.ProcessGroup] = None) -> Dict[str, float]: """ Perform all-reduce operation on a dictionary of metrics. This function converts the dictionary values to a single tensor for efficient reduction. :param metrics_dict: Dictionary of metrics to be reduced. :type metrics_dict: Dict[str, float] :param op: Reduction operation ('sum', 'max', 'min', 'mean'). :type op: str :param group: The process group to work on. If None, the default process group is used. :type group: ProcessGroup, optional :returns: Reduced dictionary of metrics. :rtype: Dict[str, float] """ if not dist.is_initialized(): return metrics_dict keys = sorted(metrics_dict.keys()) values = [metrics_dict[k] for k in keys] # Use the current device if available, otherwise CPU device = torch.cuda.current_device() if torch.cuda.is_available() else torch.device("cpu") tensor = torch.tensor(values, device=device, dtype=torch.float64) dist_op_map = { "sum": dist.ReduceOp.SUM, "max": dist.ReduceOp.MAX, "min": dist.ReduceOp.MIN, "mean": dist.ReduceOp.SUM, # Mean is handled by sum then divide } dist_op = dist_op_map[op.lower()] dist.all_reduce(tensor, op=dist_op, group=group) if op.lower() == "mean": tensor /= dist.get_world_size(group=group) reduced_values = tensor.tolist() return {k: v for k, v in zip(keys, reduced_values)}