Source code for lightrft.trainer.fast_exp_maker
"""
FastExperienceMaker Module
This module provides an optimized experience maker for RLHF (Reinforcement Learning from Human Feedback)
that supports high-performance inference backends like VLLM and SGLang. It extends the base
NaiveExperienceMaker with enhanced features for multimodal data processing, reward computation,
and advantage estimation.
Key Features:
- VLLM/SGLang backend support for efficient text generation
- Multimodal (vision-language) data processing
- Multiple advantage estimation methods (GAE, RLOO, REINFORCE, Group Norm)
- Flexible reward model composition with custom reward functions
- Sample packing support for improved training efficiency
- Running reward normalization and advantage whitening
Classes:
MultimodalDataProcessor: Handles preprocessing of mixed text/image data
RewardComputationEngine: Manages reward model inference and aggregation
FastExperienceMaker: Main experience generation class
"""
import os
import time
import pathlib
import warnings
from typing import Callable, Dict, List, Tuple, Union, Optional
from dataclasses import dataclass
from copy import deepcopy
import torch
import numpy as np
from PIL import Image
from easydict import EasyDict
from vllm import SamplingParams
from lightrft.models.utils import (
compute_approx_kl,
compute_reward,
masked_mean,
unpacking_samples,
)
from lightrft.models.actor_modality import ActorModality, get_supported_parameters
from lightrft.trainer.experience_maker import (
Experience,
NaiveExperienceMaker,
Samples,
)
from lightrft.trainer.experience_maker_vl import (
ExperienceVL,
SamplesVL,
)
from lightrft.utils.remote_rm_utils import remote_rm_fn
from lightrft.utils import Timer, get_current_device
from .utils import RunningMoments, compute_clip_fraction, get_cpgd_advantages_returns, fire_sampling, vllm_ge_0130
from .advantage_calculator import get_advantage_calculator, normalize_advantages_cross_batch
from .image_utils import normalize_images, get_images_num
from .video_utils import normalize_videos, get_videos_num
# ============================================================================
# Data Structures
# ============================================================================
@dataclass
class _SamplesOutput:
"""
Lightweight dataclass for caching intermediate computation results during experience creation.
This structure serves as a unified container for all data flowing through the parallel
experience generation pipeline, including sequences, attention masks, multimodal inputs,
and model outputs (log probabilities, values, rewards).
Attributes:
sequences: Token ID sequences [batch_size, seq_len]
attention_mask: Attention mask for sequences
action_mask: Mask indicating which tokens are part of the generated response
num_actions: Number of action tokens per sequence
packed_seq_lens: Sequence lengths for packed samples (if packing enabled)
response_length: Length of generated responses
total_length: Total sequence length (prompt + response)
prompts: Original text prompts
labels: Optional labels for the samples
# Vision-Language Model (VLM) specific fields
pixel_values: Processed pixel values for images (Qwen-VL format)
pixel_values_videos: Processed pixel values for videos (Qwen-VL format)
image_grid_thw: Image grid dimensions [temporal, height, width]
video_grid_thw: Video grid dimensions [temporal, height, width]
raw_images: Original PIL images
references: Reference texts for evaluation
image_num: Number of images per sample
# Model inference outputs
action_log_probs: Log probabilities from actor model
base_action_log_probs: Log probabilities from initial/reference model
value: Value estimates from critic model
rewards: Reward scores from reward model(s)
kl: KL divergence between actor and reference policy
inputs_extra_kwargs: Additional model-specific inputs
prompt_and_output: Concatenated prompt+output text for reward models
"""
# Core sequence data
sequences: torch.Tensor
attention_mask: Optional[torch.Tensor]
action_mask: Optional[torch.Tensor]
num_actions: Union[list, torch.Tensor]
packed_seq_lens: Optional[Union[list, torch.Tensor]]
response_length: torch.Tensor
total_length: torch.Tensor
prompts: List[str]
labels: Optional[list]
# Vision-Language Model fields
pixel_values: Optional[torch.Tensor] = None
image_grid_thw: Optional[torch.Tensor] = None
pixel_values_videos: Optional[torch.Tensor] = None
video_grid_thw: Optional[torch.Tensor] = None
raw_images: Optional[list] = None
references: Optional[list] = None
image_num: Optional[List[int]] = None
video_num: Optional[List[int]] = None
# Model outputs
action_log_probs: Optional[torch.Tensor] = None
base_action_log_probs: Optional[torch.Tensor] = None
value: Optional[torch.Tensor] = None
rewards: Optional[torch.Tensor] = None
reward_metrics: Optional[Dict[str, torch.Tensor]] = None # Detailed reward metrics
kl: Optional[torch.Tensor] = None
action_entropy: Optional[torch.Tensor] = None # Entropy for high-entropy token filtering
inputs_extra_kwargs: Optional[dict] = None
prompt_and_output: Optional[List[str]] = None
# ============================================================================
# Helper Classes
# ============================================================================
class MultimodalDataProcessor:
"""
Handles preprocessing of mixed text-only and image-text multimodal data.
This processor separates text-only and multimodal samples, processes them through
appropriate pipelines (tokenizer vs. multimodal processor), then merges results
back in original order to maintain batch consistency.
Key responsibilities:
- Normalize image inputs (file paths, PIL images, bytes)
- Separate text-only and image-text samples
- Process each modality through appropriate pipeline
- Expand samples by n_samples_per_prompt factor
- Reconstruct original batch ordering
Args:
tokenizer: Tokenizer for text-only samples
processor: Multimodal processor for image-text samples
prompt_max_len: Maximum prompt length for truncation
"""
def __init__(self, tokenizer, processor, prompt_max_len: int):
"""
Initialize the multimodal data processor.
:param tokenizer: HuggingFace tokenizer for text processing
:type tokenizer: transformers.PreTrainedTokenizer
:param processor: Multimodal processor (e.g., Qwen-VL processor)
:type processor: Union[transformers.ProcessorMixin, Any]
:param prompt_max_len: Maximum allowed prompt length
:type prompt_max_len: int
"""
self.tokenizer = tokenizer
self.processor = processor
self.prompt_max_len = prompt_max_len
def process_multimodal_batch(
self,
all_prompts: List[str],
all_images: List,
all_references: Optional[List[str]],
images_num: List[int],
n_samples_per_prompt: int,
all_videos: Optional[List],
videos_num: Optional[List[int]],
) -> EasyDict:
"""
Process multimodal batch - following original implementation exactly.
This method is a direct port of the original process_multimodal_data to ensure
functional equivalence.
:param all_prompts: List of text prompts
:type all_prompts: List[str]
:param all_images: List of images (PIL.Image or None)
:type all_images: List[Union[List[PIL.Image.Image], None]]
:param all_references: Optional reference texts
:type all_references: Optional[List[str]]
:param images_num: Number of images per sample
:type images_num: List[int]
:param n_samples_per_prompt: Number of samples to generate per prompt
:type n_samples_per_prompt: int
:param all_videos: List of videos (List[str] or None)
:type all_videos: Optional[List[Union[List[str], None]]]
:param videos_num: Number of videos per sample
:type videos_num: Optional[List[int]]
:return: Dictionary containing processed data
:rtype: EasyDict
"""
N = n_samples_per_prompt
L = len(all_prompts)
# Ensure all_images and all_videos are iterable even if None
if all_images is None:
all_images = [None] * L
if all_videos is None:
all_videos = [None] * L
# ===== Stage 1: Separation =====
all_prompts_text, all_prompts_multimodal = [], []
all_images_valid = []
all_videos_valid = []
text_idx = []
for idx, (prompt, image, video) in enumerate(zip(all_prompts, all_images, all_videos)):
if image is None and video is None:
all_prompts_text.append(prompt)
text_idx.append(idx)
else:
all_prompts_multimodal.append(prompt)
all_images_valid.append(image)
all_videos_valid.append(video)
# ===== Stage 2: Expansion =====
all_prompts_text = sum([[p] * N for p in all_prompts_text], [])
all_prompts_multimodal = sum([[p] * N for p in all_prompts_multimodal], [])
all_images_valid = [img for img in all_images_valid for _ in range(N)]
all_videos_valid = [vid for vid in all_videos_valid for _ in range(N)]
all_images_num = sum([[num] * N for num in images_num], []) if images_num is not None else [0] * (L * N)
all_videos_num = sum([[num] * N for num in videos_num], []) if videos_num is not None else [0] * (L * N)
# ===== Stage 3-A: Text-only processing =====
if len(all_prompts_text) > 0:
inputs_text = self.tokenizer(
all_prompts_text,
max_length=self.prompt_max_len,
truncation=True,
add_special_tokens=False,
)
all_prompt_token_ids_text = inputs_text["input_ids"]
else:
all_prompt_token_ids_text = []
# Initialize multimodal variables for text-only compatibility
all_prompt_token_ids_multimodal = []
all_images_pixel_values_multimodal = None
all_videos_pixel_values_multimodal = None
all_images_grid_thw_multimodal = None
all_videos_grid_thw_multimodal = None
# ===== Stage 3-B: Multimodal processing =====
if len(all_prompts_multimodal) > 0:
assert self.processor is not None, "Processor required for multimodal data"
flat_images = []
for img_item in all_images_valid:
if isinstance(img_item, list):
flat_images.extend(img_item)
elif img_item is not None:
flat_images.append(img_item)
flat_videos = []
for vid_item in all_videos_valid:
if isinstance(vid_item, list):
flat_videos.extend(vid_item)
elif vid_item is not None:
flat_videos.append(vid_item)
processor_kwargs = {
"text": all_prompts_multimodal.copy(),
"add_special_tokens": False,
"max_length": self.prompt_max_len,
"truncation": True,
}
if flat_images:
processor_kwargs["images"] = flat_images
if flat_videos:
processor_kwargs["videos"] = flat_videos
inputs_multimodal = self.processor(**processor_kwargs)
all_prompt_token_ids_multimodal = inputs_multimodal["input_ids"]
all_images_pixel_values_multimodal = inputs_multimodal.get("pixel_values", None)
all_videos_pixel_values_multimodal = inputs_multimodal.get("pixel_values_videos", None)
all_images_grid_thw_multimodal = inputs_multimodal.get("image_grid_thw", None)
all_videos_grid_thw_multimodal = inputs_multimodal.get("video_grid_thw", None)
# ===== Stage 4: Merge back in original order =====
total_samples = L * N
all_prompts_out = [None] * total_samples
all_images_out = [None] * total_samples
all_videos_out = [None] * total_samples
all_prompt_token_ids_out = [None] * total_samples
all_images_grid_thw_list = [None] * total_samples
all_videos_grid_thw_list = [None] * total_samples
# 4-A: Fill text-only
text_ptr = 0
for orig_idx in text_idx:
for n in range(N):
gid = orig_idx * N + n
all_prompts_out[gid] = all_prompts_text[text_ptr]
all_prompt_token_ids_out[gid] = all_prompt_token_ids_text[text_ptr]
# Ensure (0, 3) shape for cat
all_images_grid_thw_list[gid] = torch.empty((0, 3), dtype=torch.long)
all_videos_grid_thw_list[gid] = torch.empty((0, 3), dtype=torch.long)
text_ptr += 1
# 4-B: Fill multimodal
multi_ptr = 0
image_grid_ptr = 0
video_grid_ptr = 0
for orig_idx in range(L):
if orig_idx in text_idx:
continue
for n in range(N):
gid = orig_idx * N + n
all_prompts_out[gid] = all_prompts_multimodal[multi_ptr]
all_images_out[gid] = all_images_valid[multi_ptr]
all_videos_out[gid] = all_videos_valid[multi_ptr]
all_prompt_token_ids_out[gid] = all_prompt_token_ids_multimodal[multi_ptr]
# Handle image_grid_thw: extract rows based on all_images_num
num_images = all_images_num[gid]
if num_images > 0 and all_images_grid_thw_multimodal is not None:
all_images_grid_thw_list[gid] = all_images_grid_thw_multimodal[image_grid_ptr:image_grid_ptr +
num_images]
image_grid_ptr += num_images
else:
all_images_grid_thw_list[gid] = torch.empty((0, 3), dtype=torch.long)
# Handle video_grid_thw: extract rows based on all_videos_num
num_videos = all_videos_num[gid]
if num_videos > 0 and all_videos_grid_thw_multimodal is not None:
all_videos_grid_thw_list[gid] = all_videos_grid_thw_multimodal[video_grid_ptr:video_grid_ptr +
num_videos]
video_grid_ptr += num_videos
else:
all_videos_grid_thw_list[gid] = torch.empty((0, 3), dtype=torch.long)
multi_ptr += 1
# Concatenate grid_thw (using cat instead of stack to support multi-image/video)
all_images_grid_thw = (
torch.cat(all_images_grid_thw_list, dim=0)
if len(all_images_grid_thw_list) > 0 else torch.empty((0, 3), dtype=torch.long)
)
all_videos_grid_thw = (
torch.cat(all_videos_grid_thw_list, dim=0)
if len(all_videos_grid_thw_list) > 0 else torch.empty((0, 3), dtype=torch.long)
)
# Expand references
if all_references is not None:
all_references = sum([[ref] * N for ref in all_references], [])
return EasyDict(
all_prompt_token_ids=all_prompt_token_ids_out,
all_prompts=all_prompts_out,
all_images=all_images_out,
all_videos=all_videos_out,
all_images_num=all_images_num,
all_videos_num=all_videos_num,
all_images_pixel_values=all_images_pixel_values_multimodal,
all_videos_pixel_values=all_videos_pixel_values_multimodal,
all_images_grid_thw=all_images_grid_thw,
all_videos_grid_thw=all_videos_grid_thw,
all_references=all_references,
)
class RewardComputationEngine:
"""
Manages reward model inference and score aggregation.
This engine handles both local and remote reward models, supporting:
- Remote HTTP/gRPC reward models
- Local PyTorch reward models
- Custom reward functions and rules
- Multi-model ensemble with custom aggregation
- Optimized batch processing with sample filtering
The engine uses a three-stage pipeline:
1. Gather: Collect or filter samples based on reward recipe
2. Process: Run forward pass through reward model(s)
3. Aggregate: Combine scores using reward_fn
Args:
reward_model: Single reward model or list of models
remote_rm_url: List of remote reward model URLs
custom_reward_func: Custom Python function for reward computation
reward_fn: Aggregation function for multiple reward models
reward_fn_label_map: Mapping from reward model names to indices
tokenizer: Tokenizer for decoding sequences
strategy: Training strategy (for model loading/offloading)
packing_samples: Whether samples are packed
"""
def __init__(
self,
reward_model,
remote_rm_url: Optional[List[str]],
custom_reward_func: Optional[Callable],
reward_fn: Optional[Callable],
reward_fn_label_map: Optional[Dict],
reward_recipe: Optional[Dict],
tokenizer,
strategy,
packing_samples: bool,
):
"""
Initialize the reward computation engine.
:param reward_model: Single reward model or list of models
:type reward_model: Union[torch.nn.Module, List[torch.nn.Module]]
:param remote_rm_url: List of remote reward model URLs
:type remote_rm_url: Optional[List[str]]
:param custom_reward_func: Custom Python function for reward computation
:type custom_reward_func: Optional[Callable]
:param reward_fn: Aggregation function for multiple reward models
:type reward_fn: Optional[Callable]
:param reward_fn_label_map: Mapping from reward model names to indices
:type reward_fn_label_map: Optional[Dict[str, int]]
:param reward_recipe: Recipe configuration for reward computation
:type reward_recipe: Optional[Dict]
:param tokenizer: Tokenizer for decoding sequences
:type tokenizer: transformers.PreTrainedTokenizer
:param strategy: Training strategy (for model loading/offloading)
:type strategy: Any
:param packing_samples: Whether samples are packed
:type packing_samples: bool
"""
self.reward_model = reward_model
self.remote_rm_url = remote_rm_url
self.custom_reward_func = custom_reward_func
self.reward_fn = reward_fn
self.reward_fn_label_map = reward_fn_label_map or {}
self.reward_recipe = reward_recipe or {}
self.tokenizer = tokenizer
self.strategy = strategy
self.packing_samples = packing_samples
# Build inverse label map for quick lookup
self.inv_label_map = {idx: key for key, idx in self.reward_fn_label_map.items()}
# Configuration flag for optimized filtering engine
self.use_filtering_engine = False # TODO: Enable after testing
def compute_rewards(
self,
outputs: List[_SamplesOutput],
vlm_mode: bool,
device: torch.device,
) -> None:
"""
Compute rewards for all samples and store in outputs[i].rewards.
This method dispatches to the appropriate computation path based on
whether remote or local reward models are used.
:param outputs: List of sample outputs to compute rewards for
:type outputs: List[_SamplesOutput]
:param vlm_mode: Whether in vision-language mode
:type vlm_mode: bool
:param device: Device to place reward tensors on
:type device: torch.device
"""
if self.remote_rm_url:
self._compute_remote_rewards(outputs, vlm_mode, device)
else:
self._compute_local_rewards(outputs, vlm_mode, device)
def _compute_remote_rewards(
self,
outputs: List[_SamplesOutput],
vlm_mode: bool,
device: torch.device,
) -> None:
"""
Compute rewards using remote reward models.
This path maintains serial processing for compatibility with HTTP/gRPC APIs.
:param outputs: Sample outputs to compute rewards for
:type outputs: List[_SamplesOutput]
:param vlm_mode: Vision-language mode flag
:type vlm_mode: bool
:param device: Target device for tensors
:type device: torch.device
"""
for output in outputs:
# Decode sequences to text
sequences = (
output.sequences
if not self.packing_samples else unpacking_samples(output.sequences, output.packed_seq_lens)
)
queries = self.tokenizer.batch_decode(sequences, skip_special_tokens=False)
reward_tensors = []
# Custom reward function
if self.custom_reward_func:
if vlm_mode:
scores = self.custom_reward_func(queries, output.prompts, output.references)
else:
scores = self.custom_reward_func(queries, output.prompts, output.labels)
reward_tensors.append(torch.as_tensor(scores, dtype=torch.float32, device=device))
# Remote reward models
for rm_url in self.remote_rm_url[len(reward_tensors):]:
if vlm_mode:
scores = remote_rm_fn(
rm_url,
queries=queries,
prompts=output.prompts,
references=output.references,
raw_images=output.raw_images,
)
else:
scores = remote_rm_fn(
rm_url,
queries=queries,
prompts=output.prompts,
labels=output.labels,
)
reward_tensors.append(torch.as_tensor(scores, dtype=torch.float32, device=device))
# Aggregate rewards
output.rewards = (self.reward_fn(reward_tensors) if len(reward_tensors) > 1 else reward_tensors[0])
def _compute_local_rewards(
self,
outputs: List[_SamplesOutput],
vlm_mode: bool,
device: torch.device,
) -> None:
"""
Compute rewards using local reward models.
Implements batched processing for efficiency. Supports both standard
PyTorch models and custom engine models with optional sample filtering.
:param outputs: Sample outputs to compute rewards for
:type outputs: List[_SamplesOutput]
:param vlm_mode: Vision-language mode flag
:type vlm_mode: bool
:param device: Target device for tensors
:type device: torch.device
"""
# Ensure reward_model is a list
is_multi_rm = isinstance(self.reward_model, (list, tuple))
rm_list = list(self.reward_model) if is_multi_rm else [self.reward_model]
# Load all PyTorch models to GPU
for rm in rm_list:
if isinstance(rm, torch.nn.Module):
self.strategy.reload_model(rm)
# Compute rewards for each RM
# all_rewards_list[rm_idx][micro_batch_idx] = Tensor(batch_size,)
all_rewards_list = []
for rm_idx, rm in enumerate(rm_list):
micro_batch_rewards = self._compute_single_rm_rewards(rm, rm_idx, outputs, vlm_mode, device)
all_rewards_list.append(micro_batch_rewards)
# Offload model immediately after use
if isinstance(rm, torch.nn.Module):
self.strategy.offload_model(rm)
# Aggregate rewards across RMs for each micro-batch
self._aggregate_rewards(outputs, all_rewards_list, is_multi_rm)
def _compute_single_rm_rewards(
self,
rm,
rm_idx: int,
outputs: List[_SamplesOutput],
vlm_mode: bool,
device: torch.device,
) -> List[torch.Tensor]:
"""
Compute rewards for a single reward model across all micro-batches.
:param rm: Reward model instance
:type rm: Union[torch.nn.Module, Any]
:param rm_idx: Index of this RM in the RM list
:type rm_idx: int
:param outputs: Sample outputs
:type outputs: List[_SamplesOutput]
:param vlm_mode: Vision-language mode flag
:type vlm_mode: bool
:param device: Target device
:type device: torch.device
:return: List of reward tensors, one per micro-batch
:rtype: List[torch.Tensor]
"""
# Check if this is a custom engine model (non-torch base_model)
is_custom_engine = (
isinstance(rm, torch.nn.Module) and hasattr(rm, "base_model")
and not isinstance(rm.base_model, torch.nn.Module)
)
if is_custom_engine and self.use_filtering_engine:
return self._compute_filtered_rewards(rm, rm_idx, outputs, device)
elif is_custom_engine:
return self._compute_batched_custom_engine_rewards(rm, outputs, device)
elif isinstance(rm, torch.nn.Module):
return self._compute_standard_torch_rewards(rm, outputs, vlm_mode, device)
else:
raise ValueError(f"Unsupported reward model type: {type(rm)}")
def _compute_filtered_rewards(
self,
rm,
rm_idx: int,
outputs: List[_SamplesOutput],
device: torch.device,
) -> List[torch.Tensor]:
"""
Compute rewards using optimized filtering (only process relevant samples).
This optimization filters samples based on the reward recipe, only running
the forward pass on samples that actually need this specific RM.
:param rm: Custom engine reward model
:type rm: Any
:param rm_idx: RM index for label lookup
:type rm_idx: int
:param outputs: Sample outputs
:type outputs: List[_SamplesOutput]
:param device: Target device
:type device: torch.device
:return: List of reward tensors per micro-batch
:rtype: List[torch.Tensor]
"""
# Get RM key from inverse label map
rm_key = self.inv_label_map.get(rm_idx)
if rm_key is None:
raise ValueError(
f"Filtering engine requires a label map key for RM at index {rm_idx}, "
f"but none was found. Check your reward_fn_label_map configuration."
)
# ========== Gather Stage: Filter samples that need this RM ==========
flat_data = {
"prompt_and_output": [],
"raw_images": [],
"image_num": [],
"references": [],
"labels": [],
}
needed_positions = [] # [(micro_batch_idx, sample_idx), ...]
for mb_idx, output in enumerate(outputs):
for samp_idx, label in enumerate(output.labels):
# Check if this sample's recipe requires this RM
needs_rm = any(
typ == "model" and key == rm_key and float(weight) != 0.0
for typ, key, weight in self.reward_recipe.get(label, [])
)
if needs_rm:
needed_positions.append((mb_idx, samp_idx))
flat_data["prompt_and_output"].append(output.prompt_and_output[samp_idx])
flat_data["raw_images"].append(output.raw_images[samp_idx])
flat_data["image_num"].append(output.image_num[samp_idx])
flat_data["references"].append(output.references[samp_idx])
flat_data["labels"].append(output.labels[samp_idx])
# ========== Process Stage: Compute or skip ==========
if not needed_positions:
# No samples need this RM, return zeros for all micro-batches
return [torch.zeros(len(output.labels), dtype=torch.float32, device=device) for output in outputs]
# Run single forward pass on filtered samples
rm_output = rm(
None,
None,
prompt_and_outputs=flat_data["prompt_and_output"],
raw_images=flat_data["raw_images"],
img_num=flat_data["image_num"],
references=flat_data["references"],
labels=flat_data["labels"],
)
filtered_scores = (rm_output["score"] if isinstance(rm_output, dict) else rm_output).to(device)
# ========== Scatter Stage: Reconstruct micro-batch structure ==========
micro_batch_rewards = [
torch.zeros(len(output.labels), dtype=torch.float32, device=device) for output in outputs
]
for (mb_idx, samp_idx), score in zip(needed_positions, filtered_scores):
micro_batch_rewards[mb_idx][samp_idx] = score
return micro_batch_rewards
def _compute_batched_custom_engine_rewards(
self,
rm,
outputs: List[_SamplesOutput],
device: torch.device, # noqa: ARG002 (unused but kept for API consistency)
) -> List[torch.Tensor]:
"""
Compute rewards using custom engine with full batch processing (legacy path).
:param rm: Custom engine reward model
:type rm: Any
:param outputs: Sample outputs
:type outputs: List[_SamplesOutput]
:param device: Target device (unused but kept for API consistency)
:type device: torch.device
:return: List of reward tensors per micro-batch
:rtype: List[torch.Tensor]
"""
# Flatten all micro-batches into single batch
flat_data = {
"prompt_and_output": [],
"raw_images": [],
"image_num": [],
"references": [],
"labels": [],
}
for output in outputs:
flat_data["prompt_and_output"].extend(output.prompt_and_output)
flat_data["raw_images"].extend(output.raw_images)
flat_data["image_num"].extend(output.image_num)
flat_data["references"].extend(output.references)
flat_data["labels"].extend(output.labels)
# Single forward pass
rm_output = rm(
None,
None,
prompt_and_outputs=flat_data["prompt_and_output"],
raw_images=flat_data["raw_images"],
img_num=flat_data["image_num"],
references=flat_data["references"],
labels=flat_data["labels"],
)
all_scores = rm_output["score"] if isinstance(rm_output, dict) else rm_output
# Split back into micro-batches
batch_sizes = [len(output.prompt_and_output) for output in outputs]
return list(all_scores.split(batch_sizes))
def _compute_standard_torch_rewards(
self,
rm,
outputs: List[_SamplesOutput],
vlm_mode: bool, # noqa: ARG002 (kept for future VLM-specific logic)
device: torch.device,
) -> List[torch.Tensor]:
"""
Compute rewards using standard PyTorch reward model.
Processes each micro-batch sequentially.
:param rm: PyTorch reward model
:type rm: torch.nn.Module
:param outputs: Sample outputs
:type outputs: List[_SamplesOutput]
:param vlm_mode: Vision-language mode flag (reserved for future use)
:type vlm_mode: bool
:param device: Target device
:type device: torch.device
:return: List of reward tensors per micro-batch
:rtype: List[torch.Tensor]
"""
micro_batch_rewards = []
for output in outputs:
# Unpack sequences if needed
sequences = (
output.sequences
if not self.packing_samples else unpacking_samples(output.sequences, output.packed_seq_lens)
)
# Forward pass
rm_output = rm(
sequences,
output.attention_mask,
prompt_and_output=output.prompt_and_output,
raw_images=output.raw_images,
img_num=output.image_num,
**output.inputs_extra_kwargs,
)
score = rm_output["score"] if isinstance(rm_output, dict) else rm_output
micro_batch_rewards.append(torch.as_tensor(score, dtype=torch.float32, device=device))
return micro_batch_rewards
def _aggregate_rewards(
self,
outputs: List[_SamplesOutput],
all_rewards_list: List[List[torch.Tensor]],
is_multi_rm: bool,
) -> None:
"""
Aggregate rewards from multiple RMs and store in outputs.
:param outputs: Sample outputs (modified in-place)
:type outputs: List[_SamplesOutput]
:param all_rewards_list: Nested list [rm_idx][micro_batch_idx] -> Tensor
:type all_rewards_list: List[List[torch.Tensor]]
:param is_multi_rm: Whether using multiple reward models
:type is_multi_rm: bool
"""
num_micro_batches = len(outputs)
num_rms = len(all_rewards_list)
for mb_idx in range(num_micro_batches):
# Collect rewards from all RMs for this micro-batch
same_batch_rewards = [all_rewards_list[rm_idx][mb_idx] for rm_idx in range(num_rms)]
if is_multi_rm:
# Use custom aggregation function
sequences = (
outputs[mb_idx].sequences if not self.packing_samples else
unpacking_samples(outputs[mb_idx].sequences, outputs[mb_idx].packed_seq_lens)
)
queries = self.tokenizer.batch_decode(sequences, skip_special_tokens=False)
rewards, reward_metrics = self.reward_fn(
model_reward_list=same_batch_rewards,
labels=outputs[mb_idx].labels,
queries=queries,
refs=outputs[mb_idx].references,
label_map=self.reward_fn_label_map,
)
outputs[mb_idx].rewards = rewards
outputs[mb_idx].reward_metrics = reward_metrics
else:
# Single RM, use score directly
outputs[mb_idx].rewards = same_batch_rewards[0]
outputs[mb_idx].reward_metrics = None
# ============================================================================
# Main Experience Maker
# ============================================================================
[docs]class FastExperienceMaker(NaiveExperienceMaker):
"""
Optimized experience maker with VLLM/SGLang support and advanced RL features.
This class extends NaiveExperienceMaker to provide:
- High-performance inference via VLLM or SGLang backends
- Multimodal (vision-language) data processing
- Multiple advantage estimation algorithms (GAE, RLOO, REINFORCE, Group Norm)
- Flexible reward model composition with custom aggregation
- Sample packing for improved training efficiency
- Running reward normalization and advantage whitening/clipping
The experience generation pipeline:
1. Sample Generation: Use inference engine to generate responses
2. Shard-Parallel Preprocessing: Distribute samples across shards
3. Model Inference: Batch forward through actor, critic, initial, and reward models
4. Shard-Parallel Postprocessing: Gather results back
5. Reward Processing: Apply transformations (normalization, shaping, filtering)
6. Advantage Estimation: Compute advantages and returns
Args:
packing_samples: Whether to pack multiple sequences into single batch
processor: Multimodal processor for vision-language models
*args, **kwargs: Arguments passed to parent NaiveExperienceMaker
"""
[docs] def __init__(self, *args, packing_samples: bool = False, processor=None, **kwargs):
"""
Initialize FastExperienceMaker.
:param args: Positional arguments for NaiveExperienceMaker
:type args: tuple
:param packing_samples: Enable sample packing for efficiency
:type packing_samples: bool
:param processor: Multimodal processor (required for VLM models)
:type processor: Optional[Any]
:param kwargs: Keyword arguments for NaiveExperienceMaker
:type kwargs: dict
"""
super().__init__(*args, **kwargs)
# Core configuration
self.backend_mp_group = self.strategy.engine_mp_group
self.backend = self.strategy.args.engine_type
self.packing_samples = packing_samples
self.processor = processor
# Initialize tokenizer (extract from processor if needed)
if self.processor is not None:
self.tokenizer = getattr(self.processor, "tokenizer", self.processor)
# Initialize running reward normalization
if self.strategy.args.reward_running_norm:
self.reward_running_moments = RunningMoments()
else:
self.reward_running_moments = None
# Initialize advantage calculator
advantage_estimator = self.strategy.config.advantage_estimator
self.advantage_calculator = get_advantage_calculator(advantage_estimator, self.strategy.config)
# Initialize helper modules
if self.processor is not None:
self.multimodal_processor = MultimodalDataProcessor(
tokenizer=self.tokenizer,
processor=self.processor,
prompt_max_len=self.prompt_max_len,
)
else:
self.multimodal_processor = None
self.reward_engine = RewardComputationEngine(
reward_model=self.reward_model,
remote_rm_url=self.remote_rm_url,
custom_reward_func=getattr(self, "custom_reward_func", None),
reward_fn=self.reward_fn,
reward_fn_label_map=getattr(self, "reward_fn_label_map", None),
reward_recipe=getattr(self, "reward_recipe", None),
tokenizer=self.tokenizer,
strategy=self.strategy,
packing_samples=self.packing_samples,
)
# Cache actor's supported parameters based on its modality
# Default to VISION_LANGUAGE for backward compatibility with models without modality attribute
actor_modality = self.actor.modality
self._actor_supported_params = get_supported_parameters(actor_modality)
# ========================================================================
# Public API Methods
# ========================================================================
@torch.no_grad()
def make_experience_list(
self,
all_prompts: List[str],
all_images: Optional[List] = None,
all_videos: Optional[List] = None,
all_references: Optional[List[str]] = None,
all_labels: Optional[List] = None,
**generate_kwargs,
) -> List[ExperienceVL]:
"""
Generate a list of experiences from prompts and optional multimodal inputs.
This is the main entry point for experience generation. It orchestrates the
entire pipeline from sampling to advantage computation.
:param all_prompts: List of text prompts
:type all_prompts: List[str]
:param all_images: Optional images for multimodal generation
:type all_images: Optional[List]
:param all_references: Optional reference texts for evaluation
:type all_references: Optional[List[str]]
:param all_labels: Optional labels for samples
:type all_labels: Optional[List]
:param all_videos: Optional videos for multimodal generation
:type all_videos: Optional[List]
:param generate_kwargs: Generation parameters (temperature, max_new_tokens, etc.)
:type generate_kwargs: dict
:return: List of Experience or ExperienceVL objects with computed advantages and returns
:rtype: List[Union[Experience, ExperienceVL]]
"""
config = self.strategy.config
# Normalize images if provided
if all_images is not None:
if self.multimodal_processor is None:
raise ValueError(
"Multimodal data (images) provided but processor was not initialized. "
"Please provide a processor when initializing FastExperienceMaker for VLM support."
)
all_images = normalize_images(all_images)
# Normalize videos if provided
if all_videos is not None:
if self.multimodal_processor is None:
raise ValueError(
"Multimodal data (videos) provided but processor was not initialized. "
"Please provide a processor when initializing FastExperienceMaker for VLM support."
)
all_videos = normalize_videos(all_videos)
# Get image counts
images_num = (get_images_num(all_images) if self.multimodal_processor and all_images is not None else None)
# Get video counts
videos_num = (get_videos_num(all_videos) if self.multimodal_processor and all_videos is not None else None)
# ========== Stage 1: Sample Generation ==========
Timer.start(' generate_samples')
samples_list = self.generate_samples(
all_prompts,
all_images=all_images,
images_num=images_num,
all_videos=all_videos,
videos_num=videos_num,
all_references=all_references,
all_labels=all_labels,
**generate_kwargs,
)
Timer.stop(' generate_samples')
torch.distributed.barrier()
torch.cuda.synchronize()
# ========== Stage 2: Shard-Parallel Preprocessing ==========
all_samples = self.strategy.sp_data_processor.preprocess(samples_list)
# ========== Stage 3: Model Inference ==========
Timer.start(' make_experience')
experiences = self._make_experience_list_by_model(all_samples)
Timer.stop(' make_experience')
# ========== Stage 4: Shard-Parallel Postprocessing ==========
experiences = self.strategy.sp_data_processor.postprocess(experiences)
# ========== Stage 5: Reward Processing ==========
experiences, rewards = self._process_experiences( # GRPO's -mean / std operation is performed in this method
experiences, generate_kwargs.get("max_new_tokens", 1024)
)
# ========== Stage 6: Multi-Image/Video Handling ==========
if (images_num is not None and not all(num == 1 for num in images_num)) or \
(videos_num is not None and not all(num == 1 for num in videos_num)):
# Expand image_num by n_samples_per_prompt
expanded_images_num = sum([[num] * config.n_samples_per_prompt
for num in images_num], []) if images_num is not None else None
expanded_videos_num = sum([[num] * config.n_samples_per_prompt
for num in videos_num], []) if videos_num is not None else None
self._process_multi_image_video_thws(experiences, expanded_images_num, expanded_videos_num)
# ========== Stage 7: Advantage Computation ==========
experiences = self._compute_advantages_and_returns(experiences, rewards, generate_kwargs)
return experiences
@torch.no_grad()
def generate_samples(
self,
all_prompts: List[str],
all_images: Optional[List] = None,
all_videos: Optional[List] = None,
images_num: Optional[List[int]] = None,
videos_num: Optional[List[int]] = None,
all_references: Optional[List[str]] = None,
all_labels: Optional[List] = None,
**generate_kwargs,
) -> List[Samples]:
"""
Generate samples using the inference engine (VLLM or SGLang).
This method handles:
- Sampling parameter configuration
- Multimodal data processing
- Inference engine invocation
- Output processing into Samples format
:param all_prompts: List of text prompts
:type all_prompts: List[str]
:param all_images: Optional images for VLM
:type all_images: Optional[List]
:param images_num: Number of images per prompt
:type images_num: Optional[List[int]]
:param all_references: Reference texts
:type all_references: Optional[List[str]]
:param all_labels: Sample labels
:type all_labels: Optional[List]
:param all_videos: Optional videos for VLM
:type all_videos: Optional[List]
:param videos_num: Number of videos per prompt
:type videos_num: Optional[List[int]]
:param generate_kwargs: Generation parameters (temperature, max_new_tokens, etc.)
:type generate_kwargs: dict
:return: List of Samples or SamplesVL objects
:rtype: List[Union[Samples, SamplesVL]]
"""
assert self.strategy.inference_engine is not None, "Inference engine required"
torch.cuda.synchronize()
start_time = time.time()
config = self.strategy.config
is_multimodal = all_images is not None or all_videos is not None
n_samples = config.n_samples_per_prompt
# Initialize multimodal-specific variables to None
all_images_num = None
all_videos_num = None
all_images_pixel_values = None
all_videos_pixel_values = None
all_images_grid_thw = None
all_videos_grid_thw = None
# ========== Configure Sampling Parameters ==========
if config.engine_type == "vllm":
# For vllm>=0.13.0, truncate_prompt_tokens must not exceed max_model_len
# For older versions, we can use 8192 directly without validation
if vllm_ge_0130():
max_model_len = self.strategy.inference_engine.llm_engine.model_config.max_model_len
truncate_tokens = min(8192, max_model_len)
else:
truncate_tokens = 8192
sampling_params = SamplingParams(
temperature=generate_kwargs.get("temperature", 1.0),
top_p=generate_kwargs.get("top_p", 1.0),
top_k=generate_kwargs.get("top_k", -1),
max_tokens=generate_kwargs.get("max_new_tokens", 1024),
min_tokens=generate_kwargs.get("min_new_tokens", 1),
skip_special_tokens=generate_kwargs.get("skip_special_tokens", False),
include_stop_str_in_output=True,
ignore_eos=os.environ.get("IGNORE_EOS", "0") == "1",
truncate_prompt_tokens=truncate_tokens,
)
elif config.engine_type == "sglang":
sampling_params = dict(
n=1,
temperature=generate_kwargs.get("temperature", 1.0),
top_p=generate_kwargs.get("top_p", 1.0),
top_k=generate_kwargs.get("top_k", -1),
max_new_tokens=generate_kwargs.get("max_new_tokens", 1024),
presence_penalty=0.0,
frequency_penalty=0.0,
repetition_penalty=1.0,
skip_special_tokens=generate_kwargs.get("skip_special_tokens", False),
spaces_between_special_tokens=True,
ignore_eos=os.environ.get("IGNORE_EOS", "0") == "1",
)
else:
raise ValueError(f"Unsupported engine type: {config.engine_type}")
# ========== Expand Labels ==========
if all_labels is not None:
all_labels = sum([[label] * n_samples for label in all_labels], [])
# ========== Process Multimodal Data ==========
if is_multimodal:
processed_data = self.multimodal_processor.process_multimodal_batch(
all_prompts=all_prompts,
all_images=all_images,
all_references=all_references,
images_num=images_num,
n_samples_per_prompt=n_samples,
all_videos=all_videos,
videos_num=videos_num,
)
all_prompt_token_ids = processed_data["all_prompt_token_ids"]
all_prompts = processed_data["all_prompts"]
all_images = processed_data["all_images"]
all_videos = processed_data["all_videos"]
all_images_num = processed_data["all_images_num"]
all_videos_num = processed_data["all_videos_num"]
all_images_grid_thw = processed_data["all_images_grid_thw"]
all_videos_grid_thw = processed_data["all_videos_grid_thw"]
all_images_pixel_values = processed_data["all_images_pixel_values"]
all_videos_pixel_values = processed_data["all_videos_pixel_values"]
all_references = processed_data.get("all_references", None)
else:
# Text-only processing
tokenized = self.tokenize_fn(all_prompts, self.prompt_max_len, padding=False)
all_prompt_token_ids = sum([[token_ids] * n_samples for token_ids in tokenized["input_ids"]], [])
# ========== Generate via Inference Engine ==========
# Call fire_sampling function or direct generation
try:
if hasattr(self.strategy.args, 'use_fire') and self.strategy.args.use_fire:
# Use FIRE sampling (Flaming-hot Initiation with Regular Execution)
# According to the paper (https://arxiv.org/abs/2410.21236), FIRE only changes
# the temperature for the first token. All other sampling parameters (top_k, top_p, etc.)
# are kept the same between first token and remaining tokens.
all_outputs = fire_sampling(
all_prompt_token_ids=all_prompt_token_ids,
generate_fn=generate_fn, # noqa: TODO
engine_type=config.engine_type,
first_token_temperature=generate_kwargs.get("first_token_temperature", 10.0),
temperature=generate_kwargs.get("temperature", 1.0),
# Note: first_token_top_k and first_token_top_p are deprecated and ignored
# The function will use top_k and top_p from sampling_params for both stages
is_multimodal=is_multimodal,
all_prompts=all_prompts,
all_images=all_images,
all_videos=all_videos,
all_images_num=all_images_num,
all_videos_num=all_videos_num,
sampling_params=sampling_params,
)
else:
# maybe this can be called in if and else respectively? or like this?
# Use original single-shot generation
all_outputs = self.strategy.gather_and_generate(
sampling_params=sampling_params,
all_prompt_token_ids=all_prompt_token_ids,
all_prompts=all_prompts if is_multimodal else None,
sleep_engine=self.strategy.args.enable_engine_sleep,
all_images=all_images if is_multimodal else None,
all_videos=all_videos if is_multimodal else None,
images_num=all_images_num if is_multimodal else None,
videos_num=all_videos_num if is_multimodal else None,
)
except ValueError as e:
if "prompt" in str(e) and "too long" in str(e):
self.strategy.print(f"[Skip] {e}")
return None # Return None, subsequent experience_maker will ignore
else:
raise
# ========== Process Outputs into Samples ==========
samples_list = []
image_patch_idx = 0
video_patch_idx = 0
image_start_idx = 0
video_start_idx = 0
for i in range(0, len(all_outputs), config.micro_rollout_batch_size):
micro_batch_outputs = all_outputs[i:i + config.micro_rollout_batch_size]
micro_batch_prompts = all_prompts[i:i + config.micro_rollout_batch_size]
# Extract micro-batch data
micro_batch_grid_thw = None
micro_batch_video_grid_thw = None
micro_batch_raw_images = None
if is_multimodal:
rollout_image_count = sum(all_images_num[i:i + config.micro_rollout_batch_size])
micro_batch_grid_thw = all_images_grid_thw[image_start_idx:image_start_idx + rollout_image_count]
micro_batch_raw_images = all_images[i:i + config.micro_rollout_batch_size]
image_start_idx += rollout_image_count
rollout_video_count = sum(all_videos_num[i:i + config.micro_rollout_batch_size])
micro_batch_video_grid_thw = all_videos_grid_thw[video_start_idx:video_start_idx + rollout_video_count]
video_start_idx += rollout_video_count
micro_batch_references = (all_references[i:i + config.micro_rollout_batch_size] if all_references else None)
micro_batch_labels = (all_labels[i:i + config.micro_rollout_batch_size] if all_labels else None)
# Build samples
if not self.packing_samples:
sample, updated_patch_idx, updated_video_patch_idx = self._build_unpacked_sample(
outputs=micro_batch_outputs,
prompts=micro_batch_prompts,
labels=micro_batch_labels,
references=micro_batch_references,
is_multimodal=is_multimodal,
grid_thw=micro_batch_grid_thw,
video_grid_thw=micro_batch_video_grid_thw,
raw_images=micro_batch_raw_images,
pixel_values=all_images_pixel_values if is_multimodal else None,
pixel_values_videos=all_videos_pixel_values if is_multimodal else None,
images_num=all_images_num[i:i + config.micro_rollout_batch_size] if is_multimodal else None,
videos_num=all_videos_num[i:i + config.micro_rollout_batch_size] if is_multimodal else None,
image_patch_idx=image_patch_idx,
video_patch_idx=video_patch_idx,
)
# Update patch indices from the returned values
if updated_patch_idx is not None:
image_patch_idx = updated_patch_idx
if updated_video_patch_idx is not None:
video_patch_idx = updated_video_patch_idx
samples_list.append(sample)
else:
# Packed samples
sample = self._build_packed_sample(
outputs=micro_batch_outputs,
prompts=micro_batch_prompts,
labels=micro_batch_labels,
references=micro_batch_references,
)
samples_list.append(sample)
# Report timing
torch.cuda.synchronize()
gen_time = torch.tensor(time.time() - start_time, device=get_current_device())
torch.distributed.all_reduce(gen_time, op=torch.distributed.ReduceOp.MAX)
self.strategy.print(f"***Rollout engine generation time (global max): {gen_time.item():.4f}s")
self.strategy.report_memory("after rollout engine generation")
return samples_list
def get_advantages_and_returns(
self,
values: torch.Tensor,
rewards: torch.Tensor,
action_mask: torch.Tensor,
gamma: float,
lambd: float,
) -> Tuple[torch.Tensor, torch.Tensor, float]:
"""
Compute advantages and returns using Generalized Advantage Estimation (GAE).
Extends parent method with advantage whitening and clipping.
:param values: Value estimates from critic
:type values: torch.Tensor
:param rewards: Reward signals
:type rewards: torch.Tensor
:param action_mask: Mask for valid action positions
:type action_mask: torch.Tensor
:param gamma: Discount factor
:type gamma: float
:param lambd: GAE lambda parameter
:type lambd: float
:return: Tuple of (advantages, returns, advantage_clip_fraction)
:rtype: Tuple[torch.Tensor, torch.Tensor, float]
"""
# Call parent GAE implementation
advantages, returns = super().get_advantages_and_returns(values, rewards, action_mask, gamma, lambd)
config = self.strategy.config
# Advantage whitening (normalization)
if config.advantages_norm:
masked_adv = torch.masked_select(advantages, action_mask)
adv_mean = masked_adv.mean()
adv_std = masked_adv.std()
advantages = (advantages - adv_mean) / (adv_std + 1e-9)
# Advantage clipping
advantage_clip_frac = 0.0
if config.advantage_clip > 0:
advantages = torch.clamp(advantages, -config.advantage_clip, config.advantage_clip)
advantage_clip_frac = compute_clip_fraction(advantages, config.advantage_clip, -config.advantage_clip)
return advantages, returns, advantage_clip_frac
# ========================================================================
# Private Helper Methods
# ========================================================================
def _process_multi_image_video_thws(
self,
experiences: List[ExperienceVL],
images_num: Optional[List[int]] = None,
videos_num: Optional[List[int]] = None,
) -> None:
"""
Process image_grid_thws and video_grid_thws for multi-image/video scenarios.
Ensures len(experience.sequences) == len(experience.image_grid_thws) by
converting the stacked tensor into a list of per-sequence tensors.
:param experiences: List of experiences to modify in-place
:type experiences: List[ExperienceVL]
:param images_num: Number of images per sample (expanded by n_samples_per_prompt)
:type images_num: Optional[List[int]]
:param videos_num: Number of videos per sample (expanded by n_samples_per_prompt)
:type videos_num: Optional[List[int]]
"""
config = self.strategy.config
for i, experience in enumerate(experiences):
# Get image and video counts for this micro-batch
start_idx = i * config.micro_rollout_batch_size
end_idx = (i + 1) * config.micro_rollout_batch_size
if images_num is not None:
micro_images_num = images_num[start_idx:end_idx]
if sum(micro_images_num) > 0 and experience.image_grid_thws is not None:
image_grid_thw_list = []
image_grid_thws = experience.image_grid_thws
image_grid_thws_unbind = torch.unbind(image_grid_thws)
thw_idx = 0
for num in micro_images_num:
if num > 0:
stacked_thw = torch.stack(image_grid_thws_unbind[thw_idx:thw_idx + num], dim=0).to("cuda")
image_grid_thw_list.append(stacked_thw)
thw_idx += num
else:
image_grid_thw_list.append(None)
experience.image_grid_thws = image_grid_thw_list
else:
experience.image_grid_thws = [None] * len(micro_images_num)
if videos_num is not None:
micro_videos_num = videos_num[start_idx:end_idx]
if sum(micro_videos_num) > 0 and experience.video_grid_thws is not None:
video_grid_thw_list = []
video_grid_thws = experience.video_grid_thws
video_grid_thws_unbind = torch.unbind(video_grid_thws)
v_thw_idx = 0
for num in micro_videos_num:
if num > 0:
v_stacked_thw = torch.stack(video_grid_thws_unbind[v_thw_idx:v_thw_idx + num],
dim=0).to("cuda")
video_grid_thw_list.append(v_stacked_thw)
v_thw_idx += num
else:
video_grid_thw_list.append(None)
experience.video_grid_thws = video_grid_thw_list
else:
experience.video_grid_thws = [None] * len(micro_videos_num)
def _process_experiences(
self,
experiences: List[ExperienceVL],
max_new_tokens: int,
) -> Tuple[List[ExperienceVL], List[torch.Tensor]]:
"""
Apply reward transformations and filtering to experiences.
Handles:
- Overlong sequence penalty
- Dynamic sampling filtering
- Advantage estimation-specific reward shaping (RLOO, REINFORCE, Group Norm)
:param experiences: List of experiences to process
:type experiences: List[Union[Experience, ExperienceVL]]
:param max_new_tokens: Maximum generation length
:type max_new_tokens: int
:return: Tuple of (processed_experiences, shaped_rewards)
:rtype: Tuple[List[Union[Experience, ExperienceVL]], List[torch.Tensor]]
"""
config = self.strategy.config
rewards = torch.cat([exp.info["reward"] for exp in experiences])
# ========== Overlong Sequence Penalty ==========
if config.overlong_buffer:
expected_len = max_new_tokens - config.overlong_buffer_len
actual_lens = torch.cat([exp.action_mask.sum(dim=1) for exp in experiences])
exceed_len = actual_lens - expected_len
# Penalty: clamp(-exceed_len / buffer_len * penalty_factor, max=0)
penalty = torch.clamp(
-exceed_len / config.overlong_buffer_len * config.overlong_buffer_penalty_factor, max=0.0
)
rewards += penalty
# ========== Dynamic Sampling Warning ==========
if config.dynamic_sampling and config.advantage_estimator in ["rloo", "reinforce_baseline"]:
warnings.warn(f"dynamic_sampling not implemented for {config.advantage_estimator}, ignoring", UserWarning)
# ========== Advantage Estimator-Specific Shaping ==========
# Use calculator's preprocess_rewards method
return self.advantage_calculator.preprocess_rewards(rewards, experiences, max_new_tokens)
def _compute_advantages_and_returns(
self,
experiences: List[ExperienceVL],
rewards: List[torch.Tensor],
generate_kwargs: Dict,
) -> List[ExperienceVL]:
"""
Compute advantages and returns for each experience.
Applies reward normalization/clipping, KL penalty, and advantage estimation
based on the configured method (GAE, CPGD, REINFORCE, etc.).
:param experiences: List of experiences to process
:type experiences: List[Union[Experience, ExperienceVL]]
:param rewards: List of reward tensors
:type rewards: List[torch.Tensor]
:param generate_kwargs: Generation parameters (contains gamma, lambd)
:type generate_kwargs: Dict
:return: List of experiences with advantages and returns filled in
:rtype: List[Union[Experience, ExperienceVL]]
"""
config = self.strategy.config
for experience, reward in zip(experiences, rewards):
reward = reward.to("cuda")
processed_reward = reward.clone() # TODO:check
# ========== Reward Normalization ==========
if self.reward_running_moments:
self.reward_running_moments.update(processed_reward)
if config.reward_running_norm_minus_mean:
processed_reward = ((processed_reward - self.reward_running_moments.mean) /
self.reward_running_moments.std)
else:
processed_reward /= self.reward_running_moments.std
# ========== Reward Clipping ==========
if config.reward_clip > 0:
experience.info["reward_clip_frac"] = compute_clip_fraction(
processed_reward, config.reward_clip, -config.reward_clip
)
processed_reward = torch.clamp(processed_reward, -config.reward_clip, config.reward_clip)
# ========== Final Reward (with KL penalty) ==========
final_reward = compute_reward(
processed_reward,
self.kl_ctl.value,
experience.kl,
action_mask=experience.action_mask,
num_actions=experience.info["num_actions"],
)
# ========== Advantage Estimation ==========
# Compute advantages and returns using calculator
gamma = generate_kwargs.pop("gamma", 1.0)
experience.advantages, experience.returns, info_dict = self.advantage_calculator.compute(
experience,
final_reward,
gamma=gamma,
generate_kwargs=generate_kwargs,
)
# Update experience info with calculator's info dict
experience.info.update(info_dict)
# ========== Store Episode Return ==========
if not self.packing_samples:
experience.info["return"] = final_reward.sum(dim=-1)
else:
experience.info["return"] = torch.tensor([r.sum() for r in final_reward], device=final_reward.device)
# Cleanup
experience.kl = None
del experience.info["num_actions"]
# ========== Cross-batch Advantage Normalization ==========
# Use the utility function for cross-batch normalization
experiences = normalize_advantages_cross_batch(experiences, self.advantage_estimator, self.strategy.args)
return experiences
@torch.no_grad()
def _make_experience_list_by_model(
self,
all_samples: List[Union[Samples, SamplesVL]],
) -> List[Union[Experience, ExperienceVL]]:
"""
Batch forward pass through all models to create experiences.
This method implements role-based batching to avoid frequent model switching.
Processing order:
1. Actor (log probabilities)
2. Initial model (reference log probabilities)
3. Critic (value estimates)
4. Reward model(s) (rewards)
5. Assemble Experience objects
:param all_samples: List of Samples/SamplesVL from generate_samples
:type all_samples: List[Union[Samples, SamplesVL]]
:return: List of Experience/ExperienceVL objects with model outputs filled in
:rtype: List[Union[Experience, ExperienceVL]]
"""
device = get_current_device()
vlm_mode = isinstance(all_samples[0], SamplesVL)
# ========== Stage 0: Preprocessing ==========
outputs = [self._preprocess_sample(sample, vlm_mode, device) for sample in all_samples]
# ========== Stage 1: Actor Forward ==========
Timer.start(' actor_logprob')
# Check if we need to compute entropy for high-entropy token filtering
need_entropy = hasattr(self.actor, 'high_entropy_token_ratio') and self.actor.high_entropy_token_ratio > 0.0
for output in outputs:
if need_entropy:
# Request full output to get action_entropy
action_log_probs, model_output = self.actor(
output.sequences,
output.num_actions,
output.attention_mask,
packed_seq_lens=output.packed_seq_lens,
return_output=True,
**output.inputs_extra_kwargs
)
output.action_log_probs = action_log_probs
# Extract action_entropy if available
if "action_entropy" in model_output:
output.action_entropy = model_output["action_entropy"]
else:
output.action_log_probs = self.actor(
output.sequences,
output.num_actions,
output.attention_mask,
packed_seq_lens=output.packed_seq_lens,
**output.inputs_extra_kwargs
)
Timer.stop(' actor_logprob')
# ========== Stage 2: Initial Model ==========
if self.initial_model is not None:
self.strategy.reload_model(self.initial_model)
for output in outputs:
output.base_action_log_probs = self.initial_model(
output.sequences,
output.num_actions,
output.attention_mask,
packed_seq_lens=output.packed_seq_lens,
**output.inputs_extra_kwargs
)
self.strategy.offload_model(self.initial_model)
# ========== Stage 3: Critic ==========
if self.critic is not None:
self.strategy.reload_model(self.critic)
for output in outputs:
output.value = self.critic(
output.sequences, output.num_actions, output.attention_mask, **output.inputs_extra_kwargs
)
self.strategy.offload_model(self.critic)
# ========== Stage 4: Reward Models ==========
self.reward_engine.compute_rewards(outputs, vlm_mode, device)
# ========== Stage 5: Assemble Experiences ==========
return [self._pack_experience(output, vlm_mode) for output in outputs]
def _preprocess_sample(
self,
sample: Union[Samples, SamplesVL],
vlm: bool,
device: torch.device,
) -> _SamplesOutput:
"""
Convert a Samples object to _SamplesOutput for processing.
:param sample: Input sample
:type sample: Union[Samples, SamplesVL]
:param vlm: Vision-language mode flag
:type vlm: bool
:param device: Target device
:type device: torch.device
:return: _SamplesOutput with data ready for model inference
:rtype: _SamplesOutput
"""
# Extract common fields
sequences = sample.sequences.to(device)
attention_mask = sample.attention_mask.to(device)
action_mask = sample.action_mask
num_actions = sample.num_actions
packed_seq_lens = sample.packed_seq_lens
response_length = sample.response_length
total_length = sample.total_length
prompts = sample.prompts
labels = getattr(sample, "labels", None)
references = sample.references
output_texts = getattr(sample, "output_texts", None)
# Build extra kwargs for VLM based on actor's modality
# Only include parameters that the actor's modality supports
extra_kwargs = {}
if vlm:
# Candidate parameters to pass
candidate_params = {
"pixel_values": sample.pixel_values,
"image_grid_thw": sample.image_grid_thws,
"pixel_values_videos": sample.pixel_values_videos,
"video_grid_thw": sample.video_grid_thws,
}
# Filter to only include supported parameters
extra_kwargs = {
key: value
for key, value in candidate_params.items()
if key in self._actor_supported_params
}
# Fix Qwen-VL image token count bug
self._fix_qwen_vl_image_tokens(sequences, sample, vlm)
return _SamplesOutput(
sequences=sequences,
attention_mask=attention_mask,
action_mask=action_mask,
num_actions=num_actions,
packed_seq_lens=packed_seq_lens,
response_length=response_length,
total_length=total_length,
prompts=prompts,
labels=labels,
pixel_values=getattr(sample, "pixel_values", None),
image_grid_thw=getattr(sample, "image_grid_thws", None),
pixel_values_videos=getattr(sample, "pixel_values_videos", None),
video_grid_thw=getattr(sample, "video_grid_thws", None),
raw_images=getattr(sample, "raw_images", None),
image_num=getattr(sample, "image_num", None),
video_num=getattr(sample, "video_num", None),
references=references,
inputs_extra_kwargs=extra_kwargs,
prompt_and_output=([p + (o or "") for p, o in zip(prompts, output_texts)] if output_texts else None),
)
def _fix_qwen_vl_image_tokens(
self,
sequences: torch.Tensor,
sample: SamplesVL,
vlm: bool,
) -> None:
"""
Fix Qwen-VL image token count mismatch.
In some cases, the number of image tokens in sequences doesn't match
the number of pixel value patches. This fixes the discrepancy by replacing
extra image tokens with padding tokens.
:param sequences: Token sequence (modified in-place)
:type sequences: torch.Tensor
:param sample: Original sample
:type sample: SamplesVL
:param vlm: Vision-language mode flag
:type vlm: bool
"""
if not vlm or sample.pixel_values is None:
return
config = self.strategy.unwrap_model(self.actor.model).config
image_token_id = config.image_token_id
num_tokens = (sequences == image_token_id).sum()
num_patches = sample.pixel_values.shape[0] // 4
if num_tokens != num_patches:
self.strategy.print(
f"[Warning] Mismatch found during rollout step. Fixing sequences. "
f"Tokens: {num_tokens}, Patches: {num_patches}"
)
pad_token_id = self.tokenizer.pad_token_id
diff = num_tokens - num_patches
token_positions = (sequences == image_token_id).nonzero()
# Replace extra tokens from the end
for k in range(diff):
pos = token_positions[-(k + 1)]
sequences[pos[0], pos[1]] = pad_token_id
def _pack_experience(
self,
output: _SamplesOutput,
vlm: bool,
) -> Union[Experience, ExperienceVL]:
"""
Pack model outputs into an Experience object.
:param output: Processed sample output
:type output: _SamplesOutput
:param vlm: Vision-language mode flag
:type vlm: bool
:return: Experience or ExperienceVL object
:rtype: Union[Experience, ExperienceVL]
"""
# Compute KL divergence
if self.initial_model is not None and not self.strategy.args.use_kl_loss:
# Note: When use_kl_loss is True, KL is used as a loss term;
# when False, KL is added to reward as augmentation
kl = compute_approx_kl(
output.action_log_probs,
output.base_action_log_probs,
action_mask=output.action_mask,
kl_estimator=self.strategy.args.kl_estimator,
)
else:
kl = torch.zeros_like(output.action_log_probs)
# Compute mean KL
if not self.packing_samples:
kl_mean = masked_mean(kl, output.action_mask, dim=-1)
else:
kl_mean = torch.tensor(
[each.mean() for each in unpacking_samples(kl, output.num_actions)],
device=kl.device,
)
# Clear base log probs if not needed
if not self.strategy.args.use_kl_loss:
output.base_action_log_probs = None
# Build info dict
info = dict(
kl=kl_mean,
reward=output.rewards,
response_length=output.response_length,
total_length=output.total_length,
num_actions=output.num_actions,
)
# Add reward_metrics if available
if output.reward_metrics is not None:
info['reward_metrics'] = output.reward_metrics
# Create Experience object
if vlm:
return ExperienceVL(
sequences=output.sequences,
pixel_values=output.pixel_values,
image_grid_thws=output.image_grid_thw,
raw_images=output.raw_images,
pixel_values_videos=output.pixel_values_videos,
video_grid_thws=output.video_grid_thw,
action_log_probs=output.action_log_probs,
base_action_log_probs=output.base_action_log_probs,
values=output.value,
returns=None, # returns (filled later)
advantages=None, # advantages (filled later)
attention_mask=output.attention_mask,
action_mask=output.action_mask,
info=info,
kl=kl,
action_entropy=output.action_entropy,
)
else:
return Experience(
sequences=output.sequences,
action_log_probs=output.action_log_probs,
base_action_log_probs=output.base_action_log_probs,
values=output.value,
returns=None, # returns (filled later)
advantages=None, # advantages (filled later)
attention_mask=output.attention_mask,
action_mask=output.action_mask,
info=info,
kl=kl,
action_entropy=output.action_entropy,
)
def _build_unpacked_sample(
self,
outputs: List,
prompts: List[str],
labels: Optional[List],
references: Optional[List],
is_multimodal: bool,
**kwargs,
) -> Tuple[Union[Samples, SamplesVL], Optional[int], Optional[int]]:
"""
Build unpacked sample (one sequence per row with padding).
Sample format:
| [PAD] [PAD] prompt_token ... | response_token ... [EOS] [PAD] |
:param outputs: Engine outputs
:type outputs: List
:param prompts: Text prompts
:type prompts: List[str]
:param labels: Sample labels
:type labels: Optional[List]
:param references: Reference texts
:type references: Optional[List]
:param is_multimodal: Whether in VLM mode
:type is_multimodal: bool
:param kwargs: Additional VLM-specific arguments
:type kwargs: dict
:return: Tuple of (Samples/SamplesVL object, updated image_patch_idx, updated video_patch_idx)
:rtype: Tuple[Union[Samples, SamplesVL], Optional[int], Optional[int]]
"""
# Find max lengths
max_input_len = max(len(out.prompt_token_ids) for out in outputs)
max_output_len = max(len(out.output_token_ids) for out in outputs)
pad_token_id = self.tokenizer.pad_token_id
eos_token_id = self.tokenizer.eos_token_id
sequences = []
all_output_ids = []
# VLM data structures
if is_multimodal:
pixel_values = []
image_grid_thw_list = []
all_img_num = []
pixel_values_videos = []
video_grid_thw_list = []
all_vid_num = []
grid_thw = kwargs["grid_thw"]
raw_images = kwargs["raw_images"]
pixel_values_tensor = kwargs["pixel_values"]
images_num = kwargs["images_num"]
image_patch_idx = kwargs["image_patch_idx"]
video_grid_thw = kwargs["video_grid_thw"]
pixel_values_videos_tensor = kwargs["pixel_values_videos"]
videos_num = kwargs["videos_num"]
video_patch_idx = kwargs["video_patch_idx"]
local_grid_idx = 0
local_video_grid_idx = 0
# Process each output
for j, output in enumerate(outputs):
# Left-pad input
input_len = len(output.prompt_token_ids)
input_ids = [pad_token_id] * (max_input_len - input_len) + list(output.prompt_token_ids)
# Right-pad output
output_len = len(output.output_token_ids)
output_ids = list(output.output_token_ids) + [pad_token_id] * (max_output_len - output_len)
all_output_ids.append(output.output_token_ids)
# Process images/videos for this sample
if is_multimodal:
if images_num is not None:
image_num = images_num[j]
all_img_num.append(image_num)
for img_idx in range(image_num):
grid = grid_thw[local_grid_idx + img_idx]
num_patch = grid[0] * grid[1] * grid[2]
image_grid_thw_list.append(grid.clone().unsqueeze(0))
if num_patch > 0:
pixel_slice = pixel_values_tensor[image_patch_idx:image_patch_idx + num_patch]
pixel_values.append(pixel_slice.clone())
image_patch_idx += num_patch
local_grid_idx += image_num
if videos_num is not None:
video_num = videos_num[j]
all_vid_num.append(video_num)
for vid_idx in range(video_num):
grid = video_grid_thw[local_video_grid_idx + vid_idx]
num_patch = grid[0] * grid[1] * grid[2]
video_grid_thw_list.append(grid.clone().unsqueeze(0))
if num_patch > 0:
pixel_slice = pixel_values_videos_tensor[video_patch_idx:video_patch_idx + num_patch]
pixel_values_videos.append(pixel_slice.clone())
video_patch_idx += num_patch
local_video_grid_idx += video_num # Concatenate input and output
sequences.append(input_ids + output_ids)
# Decode output texts
output_texts = self.tokenizer.batch_decode(all_output_ids)
# Process sequences
sequences = torch.tensor(sequences)
sequences, attention_mask, action_mask = self.actor.process_sequences(
sequences, max_input_len, eos_token_id, pad_token_id
)
sequences = sequences.to("cuda")
attention_mask = attention_mask.to("cuda")
action_mask = action_mask.to("cuda")
if not is_multimodal:
return Samples(
sequences=sequences,
attention_mask=attention_mask,
action_mask=action_mask,
num_actions=action_mask.size(1),
packed_seq_lens=None,
response_length=action_mask.float().sum(dim=-1),
total_length=attention_mask.float().sum(dim=-1),
prompts=prompts,
labels=labels,
references=references,
pad_len=None,
), None, None # Return None for patch indices
else:
# Process VLM pixel values
pixel_values = (
torch.cat(pixel_values, dim=0).cuda() if pixel_values and pixel_values[0].shape[0] > 0 else None
)
pixel_values_videos = (
torch.cat(pixel_values_videos, dim=0).cuda()
if pixel_values_videos and pixel_values_videos[0].shape[0] > 0 else None
)
return SamplesVL(
sequences=sequences,
attention_mask=attention_mask,
action_mask=action_mask,
image_grid_thws=(torch.cat(image_grid_thw_list, dim=0).to("cuda") if image_grid_thw_list else None),
video_grid_thws=(torch.cat(video_grid_thw_list, dim=0).to("cuda") if video_grid_thw_list else None),
raw_images=raw_images,
pixel_values=pixel_values,
pixel_values_videos=pixel_values_videos,
num_actions=action_mask.size(1),
packed_seq_lens=None,
response_length=action_mask.float().sum(dim=-1),
total_length=attention_mask.float().sum(dim=-1),
references=references,
labels=labels,
prompts=prompts,
output_texts=output_texts,
image_num=all_img_num,
video_num=all_vid_num,
), image_patch_idx, video_patch_idx
def _build_packed_sample(
self,
outputs: List,
prompts: List[str],
labels: Optional[List],
references: Optional[List],
) -> Samples:
"""
Build packed sample (multiple sequences concatenated without padding).
Sample format:
| prompt1 response1 [EOS] | prompt2 response2 [EOS] | prompt3 ... |
:param outputs: Engine outputs
:type outputs: List
:param prompts: Text prompts
:type prompts: List[str]
:param labels: Sample labels
:type labels: Optional[List]
:param references: Reference texts
:type references: Optional[List]
:return: Samples object with packed sequences
:rtype: Samples
"""
sequences = []
packed_seq_lens = []
attention_mask = []
num_actions = []
for idx, output in enumerate(outputs):
input_len = len(output.prompt_token_ids)
output_len = len(output.output_token_ids)
packed_seq_lens.append(input_len + output_len)
sequences.extend(output.prompt_token_ids + list(output.output_token_ids))
attention_mask.extend([idx + 1] * (input_len + output_len))
num_actions.append(max(1, output_len))
sequences = torch.tensor(sequences, device="cuda").unsqueeze(0)
attention_mask = torch.tensor(attention_mask, device="cuda").unsqueeze(0)
response_length = torch.tensor(num_actions, device="cuda", dtype=torch.float)
total_length = torch.tensor(packed_seq_lens, device="cuda", dtype=torch.float)
return Samples(
sequences=sequences,
attention_mask=attention_mask,
action_mask=None,
num_actions=num_actions,
packed_seq_lens=packed_seq_lens,
response_length=response_length,
total_length=total_length,
prompts=prompts,
labels=labels,
references=references,
pad_len=None,
)