Shortcuts

Source code for lightrft.datasets.utils

"""
Utility functions for dataset processing.

Parts of this file are adapted from Open-Reasoner-Zero:
https://github.com/Open-Reasoner-Zero/Open-Reasoner-Zero
"""

from abc import ABC, abstractmethod

import re
import io
from PIL import Image
from typing import Any, Dict, List, Tuple, Union

import torch
import torch.nn.functional as F


[docs]def find_subsequence(lst: List[int], sub: List[int]) -> int: """Find first index where ``sub`` appears in ``lst``. This function is used to finda marker token sequence (e.g. assistant-start) in the token id list so prompt and response can be separated for label masking. Complexity: Implements the KMP algorithm: O(n + m) time, O(m) extra space. :param lst: Sequence to search (e.g., list of token ids). :type lst: List[int] :param sub: Subsequence (pattern) to find. :type sub: List[int] :returns: Index of first occurrence or -1 if not found. :rtype: int """ if not sub: return 0 # empty pattern matches at 0 n, m = len(lst), len(sub) if m > n: return -1 # build lps array (longest proper prefix which is also suffix) lps = [0] * m length = 0 i = 1 while i < m: if sub[i] == sub[length]: length += 1 lps[i] = length i += 1 else: if length: length = lps[length - 1] else: lps[i] = 0 i += 1 # search i = j = 0 # i -> lst, j -> sub while i < n: if lst[i] == sub[j]: i += 1 j += 1 if j == m: return i - m else: if j: j = lps[j - 1] else: i += 1 return -1
[docs]def extract_answer(text: str) -> Union[str, None]: """ Extract the content inside <answer>...</answer> from a given text. :param text: The input text containing the <answer> tags. :type text: str :return: The extracted string inside the <answer> tags, or None if not found. :rtype: Union[str, None] Example:: >>> text = "The result is <answer>Image 1 is better</answer> based on the evaluation." >>> answer = extract_answer(text) >>> print(answer) # Output: Image 1 is better """ pattern = r"<answer>(.*?)</answer>" match = re.search(pattern, text, flags=re.DOTALL) if match: return match.group(1).strip() return None
[docs]def zero_pad_sequences(sequences, side: str = "left", value=0) -> torch.Tensor: """ Pad a list of 1D/2D tensors on the last dimension and stack them. :param sequences: Iterable of torch.Tensor objects. Each tensor's last dimension is treated as the sequence length to be padded. :type sequences: Iterable[torch.Tensor] :param side: Side to apply padding, either "left" or "right" :type side: str :param value: Padding value :type value: int | float :return: Stacked tensor with shape (N, ...) where sequences are padded to equal length :rtype: torch.Tensor Example:: >>> seqs = [torch.tensor([1,2,3]), torch.tensor([4,5])] >>> zero_pad_sequences(seqs, side="left", value=0) tensor([[1, 2, 3], [0, 4, 5]]) """ sequences = list(sequences) if len(sequences) == 0: raise ValueError("sequences must contain at least one tensor") if side not in ("left", "right"): raise ValueError("side must be either 'left' or 'right'") # Determine target length from last dimension max_len = max(seq.size(-1) for seq in sequences) padded = [] for seq in sequences: if seq.dim() == 0: # scalar -> treat as length-1 sequence seq = seq.unsqueeze(0) pad_len = max_len - seq.size(-1) if pad_len == 0: padded.append(seq) continue padding = (pad_len, 0) if side == "left" else (0, pad_len) padded.append(F.pad(seq, padding, value=value)) return torch.stack(padded, dim=0)
[docs]def exist_and_not_none(d, key): """ Check if a key exists in dictionary and its value is not None. :param d: Dictionary to check. :type d: dict :param key: Key to look for. :type key: Any :return: True if key exists and value is not None. :rtype: bool """ return key in d and not d[key] is None
[docs]def load_multimodal_content(media_info: Dict) -> Dict: """ Load multimodal content (images, videos, audios, etc.) specified by `media_info`. Keys in each entry can include: - 'image_local_path' | 'image_bytes' - 'video_local_path' - 'audio_local_path' Returns a dict mapping names to loaded objects or paths. :param media_info: Example: {'init_image': {'image_local_path': '/path/img.jpg'}, 'video': {'video_local_path': '/path/vid.mp4'}, 'audio': {'audio_local_path': '/path/audio.wav'}} :type media_info: Dict[str, Dict[str, Any]] :return: A dict mapping the same keys to loaded objects, for example: - images (from path or bytes) are returned as PIL.Image.Image - videos are returned as the original local path (str) - audios are returned as the original local path (str) If a key cannot be loaded it will be omitted from the result. :rtype: Dict[str, Any] """ loaded_content = {} for key, info in media_info.items(): if "image_local_path" in info: loaded_content[key] = Image.open(info["image_local_path"]) elif "image_bytes" in info: loaded_content[key] = Image.open(io.BytesIO(info["image_bytes"])) elif "video_local_path" in info: loaded_content[key] = info["video_local_path"] # return the local path directly elif "audio_local_path" in info: loaded_content[key] = info["audio_local_path"] elif "audio_bytes" in info: loaded_content[key] = io.BytesIO(info["audio_bytes"]) return loaded_content
[docs]def get_task_instructions(handler: Any, config: Dict[str, Any]) -> str: """ Select task instruction based on task type from handler and config. :param handler: Data handler instance. :param config: Configuration dictionary which contains 'task_instruction'. :return: The selected task instruction. """ task_instruction_raw = config.get("task_instruction") if isinstance(task_instruction_raw, dict): if hasattr(handler, "task_type"): prompt = task_instruction_raw.get(handler.task_type) if prompt is None: raise ValueError(f"Task instruction for {handler.task_type} not found.") else: raise ValueError(f"Handler {handler.__class__.__name__} does not specify a task_type.") return prompt elif isinstance(task_instruction_raw, str): return task_instruction_raw else: raise ValueError("task_instruction in config must be either a dict or a str.")
[docs]class BaseDataHandler(ABC): """ Base class for data handlers. """
[docs] @abstractmethod def load_data(self, path: str) -> List[Dict[str, Any]]: """ Load all data items from a data config file, e.g. a json file, or a parquet file. :param path: The path to load data from. :type path: str :return: A list of raw data items. :rtype: List[Dict[str, Any]] """ raise NotImplementedError
[docs] @abstractmethod def get_media_info(self, item: Dict[str, Any]) -> Dict[str, Dict[str, str]]: """ Extract path info for all media info from the raw item. :param item: The raw data item. :type item: Dict[str, Any] :return: A dict where keys are logical names (e.g. 'init_image') and values are path dicts. :rtype: Dict[str, Dict[str, str]] Example:: >>> item = {'init_image_path': '/path/img.jpg', 'video_path': '/path/vid.mp4'} >>> visual_info = get_media_info(item) >>> print(visual_info) {'init_image': {'image_local_path': '/path/img.jpg'}, 'video': {'video_local_path': '/path/vid.mp4'}} """ raise NotImplementedError
[docs] @abstractmethod def parse_item(self, item: Dict[str, Any], media_content: Dict[str, Any], config: Dict[str, Any]) -> Union[Tuple[List[Dict], List[Dict], Dict], Tuple[List[Dict], Dict]]: """ Parse the raw item and the loaded media_content into the standard format. :param item: The raw data item. :type item: Dict[str, Any] :param media_content: A dict containing loaded content (e.g. PIL Images, Video paths). :type media_content: Dict[str, Any] :param config: A dict of additional configuration options (e.g. prompt templates, max_pixels). :type config: Dict[str, Any] :return: A tuple containing message lists and a metadata dictionary. - For point-wise scoring data (e.g., Scalar Reward Model training/evaluation): Return (messages_chosen, messages_rejected, other) - For pair-wise ranking data (e.g., Generative Reward Model training/evaluation): Return (messages, other) The `other` dictionary contains metadata, and can optionally include: - "preference": (str) Indicates the ground truth preferred choice ("A", "B", or "C"). - "task_type": (str) The type of task (e.g., "text-to-video"). - "reward_rule_label": (str) A label used in RL to identify which reward function or reward model to apply to this specific sample when performing reinforcement fine-tuning. :rtype: Union[Tuple[List[Dict], List[Dict], Dict], Tuple[List[Dict], Dict]] """ raise NotImplementedError