Shortcuts

Source code for lightrft.datasets.grm_dataset

import random
import torch
from torch.utils.data import Dataset
from loguru import logger
from typing import List, Dict, Any, Tuple, Optional
from transformers import AutoTokenizer, AutoProcessor

from .omnirewardbench import OmniRewardBenchT2IGRMHandler
from .imagegen_cot_reward import ImageGenCoTRewardGRMHandler
from .hpdv3 import HPDv3GRMHandler
from .utils import zero_pad_sequences, load_multimodal_content, find_subsequence


[docs]class GRMDataset(Dataset): """ Dataset for Generative Reward Model (GRM) training. GRMDataset supports multiple data sources through pluggable Data Handlers and covers both understanding tasks (image-to-text, video-to-text) and generation tasks (text-to-image, text-to-video). :param dataset_paths: List of dataset file paths or directories, in the format ``source:path`` where the handler is determined by the source keyword such as hpdv3, imagegen-cot-reward, or omnirewardbench. :type dataset_paths: List[str] :param processor: Multimodal processor used for tokenization and visual processing. :type processor: transformers.AutoProcessor :param tokenizer: Tokenizer used for text tokenization (provides ``eos_token`` and ``pad_token_id`` attributes). :type tokenizer: transformers.AutoTokenizer :param strategy: Optional data loading strategy. :type strategy: Any :param max_length: Maximum sequence length for tokenization/truncation. :type max_length: int :param config: Additional configuration options. Supported keys include: - ``task_instruction`` (str): Instruction for the evaluation task. - ``system_prompt_template`` (str): Template for the system prompt with a ``{prompt}`` placeholder. :type config: Dict[str, Any] :param is_training: Whether the dataset is used for training (returns labels) or evaluation (no labels returned). :type is_training: bool **Example:** .. code-block:: python dataset = GRMDataset([ 'imagegen-cot-reward-5k:/data/imagegen-cot-reward-5k/train.json' ], processor=proc, tokenizer=tok, max_length=4096, is_training=True) """ def __init__( self, dataset_paths: List[str], processor: AutoProcessor, tokenizer: AutoTokenizer, strategy=None, max_length: int = 4096, config: Dict[str, Any] = None, is_training: bool = True ): super().__init__() self.processor = processor self.tokenizer = tokenizer self.strategy = strategy self.max_length = max_length self.config = config if config else {} self.is_training = is_training self.config["is_training"] = is_training self.media_content_loader = load_multimodal_content if "qwen" in self.processor.__class__.__name__.lower(): from qwen_vl_utils import process_vision_info self.process_vision_info = process_vision_info elif "keye" in self.processor.__class__.__name__.lower(): from keye_vl_utils import process_vision_info self.process_vision_info = process_vision_info else: raise NotImplementedError(f"Processor type {self.processor.__class__.__name__} not supported yet.") self.handlers = { "imagegen-cot-reward-5k": ImageGenCoTRewardGRMHandler(), "omnirewardbench-t2i": OmniRewardBenchT2IGRMHandler(), "hpdv3": HPDv3GRMHandler(), } # Load data from all specified dataset paths # We expect dataset_paths to be in the format: "source:path" # e.g. "rapidata-t2v:/path/to/file.parquet" self.data = [] for item in dataset_paths: try: source, path = item.split(":", 1) except ValueError: raise ValueError(f"Dataset path '{item}' is not in the expected format 'source:path'.") if source not in self.handlers: raise NotImplementedError(f"The data handler for source {source} is not implemented.") handler = self.handlers[source] try: loaded_items = handler.load_data(path) for item in loaded_items: item["source"] = source self.data.extend(loaded_items) except Exception as e: logger.error(f"Failed to load data {path} (source: {source}): {e}") logger.info(f"Loaded {len(self.data)} items in total, sources: {list(dataset_paths)}") random.shuffle(self.data) def __len__(self) -> int: """ Get the total number of items in the dataset. :return: Total number of items :rtype: int """ return len(self.data) def __getitem__(self, idx: int) -> Tuple[Dict[str, torch.Tensor], Optional[torch.Tensor], Dict[str, Any]]: """ Get a single item from the dataset by index. :param idx: Index of the item to retrieve :type idx: int :return: A tuple of (tokens, labels, metadata). tokens is a dictionary of tensors, labels is a tensor (or None), and metadata is a dictionary. :rtype: Tuple[Dict[str, torch.Tensor], Optional[torch.Tensor], Dict[str, Any]] **Example:** .. code-block:: python tokens, labels, meta = dataset[0] """ item = self.data[idx] source = item["source"] handler = self.handlers[source] # Get paths for all media content media_info = handler.get_media_info(item) # Load all media content at once loaded_content = self.media_content_loader(media_info) if loaded_content is None: raise RuntimeError(f"Failed to load media content: {media_info}") # Pass the loaded content dict to parse_item messages, other = handler.parse_item(item, loaded_content, self.config) # Tokenize the message if self.is_training: input_token, labels = self._tokenize_msg_for_training(messages) return input_token, labels, other else: input_token = self._tokenize_msg_for_eval(messages) return input_token, None, other def _tokenize_msg_for_training(self, messages: List[Dict]) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: """ Tokenize messages for training, including labels for the assistant's response. :param messages: List of message dictionaries following the OpenAI format :type messages: List[Dict] :return: A tuple of (tokenized_input, labels) :rtype: Tuple[Dict[str, torch.Tensor], torch.Tensor] """ input_text = self.processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=False, ) if not input_text.endswith(self.tokenizer.eos_token): input_text += " " + self.tokenizer.eos_token image_inputs, video_inputs, video_kwargs = self.process_vision_info( messages, return_video_kwargs=True, ) tokenized = self.processor( text=[input_text], images=image_inputs, videos=video_inputs, padding=True, padding_side="left", truncation=False, return_tensors="pt", add_special_tokens=False, **video_kwargs, ) input_ids = tokenized["input_ids"][0] # Find prompt->response boundary assistant_marker = "<|im_start|>assistant" # For Qwen2.5-VL marker_ids = self.tokenizer(assistant_marker, add_special_tokens=False).input_ids # Search for the position of marker_ids in input_ids response_start = find_subsequence(input_ids.tolist(), marker_ids) if response_start == -1: raise RuntimeError("Could not find '<|im_start|>assistant' token to determine response boundary.") # Create labels: only compute loss on the response part labels = input_ids.clone().unsqueeze(0) # Mask out the prompt part (excluding the <|assistant|> token itself) labels[:, :response_start] = -100 # Fix eos alignment tokenized["input_ids"][0][-1] = self.tokenizer.eos_token_id tokenized["attention_mask"][0][-1] = True return tokenized, labels def _tokenize_msg_for_eval(self, messages: List[Dict]) -> Dict[str, torch.Tensor]: """ Tokenize messages for evaluation (prompt-only). :param messages: List of message dictionaries following the OpenAI format :type messages: List[Dict] :return: Tokenized input dictionary :rtype: Dict[str, torch.Tensor] """ # Remove the last assistant response if present if messages and messages[-1]['role'] == 'assistant': messages = messages[:-1] prompt_only_text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) image_inputs, video_inputs, video_kwargs = self.process_vision_info(messages, return_video_kwargs=True) input_token = self.processor( text=[prompt_only_text], images=image_inputs, videos=video_inputs, max_length=self.max_length, padding=True, padding_side="left", truncation=True, return_tensors="pt", add_special_tokens=False, **video_kwargs, ) return input_token
[docs] def collate_fn(self, batch: List[Tuple]) -> Optional[Tuple]: """ Collate a batch of items into a single batch for model processing. :param batch: A list of items returned by __getitem__ :type batch: List[Tuple] :return: A tuple containing batched input_ids, attention_mask, pixel_values, grid_sizes, labels, and extras. :rtype: Optional[Tuple] **Example:** .. code-block:: python batch = dataset.collate_fn([dataset[i] for i in range(4)]) """ batch = [b for b in batch if b is not None] if not batch: return None input_ids_list, input_masks_list = [], [] input_img_pixels, input_img_grid = [], [] input_video_pixels, input_video_grid = [], [] labels_list = [] extras_list = [] for input_token, labels, extra in batch: extras_list.append(extra) # --- Get text --- input_ids_list.append(input_token['input_ids']) input_masks_list.append(input_token['attention_mask']) # --- Get labels --- if labels is not None: labels_list.append(labels) # --- Get visuals --- if 'pixel_values' in input_token: input_img_pixels.append(input_token['pixel_values']) input_img_grid.append(input_token['image_grid_thw']) if 'pixel_values_videos' in input_token: input_video_pixels.append(input_token['pixel_values_videos']) input_video_grid.append(input_token['video_grid_thw']) padding_side = "left" input_ids = zero_pad_sequences(input_ids_list, side=padding_side, value=self.tokenizer.pad_token_id) input_masks = zero_pad_sequences(input_masks_list, side=padding_side) if labels_list: labels_list = zero_pad_sequences(labels_list, side=padding_side, value=-100) else: labels_list = None return ( # Text inputs input_ids, input_masks, # Image inputs torch.cat(input_img_pixels, dim=0) if input_img_pixels else None, torch.cat(input_img_grid, dim=0) if input_img_grid else None, # Video inputs torch.cat(input_video_pixels, dim=0) if input_video_pixels else None, torch.cat(input_video_grid, dim=0) if input_video_grid else None, # Labels labels_list, # Extras extras_list )