Source code for lightrft.datasets.image_reward_db
import os
import copy
import json
import random
import glob
from typing import List, Dict, Any, Tuple
from itertools import combinations
from collections import defaultdict
from loguru import logger
from .utils import BaseDataHandler
[docs]class ImageRewardDBHandler(BaseDataHandler):
"""
Data Handler for ImageRewardDB dataset.
Paper: https://arxiv.org/abs/2304.05977
Dataset Repo: https://huggingface.co/datasets/zai-org/ImageRewardDB
"""
task_type = "text-to-image"
[docs] def load_data(self, path: str) -> List[Dict[str, Any]]:
"""
Load ImageRewardDB shards and build preference pairs.
This method scans the given dataset root for ImageRewardDB JSON shards and
aggregates image entries by ``prompt_id``. For each prompt group, it
constructs all unordered pairs of images and determines the preferred
image based on the ``rank`` field (smaller is better, i.e., ``1`` is best).
:param path: Path to the dataset root directory of ImageRewardDB.
:type path: str
:return: List of preference pair dictionaries.
:rtype: List[Dict[str, Any]]
**Example:**
.. code-block:: python
data = handler.load_data("path/to/ImageRewardDB")
"""
# Locate all JSON shard files
# Expected layout examples: train_01/train_01.json, train_02/train_02.json, ...
search_pattern = os.path.join(path, "**", "*.json")
json_files = glob.glob(search_pattern, recursive=True)
if not json_files:
print(f"No JSON files found under: {path}. Please verify the dataset path.")
return
print(f"Found {len(json_files)} JSON files. Starting to load data...")
# Aggregate entries by prompt_id
# Structure: { "prompt_id_1": [img_info1, img_info2, ...], ... }
grouped_data = defaultdict(list)
for json_path in json_files:
try:
with open(json_path, 'r', encoding='utf-8') as f:
data = json.load(f)
for item in data:
grouped_data[item['prompt_id']].append(item)
except Exception as e:
print(f"Error reading file {json_path}: {e}")
print(f"Aggregated {len(grouped_data)} unique prompt groups. Generating pairs...")
# Construct pairs
preference_pairs = []
skipped_files_count = 0
for pid, items in grouped_data.items():
# Filter out items with missing or empty image files
valid_items = []
dataset_root = os.path.dirname(os.path.dirname(path))
for item in items:
full_img_path = os.path.join(dataset_root, item['image_path'])
# import ipdb; ipdb.set_trace()
if os.path.exists(full_img_path) and os.path.getsize(full_img_path) > 0:
valid_items.append(item)
else:
# print(f"[Warning] Skipping missing or empty file: {full_img_path}")
skipped_files_count += 1
items = valid_items
# If a prompt group has fewer than 2 images, skip
if len(items) < 2:
continue
# Use combinations to generate all unordered pairs (e.g., AB and BA appear once)
for item_a, item_b in combinations(items, 2):
rank_a = item_a['rank']
rank_b = item_b['rank']
# Defensive: ensure rank fields are present
if rank_a is None or rank_b is None:
continue
# Decision rule: smaller rank value is better (1 is best)
if rank_a == rank_b:
continue
if rank_a < rank_b:
chosen = item_a
reject = item_b
else:
chosen = item_b
reject = item_a
# Build pair entry
pair_entry = {
"prompt_id": pid,
"prompt": chosen['prompt'],
"classification": chosen.get('classification', 'Unknown'),
"data_root": dataset_root,
# Chosen Image Info
"chosen_img": chosen['image_path'],
"rank_chosen": chosen['rank'],
"overall_rating_chosen": chosen.get('overall_rating'),
# Rejected Image Info
"rejected_img": reject['image_path'],
"rank_rejected": reject['rank'],
"overall_rating_rejected": reject.get('overall_rating'),
}
preference_pairs.append(pair_entry)
logger.info(f"Loaded {len(preference_pairs)} samples from ImageRewardDB.")
return preference_pairs
[docs] def get_media_info(self, item: Dict[str, Any]) -> Dict[str, Dict[str, str]]:
"""
Extract path info for chosen and rejected images.
:param item: A data item from load_data
:type item: Dict[str, Any]
:return: Dict containing local paths for 'preferred_image' and 'rejected_image'
:rtype: Dict[str, Dict[str, str]]
**Example:**
.. code-block:: python
info = handler.get_media_info(item)
"""
data_root = item['data_root']
# Build full local paths
full_path1 = os.path.join(data_root, item['chosen_img'])
full_path2 = os.path.join(data_root, item['rejected_img'])
# Make sure files exist
if not os.path.exists(full_path1) or not os.path.exists(full_path2):
return None
else:
return {
'preferred_image': {
'image_local_path': full_path1
},
'rejected_image': {
'image_local_path': full_path2
}
}
[docs] def parse_item(self, item: Dict[str, Any], media_content: Dict[str, Any],
config: Dict[str, Any]) -> Tuple[List[Dict], List[Dict], Dict]:
"""
Parse a single ImageRewardDB item into message pairs for ranking.
:param item: Raw data item from ImageRewardDB dataset.
:type item: Dict[str, Any]
:param media_content: Loaded media content with 'preferred_image' and 'rejected_image' keys.
:type media_content: Dict[str, Any]
:param config: Configuration dict with task instructions and max_pixels
:type config: Dict[str, Any]
:return: A tuple of (messages0, messages1, metadata)
:rtype: Tuple[List[Dict], List[Dict], Dict]
**Example:**
.. code-block:: python
msg0, msg1, other = handler.parse_item(item, media_content, config)
"""
# Get loaded visual content
preferred_image = media_content['preferred_image']
rejected_image = media_content['rejected_image']
if not all([preferred_image, rejected_image]):
raise ValueError("Missing visual content for 'preferred_image' or 'rejected_image'.")
# Get generation prompt from data item
prompt_text = item["prompt"]
if not prompt_text:
raise ValueError(f"Missing generation prompt in item: {item}")
# Get system prompts from config
task_instruction_template = config["task_instruction"]
task_instruction = task_instruction_template.format(prompt=prompt_text)
# Get max_pixels from config
max_pixels = config["max_pixels"]
# Random pick from "A" or "B" to avoid positional bias
preference = random.choice(["A", "B"])
if preference == "A": # "A" means image0 is preferred
image0, image1 = preferred_image, rejected_image
else:
image0, image1 = rejected_image, preferred_image
# Build messages
messages0 = [
{
"role": "system",
"content": copy.deepcopy(task_instruction)
},
{
"role": "user",
"content": [{
"type": "image",
"image": image0,
"max_pixels": max_pixels
} # to save memory
]
}
]
messages1 = [{
"role": "system",
"content": copy.deepcopy(task_instruction)
}, {
"role": "user",
"content": [{
"type": "image",
"image": image1,
"max_pixels": max_pixels
}]
}]
other = {
"preference": preference, # used for reward head labeling
"task_type": self.task_type,
"source": item["source"],
"prompt_id": item["prompt_id"],
"prompt": prompt_text,
"chosen_img": item["chosen_img"],
"rejected_img": item["rejected_img"],
"rank_chosen": item["rank_chosen"],
"rank_rejected": item["rank_rejected"],
"overall_rating_chosen": item["overall_rating_chosen"],
"overall_rating_rejected": item["overall_rating_rejected"],
"classification": item["classification"],
}
return messages0, messages1, other