Source code for lightrft.utils.remote_rm_utils
import base64
import io
import time
import requests
import torch
from PIL import Image
from typing import Any, Dict, List, Optional, Union
from .logging_utils import init_logger
logger = init_logger(__name__)
[docs]def request_api_wrapper(url: str,
data: Dict[str, Any],
score_key: str = "rewards",
try_max_times: int = 5) -> Union[float, List[float]]:
"""
Synchronous request API wrapper for reward model scoring.
This function makes HTTP POST requests to a reward model API endpoint and
handles retries with exponential backoff for failed requests.
:param url: The API endpoint URL to send requests to
:type url: str
:param data: The request payload data as a dictionary
:type data: Dict[str, Any]
:param score_key: The key in the response JSON that contains the reward scores
:type score_key: str
:param try_max_times: Maximum number of retry attempts for failed requests
:type try_max_times: int
:return: Reward scores extracted from the API response, either as a single float
or a list of floats depending on the API response structure
:rtype: Union[float, List[float]]
:raises Exception: When all retry attempts fail after the maximum number of tries
Example::
score = request_api_wrapper(
url="http://localhost:8000/score",
data={"text": "Hello world"},
score_key="rewards",
try_max_times=5
)
"""
headers = {
"Content-Type": "application/json",
}
for _ in range(try_max_times):
try:
response = requests.post(url=url, json=data, headers=headers, timeout=3000)
response.raise_for_status() # Raise an HTTPError for bad responses
response = response.json()
assert score_key in response, f"{score_key} not in {response}"
return response.get(score_key)
except requests.RequestException as e:
logger.info(f"Request error, please check: {e}")
logger.info(f"Request error data: {data}")
except Exception as e:
logger.info(f"Unexpected error, please check: {e}")
time.sleep(1)
raise Exception(f"Request error for {try_max_times} times, returning None. Please check the API server.")
[docs]def remote_rm_fn(
api_url: str,
queries: List[str],
prompts: List[str],
labels: Optional[List[Any]] = None,
references: Optional[List[str]] = None,
raw_images: Optional[List[Optional[Union[Image.Image, List[Image.Image]]]]] = None,
score_key: str = "rewards"
) -> torch.Tensor:
"""
Remote reward model API function for scoring text and image inputs.
This function prepares data and sends requests to a remote reward model API,
supporting both text-only and multimodal (text + image) scoring scenarios.
:param api_url: Reward model API endpoint URL
:type api_url: str
:param queries: List of query strings with response templates
:type queries: List[str]
:param prompts: List of prompt strings for context
:type prompts: List[str]
:param labels: Optional list of labels for supervised scoring (currently unused)
:type labels: Optional[List[Any]]
:param references: Optional list of reference responses for comparison scoring
:type references: Optional[List[str]]
:param raw_images: Optional list of PIL Image objects or lists of PIL Image objects
for multimodal scoring. Each element can be None, a single image,
or a list of images.
:type raw_images: Optional[List[Optional[Union[Image.Image, List[Image.Image]]]]]
:param score_key: Key in the API response that contains the reward scores
:type score_key: str
:return: Tensor of reward scores for all input samples
:rtype: torch.Tensor
:raises Exception: When API requests fail after maximum retry attempts
Example::
# Text-only scoring
scores = remote_rm_fn(
api_url="http://localhost:8000/score",
queries=["What is 2+2?"],
prompts=["Calculate the following:"],
score_key="rewards"
)
# Multimodal scoring with images
scores = remote_rm_fn(
api_url="http://localhost:8000/score",
queries=["Describe this image"],
prompts=["Please analyze the image:"],
raw_images=[Image.open("image.jpg")],
score_key="rewards"
)
"""
data = {"queries": queries, "prompts": prompts}
if references is not None:
data["references"] = references
if raw_images is not None:
# print(f"=================raw_images: {raw_images}")
# Convert PIL images to bytes then to base64 strings
base64_images = []
for imgs in raw_images:
base64_imgs = []
if imgs is None:
base64_images.append(None)
continue
with io.BytesIO() as buf:
if isinstance(imgs, list):
for img in imgs:
if img.mode == "RGBA":
img = img.convert("RGB") # to RGB
img.save(buf, format='JPEG')
base64_imgs.append(base64.b64encode(buf.getvalue()).decode('utf-8'))
base64_images.append(base64_imgs)
else:
if imgs.mode == "RGBA":
imgs = imgs.convert("RGB") # to RGB
imgs.save(buf, format='JPEG')
base64_imgs.append(base64.b64encode(buf.getvalue()).decode('utf-8'))
base64_images.append(base64_imgs)
data["images"] = base64_images
scores = request_api_wrapper(url=api_url, data=data, score_key=score_key)
return torch.tensor(scores)