Source code for lightrft.datasets.imagegen_cot_reward
import os
import json
from typing import List, Dict, Any, Tuple
from loguru import logger
from .utils import BaseDataHandler
[docs]class ImageGenCoTRewardGRMHandler(BaseDataHandler):
"""
Data handler for ImageGen-CoT-Reward-5K dataset. For Text-to-Image generation task.
Paper: https://arxiv.org/pdf/2505.03318
Dataset Repo: https://huggingface.co/datasets/CodeGoat24/ImageGen-CoT-Reward-5K
"""
task_type = "text-to-image"
[docs] def load_data(self, path: str) -> List[Dict[str, Any]]:
"""
Loads data from json file.
:param path: Path to the dataset JSON file
:type path: str
:return: List of samples with 'data_root' attached
:rtype: List[Dict[str, Any]]
**Example:**
.. code-block:: python
handler = ImageGenCoTRewardHandler()
data = handler.load_data("path/to/ImageGen-CoT-Reward.json")
"""
raw_data = []
with open(path, 'rb') as f:
raw_data = json.load(f)
data_root = os.path.dirname(path)
for item in raw_data:
item['data_root'] = data_root
logger.info(f"Loaded {len(raw_data)} samples from {path}")
return raw_data
[docs] def get_media_info(self, item: Dict[str, Any]) -> Dict[str, Dict[str, str]]:
"""
Extract path info for the two images.
:param item: A data item from load_data
:type item: Dict[str, Any]
:return: Dict containing local paths for 'image0' and 'image1'
:rtype: Dict[str, Dict[str, str]]
**Example:**
.. code-block:: python
info = handler.get_media_info(item)
"""
data_root = item['data_root']
if not data_root:
raise ValueError("Missing 'data_root' in item. Cannot resolve image paths.")
images = item['images']
image0_full_path = os.path.join(data_root, images[0])
image1_full_path = os.path.join(data_root, images[1])
return {
'image0': {
'image_local_path': image0_full_path
},
'image1': {
'image_local_path': image1_full_path
},
}
[docs] def parse_item(
self,
item: Dict[str, Any],
media_content: Dict[str, Any],
config: Dict[str, Any] | None,
) -> Tuple[List[Dict], Dict]:
"""
Parse a single ImageGen-CoT-Reward item into message pairs.
:param item: Raw data item from ImageGen-CoT-Reward dataset.
:type item: Dict[str, Any]
:param media_content: Loaded image content (PIL images/bytes)
:type media_content: Dict[str, Any]
:param config: Configuration for max_pixels
:type config: Dict[str, Any]
:return: A tuple of (messages, metadata)
:rtype: Tuple[List[Dict], Dict]
**Example:**
.. code-block:: python
messages, other = handler.parse_item(item, media_content, config)
"""
image0 = media_content['image0']
image1 = media_content['image1']
if not all([image0, image1]):
raise ValueError("Missing visual content for 'image0' or 'image1'.")
# Get conversations from data item
conversations = item["conversations"]
system_prompt = conversations[0]['value']
response = conversations[-1]['value']
# Get max_pixels from config
max_pixels = config["max_pixels"]
# Build messages
messages = [{
"role": "system",
"content": system_prompt
}, {
"role": "user",
"content": [{
"type": "text",
"text": "**Image 1:**"
}, {
"type": "image",
"image": image0,
"max_pixels": max_pixels
}]
}, {
"role": "user",
"content": [{
"type": "text",
"text": "**Image 2:**"
}, {
"type": "image",
"image": image1,
"max_pixels": max_pixels
}]
}, {
"role": "assistant",
"content": [{
"type": "text",
"text": response
}]
}]
other = {
"source": item['source'],
"task_type": self.task_type,
"data_item": item,
"system_prompt": system_prompt,
"response": response,
}
return messages, other