Shortcuts

Source code for lightrft.utils.processor

from typing import Any, Callable, Dict, List
import torch
from tqdm import tqdm


[docs]def reward_normalization(objs: List[Dict[str, Any]]) -> None: """ Normalize reward values across a list of objects using z-score normalization. This function applies standardization (z-score normalization) to reward values, transforming them to have zero mean and unit variance. This helps stabilize training by ensuring rewards are on a consistent scale. :param objs: List of dictionaries, each containing a 'reward' key. :type objs: List[Dict[str, Any]] :return: None (modifies objs in-place). :rtype: None """ rewards = [float(obj["reward"]) for obj in objs] # Using float32 for efficiency; sufficient precision for reward normalization rewards = torch.tensor(rewards, dtype=torch.float32) rewards = (rewards - rewards.mean()) / rewards.std() for i, obj in enumerate(objs): obj["reward"] = rewards[i].item()
# Default reward prompt template for Conditional SFT DEFAULT_REWARD_PROMPT = "{input} <rm_score>: {reward} "
[docs]def conditional_sft_processor(args: Any, objs: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """ Process data for Conditional SFT by prepending reward information to inputs. Implements the Conditional SFT approach from the paper: "Conditional Language Policy: A General Framework for Steerable Multi-Objective Finetuning" (https://arxiv.org/abs/2308.12050) This technique conditions the model on reward scores during training, allowing it to generate outputs of varying quality based on the specified reward threshold. :param args: Arguments object containing 'reward_template' and 'normalize_reward' flags. :type args: Any :param objs: List of training examples with 'input', 'output', and 'reward' keys. :type objs: List[Dict[str, Any]] :return: Processed list of training examples. :rtype: List[Dict[str, Any]] """ if "reward_template" not in args or args.reward_template is None: reward_template = DEFAULT_REWARD_PROMPT else: reward_template = args.reward_template assert "{input}" in reward_template assert "{reward}" in reward_template if args.normalize_reward: reward_normalization(objs) for obj in tqdm(objs, desc="Conditional SFT process..."): input = obj["input"] reward = "{:.2f}".format(float(obj["reward"])) input = reward_template.replace("{reward}", reward).replace("{input}", input) obj["input"] = input return objs
[docs]def rejection_sampling_processor(args: Any, objs: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """ Process data using Rejection Sampling by selecting highest-reward output per input. Implements the Rejection Sampling approach from the paper: "Llama 2: Open Foundation and Fine-Tuned Chat Models" (https://arxiv.org/abs/2307.09288) This technique filters multiple candidate outputs per input, keeping only the one with the highest reward score. This creates a high-quality training dataset by rejecting lower-quality samples. :param args: Arguments object (unused but kept for API consistency). :type args: Any :param objs: List of examples with 'input', 'output', and 'reward' keys. :type objs: List[Dict[str, Any]] :return: List of examples with only the highest-reward output per unique input. :rtype: List[Dict[str, Any]] """ out = {} for obj in tqdm(objs, desc="Rejection Sampling process...."): input = obj["input"] output = obj["output"] reward = float(obj["reward"]) if input not in out: out[input] = {"output": output, "reward": reward} elif reward > out[input]["reward"]: out[input]["reward"] = reward out[input]["output"] = output return [{"input": k, "output": v["output"], "reward": v["reward"]} for k, v in out.items()]
[docs]def iterative_dpo_processor(args: Any, objs: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """ Process data for Iterative DPO by creating chosen/rejected pairs per input. Implements the Iterative DPO approach from: "Online Iterative Reinforcement Learning from Human Feedback with General Preference Model" (https://github.com/RLHFlow/Online-RLHF) For each unique input, this technique tracks the highest-reward (chosen) and lowest-reward (rejected) outputs to create preference pairs for Direct Preference Optimization (DPO) training. This enables iterative improvement through online RLHF. :param args: Arguments object (unused but kept for API consistency). :type args: Any :param objs: List of examples with 'input', 'output', and 'reward' keys. :type objs: List[Dict[str, Any]] :return: List of preference pairs with 'prompt', 'chosen', 'rejected', and reward values. :rtype: List[Dict[str, Any]] """ out = {} for obj in tqdm(objs, desc="Iterative DPO process...."): input = obj["input"] output = obj["output"] reward = float(obj["reward"]) if input not in out: out[input] = { "output": output, "chosen": output, "chosen_reward": reward, "rejected": output, "rejected_reward": reward, } elif reward > out[input]["chosen_reward"]: out[input]["chosen_reward"] = reward out[input]["chosen"] = output elif reward < out[input]["rejected_reward"]: out[input]["rejected_reward"] = reward out[input]["rejected"] = output return [{ "prompt": k, "chosen": v["chosen"], "chosen_reward": v["chosen_reward"], "rejected": v["rejected"], "rejected_reward": v["rejected_reward"], } for k, v in out.items()]
PROCESSORS = { "rs": rejection_sampling_processor, "csft": conditional_sft_processor, "iter_dpo": iterative_dpo_processor, }
[docs]def get_processor(name: str) -> Callable[[Any, List[Dict[str, Any]]], List[Dict[str, Any]]]: """ Retrieve a data processor function by name. :param name: Name of the processor ('rs', 'csft', or 'iter_dpo'). :type name: str :return: The corresponding processor function. :rtype: Callable[[Any, List[Dict[str, Any]]], List[Dict[str, Any]]] :raises ValueError: If the processor name doesn't exist. """ if name in PROCESSORS: return PROCESSORS[name] else: raise ValueError(f"Processor {name} does not exist.")