Source code for lightrft.models.utils
"""
Utility functions for computing log probabilities from logits in PyTorch.
This module provides functions to efficiently calculate log probabilities
for token predictions, with optimizations to handle different data types
and reduce memory consumption. It also includes utilities for finding
linear modules in neural networks and handling position IDs for packed
sequences in transformer models.
The module is particularly useful for:
- Computing log probabilities from model logits with memory-efficient approaches
- Finding LoRA-injectable linear modules in various model architectures
- Handling position IDs in packed sequence scenarios for transformer models
"""
from typing import List, Optional, Union, Tuple, Dict, Callable
from loguru import logger
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as dist
from peft import LoraConfig, TaskType, get_peft_model
def find_all_linear_modules(model: "nn.Module", freeze_vision_tower: bool) -> List[str]:
"""
Find all linear modules that can be injected with LoRA (Low-Rank Adaptation).
This function scans through a neural network model to identify all linear layers
that are suitable for LoRA injection, while excluding certain forbidden modules
based on the model type. It handles various model architectures including ChatGLM,
LLaVA variants, Qwen2 VL models, and others.
:param model: The neural network model to scan for linear modules
:type model: nn.Module
:param freeze_vision_tower: Whether to freeze the vision tower components.
If True, vision-related modules will be added to forbidden list
:type freeze_vision_tower: bool
:return: List of linear module names that can be used for LoRA injection
:rtype: List[str]
Example::
>>> import torch.nn as nn
>>> model = nn.Sequential(nn.Linear(10, 5), nn.ReLU(), nn.Linear(5, 1))
>>> linear_modules = find_all_linear_modules(model, freeze_vision_tower=False)
>>> print(linear_modules) # ['Linear']
"""
model_type = getattr(model.config, "model_type", None)
forbidden = {"lm_head"}
if model_type == "chatglm":
forbidden.add("output_layer")
elif model_type in ["llava", "llava_next", "llava_next_video", "mllama", "paligemma", "video_llava"]:
forbidden.add("multi_modal_projector")
elif model_type in ["qwen2_vl", "qwen2_5_vl"]:
forbidden.add("merger")
if freeze_vision_tower:
if model_type in ["mllama"]:
forbidden.add("vision_model")
elif model_type in ["qwen2_vl", "qwen2_5_vl"]:
forbidden.add("visual")
else:
forbidden.add("vision_tower")
module_names = set()
for name, module in model.named_modules():
if any(fm in name for fm in forbidden):
continue
if "Linear" in module.__class__.__name__ and "Embedding" not in module.__class__.__name__:
module_names.add(name.split(".")[-1])
return list(module_names)
def entropy_from_logits(logits: torch.Tensor) -> torch.Tensor:
"""
Compute entropy from logits using Categorical distribution for efficient calculation.
This function calculates the entropy of the probability distribution over the vocabulary
for each token position. Higher entropy indicates more uncertainty in token prediction,
which corresponds to "forking tokens" that determine reasoning directions.
:param logits: Logits tensor of shape (batch_size, sequence_length, vocab_size)
or (batch_size, vocab_size)
:type logits: torch.Tensor
:return: Entropy values for each token position, of shape (batch_size, sequence_length)
or (batch_size,)
:rtype: torch.Tensor
Example::
>>> logits = torch.randn(2, 10, 50000) # batch_size=2, seq_len=10, vocab_size=50000
>>> entropy = entropy_from_logits(logits)
>>> entropy.shape
torch.Size([2, 10])
"""
# Use Categorical distribution for efficient entropy calculation
categorical = dist.Categorical(logits=logits)
return categorical.entropy()
def create_high_entropy_mask(
entropy: torch.Tensor,
action_mask: Optional[torch.Tensor],
high_entropy_ratio: float = 0.2,
) -> torch.Tensor:
"""
Create a binary mask for high-entropy tokens based on the specified ratio.
This function identifies the top-k highest entropy tokens (forking tokens) within each sequence
and creates a binary mask. Only tokens with high entropy will be used for gradient updates,
following the approach in "Beyond the 80/20 Rule: High-Entropy Minority Tokens Drive Effective
Reinforcement Learning for LLM Reasoning" (https://arxiv.org/abs/2506.01939).
The paper shows that utilizing only 20% of high-entropy tokens can maintain performance comparable
to full-gradient updates, with common value of 0.2 (top 20% highest entropy tokens).
:param entropy: Entropy values for each token, shape (batch_size, sequence_length)
:type entropy: torch.Tensor
:param action_mask: Binary mask indicating valid tokens (1 for valid, 0 for padding)
:type action_mask: Optional[torch.Tensor]
:param high_entropy_ratio: Ratio of high-entropy tokens to keep (e.g., 0.2 means top 20%).
Common value: 0.2. Based on https://arxiv.org/abs/2506.01939, defaults to 0.2
:type high_entropy_ratio: float
:return: Binary mask for high-entropy tokens, shape (batch_size, sequence_length)
:rtype: torch.Tensor
Example::
>>> entropy = torch.tensor([[1.0, 5.0, 2.0, 6.0, 3.0], [2.0, 4.0, 1.0, 5.0, 0.0]])
>>> action_mask = torch.tensor([[1, 1, 1, 1, 1], [1, 1, 1, 1, 0]])
>>> mask = create_high_entropy_mask(entropy, action_mask, high_entropy_ratio=0.4)
>>> mask
tensor([[0, 1, 0, 1, 0], [0, 1, 0, 1, 0]]) # Top 40% (2 out of 5 valid tokens)
"""
if high_entropy_ratio <= 0.0 or high_entropy_ratio >= 1.0:
# Return all-ones mask if ratio is invalid
if action_mask is not None:
return action_mask.clone()
return torch.ones_like(entropy, dtype=torch.float32)
# Validate shapes
if len(entropy.shape) != 2:
raise ValueError(f"entropy must be 2D tensor (batch_size, seq_len), got shape {entropy.shape}")
batch_size, seq_len = entropy.shape
if action_mask is not None:
if len(action_mask.shape) != 2:
raise ValueError(f"action_mask must be 2D tensor (batch_size, seq_len), got shape {action_mask.shape}")
if action_mask.shape != entropy.shape:
raise ValueError(f"action_mask shape {action_mask.shape} must match entropy shape {entropy.shape}")
high_entropy_mask = torch.zeros_like(entropy, dtype=torch.float32)
for i in range(batch_size):
# Get valid entropy values for this sequence
if action_mask is not None:
# Ensure both are 1D tensors of the same length
entropy_i = entropy[i] # Shape: (seq_len,)
mask_i = action_mask[i] # Shape: (seq_len,)
# Convert to float if needed for multiplication
if mask_i.dtype != entropy_i.dtype:
mask_i = mask_i.to(dtype=entropy_i.dtype)
valid_entropy = entropy_i * mask_i
valid_indices = mask_i.bool()
else:
valid_entropy = entropy[i]
valid_indices = torch.ones(seq_len, dtype=torch.bool, device=entropy.device)
if not valid_indices.any():
continue
# Calculate number of high-entropy tokens to keep
num_valid = valid_indices.sum().item()
num_high_entropy = max(1, int(num_valid * high_entropy_ratio))
# Get top-k highest entropy tokens
_, top_indices = torch.topk(valid_entropy, k=num_high_entropy, dim=0)
high_entropy_mask[i, top_indices] = 1.0
return high_entropy_mask
[docs]def log_probs_from_logits(
logits: torch.Tensor, labels: torch.Tensor, disable_logprobs_flashattn: bool = False
) -> torch.Tensor:
"""
Compute log probabilities for the given labels from logits.
This function calculates log probabilities efficiently, using different approaches
based on the input data type to optimize memory usage. For float32/float64 tensors,
it uses a direct computation approach, while for other data types (e.g. float16 and bfloat16)
it uses PyTorch's log_softmax function with row-by-row processing to reduce peak memory consumption.
:param logits: Logits tensor of shape (batch_size, sequence_length, vocab_size)
or (batch_size, vocab_size)
:type logits: torch.Tensor
:param labels: Labels tensor containing token indices, of shape (batch_size, sequence_length)
or (batch_size,)
:type labels: torch.Tensor
:param disable_logprobs_flashattn: Whether to use flash attn when calculating cross entropy loss
default to False
:type disable_logprobs_flashattn: bool
:return: Log probabilities for the given labels, of shape matching labels
:rtype: torch.Tensor
Example::
>>> logits = torch.randn(2, 3, 5) # batch_size=2, seq_len=3, vocab_size=5
>>> labels = torch.randint(0, 5, (2, 3)) # batch_size=2, seq_len=3
>>> log_probs = log_probs_from_logits(logits, labels)
>>> log_probs.shape
torch.Size([2, 3])
"""
if logits.dtype in [torch.float32, torch.float64]:
batch_dim = logits.shape[:-1]
last_dim = logits.shape[-1]
flashattn_available = False
if not disable_logprobs_flashattn:
try:
from flash_attn.ops.triton.cross_entropy import cross_entropy_loss
flashattn_available = True
except ImportError:
logger.warning("Failed to import cross_entropy_loss from flash_attn")
flashattn_available = False
if flashattn_available:
# use cross_entropy_loss from flash_attn to reduce peak mem consumption
output = cross_entropy_loss(logits.reshape(-1, last_dim), labels.reshape(-1))
log_probs_labels = -output[0].view(*batch_dim)
else:
logits_labels = torch.gather(logits, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
logsumexp_values = torch.stack([torch.logsumexp(logit, dim=-1)
for logit in logits] # loop to reduce peak mem consumption
)
log_probs_labels = logits_labels - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
else:
log_probs_labels = []
for row_logits, row_labels in zip(logits, labels): # loop to reduce peak mem consumption
row_log_probs = F.log_softmax(row_logits, dim=-1)
row_log_probs_labels = row_log_probs.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1)
log_probs_labels.append(row_log_probs_labels)
log_probs_labels = torch.stack(log_probs_labels)
return log_probs_labels
def reset_position_ids(attention_mask: torch.Tensor) -> torch.Tensor:
"""
Generate position IDs for packed sequences based on an attention mask.
In a packed sequence, multiple independent sequences are concatenated into a
single tensor row. The attention mask distinguishes these sequences using
unique integer identifiers (e.g., 1, 2, 3, ...). This function creates a
corresponding position ID tensor where positions are reset to zero at the
beginning of each packed sequence.
:param attention_mask: A 2D tensor of shape (batch_size, sequence_length)
where different positive integers mark different sequences within the
same row, and 0 typically represents padding.
:type attention_mask: torch.Tensor
:return: A 2D tensor of the same shape as `attention_mask` containing
the calculated position IDs. Each packed sequence will have its own
position IDs starting from 0.
:rtype: torch.Tensor
Example::
>>> attention_mask = torch.tensor([[1, 1, 1, 2, 2, 2, 3, 3, 0]])
>>> reset_position_ids(attention_mask)
tensor([[0, 1, 2, 0, 1, 2, 0, 1, 0]])
"""
# Initialize position_ids with zeros, same shape and device as the input mask.
position_ids = torch.zeros_like(attention_mask, dtype=torch.long)
# Iterate over each sequence in the batch.
for i in range(attention_mask.size(0)):
mask = attention_mask[i]
# Determine the number of packed samples in the current sequence by finding the max value in the mask.
# e.g., if mask is [1, 1, 2, 2, 2, 0], seq_num is 2.
seq_num = mask.max().item()
# Iterate through each packed sample, identified by its index (1, 2, ...).
for index in range(1, seq_num + 1):
# Create a boolean mask to isolate the tokens of the current sample.
sample_mask = mask == index
# Calculate the length of the current sample.
sample_length = sample_mask.sum().item()
# Generate a range of position IDs from 0 to sample_length - 1.
new_position_ids = torch.arange(sample_length, device=mask.device)
# Use the boolean mask to place the new position IDs into the correct locations.
position_ids[i, sample_mask] = new_position_ids
return position_ids
def apply_lora_configuration(
model: "nn.Module",
lora_rank: int,
lora_alpha: int,
lora_dropout: float,
target_modules: Optional[List[str]] = None,
freeze_vision_tower: bool = True,
) -> "nn.Module":
"""
Apply LoRA (Low-Rank Adaptation) configuration to a model.
This function configures and applies LoRA adaptation to the specified model,
including setting up the LoRA configuration and applying it to the model.
:param model: The model to apply LoRA configuration to
:type model: nn.Module
:param lora_rank: Rank for LoRA adaptation
:type lora_rank: int
:param lora_alpha: Alpha parameter for LoRA scaling
:type lora_alpha: int
:param lora_dropout: Dropout rate for LoRA layers
:type lora_dropout: float
:param target_modules: List of target modules for applying LoRA (auto-detected if None)
:type target_modules: Optional[List[str]]
:param freeze_vision_tower: Whether to freeze the vision tower components
:type freeze_vision_tower: bool
:return: The model with LoRA configuration applied
:rtype: nn.Module
Example::
>>> model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
>>> model = apply_lora_configuration(
... model=model,
... lora_rank=16,
... lora_alpha=32,
... lora_dropout=0.1
... )
"""
# Enable input require gradients for LoRA
model.enable_input_require_grads()
# Auto-detect target modules if not provided
if target_modules is None:
target_modules = find_all_linear_modules(model, freeze_vision_tower)
print("target_modules: ", target_modules)
# Configure LoRA
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=lora_rank,
lora_alpha=lora_alpha,
target_modules=target_modules,
lora_dropout=lora_dropout,
bias="none",
)
# Apply LoRA to the model
model = get_peft_model(model, lora_config)
return model
def compute_approx_kl(
log_probs: torch.Tensor,
log_probs_base: torch.Tensor,
action_mask: Optional[torch.Tensor] = None,
kl_estimator: str = "k1",
) -> torch.Tensor:
"""
Compute approximate KL divergence between two probability distributions.
This function implements three different estimators for KL divergence approximation
as described in Schulman's blog: http://joschu.net/blog/kl-approx.html
:param log_probs: Log probabilities of the new distribution
:type log_probs: torch.Tensor
:param log_probs_base: Log probabilities of the base/reference distribution
:type log_probs_base: torch.Tensor
:param action_mask: Binary mask indicating valid action positions (1 for valid, 0 for padding)
:type action_mask: Optional[torch.Tensor]
:param kl_estimator: Type of KL estimator to use ("k1", "k2", or "k3")
:type kl_estimator: str
:return: Approximate KL divergence values
:rtype: torch.Tensor
Example::
>>> log_probs = torch.tensor([[0.1, -0.2, 0.3], [-0.1, 0.2, 0.1]])
>>> log_probs_base = torch.tensor([[0.2, -0.1, 0.2], [-0.2, 0.1, 0.2]])
>>> action_mask = torch.tensor([[1, 1, 0], [1, 1, 1]])
>>> kl = compute_approx_kl(log_probs, log_probs_base, action_mask, "k1")
>>> kl.shape
torch.Size([2, 3])
"""
assert kl_estimator in ["k1", "k2", "k3"], f"Invalid kl_estimator: {kl_estimator}"
if kl_estimator == "k1":
log_ratio = log_probs.float() - log_probs_base.float()
if action_mask is not None:
log_ratio = log_ratio * action_mask
# The k2 estimator is the non negative kl approximation in
# http://joschu.net/blog/kl-approx.html
# The k2_loss is approximately equivalent to the
# one-step KL divergence penalty with the k1 estimator
# used in https://arxiv.org/pdf/2310.10505.
elif kl_estimator == "k2":
log_ratio = log_probs.float() - log_probs_base.float()
if action_mask is not None:
log_ratio = log_ratio * action_mask
log_ratio = log_ratio ** 2 / 2.0
# The k3 estimator is the non negative kl approximation in
# http://joschu.net/blog/kl-approx.html
elif kl_estimator == "k3":
log_ratio = log_probs.float() - log_probs_base.float()
if action_mask is not None:
log_ratio = log_ratio * action_mask
log_ratio = -log_ratio
log_ratio = log_ratio.exp() - 1 - log_ratio
return log_ratio
def compute_reward(
r: Union[torch.Tensor, float],
kl_coef: float,
kl: Union[torch.Tensor, list[torch.Tensor]],
action_mask: Optional[torch.Tensor] = None,
num_actions: Optional[Union[int, list[int]]] = None,
reward_clip_range: Tuple[float, float] = None,
) -> Union[torch.Tensor, list[torch.Tensor]]:
"""
Compute final reward by combining base reward with KL penalty.
Combines base reward with KL divergence penalty to encourage policy stability.
Supports two modes: with action mask (efficient) and without (individual processing).
:param r: Base reward tensor or scalar
:type r: Union[torch.Tensor, float]
:param kl_coef: KL penalty coefficient (<=0 disables penalty)
:type kl_coef: float
:param kl: KL divergence values as tensor or list
:type kl: Union[torch.Tensor, list[torch.Tensor]]
:param action_mask: Binary mask for valid action positions
:type action_mask: Optional[torch.Tensor]
:param num_actions: Number of actions per sequence (no mask mode)
:type num_actions: Optional[Union[int, list[int]]]
:param reward_clip_range: (min, max) to clip base reward
:type reward_clip_range: Tuple[float, float]
:return: Final reward tensor or list
:rtype: Union[torch.Tensor, list[torch.Tensor]]
Example::
>>> r = torch.tensor([1.0, 2.0])
>>> kl_coef = 0.1
>>> kl = torch.tensor([[0.1, 0.2, 0.3], [0.2, 0.1, 0.4]])
>>> action_mask = torch.tensor([[1, 1, 0], [1, 1, 1]])
>>> reward = compute_reward(r, kl_coef, kl, action_mask)
>>> reward.shape
torch.Size([2, 3])
"""
if kl_coef <= 0.0:
kl_coef = 0.0
if reward_clip_range:
r = r.clamp(min=reward_clip_range[0], max=reward_clip_range[1])
if action_mask is not None:
kl_reward = -kl_coef * kl
# The following code is equivalent to:
#
# last_reward = torch.zeros_like(kl)
# for i in range(last_reward.size(0)):
# for t in reversed(range(last_reward.size(1))):
# if action_mask[i][t] > 0.5:
# last_reward[i][t] = r[i]
# break
#
eos_indices = action_mask.size(1) - 1 - action_mask.long().fliplr().argmax(dim=1, keepdim=True)
last_reward = torch.zeros_like(kl).scatter_(dim=1, index=eos_indices, src=r.unsqueeze(1).to(kl.dtype))
reward = last_reward + kl_reward
else:
# TODO: write a more efficient version
reward = []
for i, (kl_seg, action_len) in enumerate(zip(kl, num_actions)):
kl_reward = -kl_coef * kl_seg
kl_reward[action_len - 1] += r[i]
reward.append(kl_reward)
return reward
def masked_mean(tensor: torch.Tensor, mask: Optional[torch.Tensor], dim: int = None) -> torch.Tensor:
"""
Compute mean of tensor excluding masked (padded) values.
Calculates mean along specified dimensions, ignoring positions where mask is zero.
Useful for sequence data with variable lengths.
:param tensor: Input tensor to average
:type tensor: torch.Tensor
:param mask: Binary mask (1 for valid, 0 for padding). None for regular mean.
:type mask: Optional[torch.Tensor]
:param dim: Dimension(s) to compute mean along. None for global mean.
:type dim: int
:return: Mean value(s) with masked positions excluded
:rtype: torch.Tensor
Example::
>>> tensor = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
>>> mask = torch.tensor([[1, 1, 0], [1, 0, 0]])
>>> masked_mean(tensor, mask)
tensor(2.6667)
>>> masked_mean(tensor, mask, dim=1)
tensor([1.5000, 4.0000])
"""
if mask is None:
return tensor.mean(axis=dim)
return (tensor * mask).sum(axis=dim) / mask.sum(axis=dim)
def unpacking_samples(values: torch.Tensor, packed_seqlens: list[int]) -> list[torch.Tensor]:
"""
Unpack concatenated sequences into individual sequences.
Splits packed tensor into multiple sequences based on original lengths.
Reverses packing operation for efficient batch processing.
:param values: Concatenated tensor (1, total_length) or (total_length,)
:type values: torch.Tensor
:param packed_seqlens: List of original sequence lengths
:type packed_seqlens: list[int]
:return: List of unpacked sequence tensors
:rtype: list[torch.Tensor]
Example::
>>> values = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8])
>>> packed_seqlens = [3, 2, 3]
>>> unpacked = unpacking_samples(values, packed_seqlens)
>>> [t.tolist() for t in unpacked]
[[1, 2, 3], [4, 5], [6, 7, 8]]
"""
values = values.squeeze(0)
unpacked_values = []
offset = 0
for seqlen in packed_seqlens:
unpacked_values.append(values[offset:offset + seqlen])
offset += seqlen
return unpacked_values
def pad_to_length(tensor: torch.Tensor, length: int, pad_value: Union[int, float], dim: int = -1) -> torch.Tensor:
"""
Left-pad a tensor to a target length along a given dimension.
:param tensor: Input tensor to be padded.
:type tensor: torch.Tensor
:param length: Target length along ``dim``. If the input is already
at least this length, the tensor is returned unchanged.
:type length: int
:param pad_value: Scalar pad value to use for the new elements.
:type pad_value: int or float
:param dim: Dimension along which to pad (default: ``-1``).
:type dim: int
:returns: Tensor padded on the left along ``dim`` to size
``length`` if needed; otherwise the original tensor.
:rtype: torch.Tensor
"""
if tensor.size(dim) >= length:
return tensor
else:
pad_size = list(tensor.shape)
pad_size[dim] = length - tensor.size(dim)
# left pad
return torch.cat([pad_value * torch.ones(*pad_size, dtype=tensor.dtype, device=tensor.device), tensor], dim=dim)
def concatenated_forward(
model: Callable, input0_ids: torch.Tensor, input0_mask: torch.Tensor, input1_ids: torch.Tensor,
input1_mask: torch.Tensor, input0_img_pixels: Optional[torch.Tensor], input0_img_grid_thws: Optional[torch.Tensor],
input1_img_pixels: Optional[torch.Tensor], input1_img_grid_thws: Optional[torch.Tensor],
input0_video_pixels: Optional[torch.Tensor], input0_video_grid_thws: Optional[torch.Tensor],
input1_video_pixels: Optional[torch.Tensor], input1_video_grid_thws: Optional[torch.Tensor], pad_token_id: int
) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
"""
Concatenates paired candidate inputs and runs a forward pass for vision-language models.
This utility is used in preference/reward modeling scenarios where two candidates
(e.g., chosen vs. rejected) are processed together for efficiency. Text sequences
from both candidates are left-padded to the maximum length across the pair, and
multimodal inputs (images/videos) are concatenated along the batch dimension when provided.
:param model: Callable model that accepts input ids, attention masks, and optional multimodal inputs.
:type model: Callable
:param input0_ids: Token ids for candidate 0.
:type input0_ids: torch.LongTensor of shape ``(B, T0)``
:param input0_mask: Attention mask for candidate 0 (1 = attend, 0 = pad).
:type input0_mask: torch.LongTensor of shape ``(B, T0)``
:param input1_ids: Token ids for candidate 1.
:type input1_ids: torch.LongTensor of shape ``(B, T1)``
:param input1_mask: Attention mask for candidate 1 (1 = attend, 0 = pad).
:type input1_mask: torch.LongTensor of shape ``(B, T1)``
:param input0_img_pixels: Image pixel tensor for candidate 0, or ``None`` if not used.
:type input0_img_pixels: Optional[torch.Tensor]
:param input0_img_grid_thws: Image grid metadata (e.g., THW) for candidate 0, or ``None``.
:type input0_img_grid_thws: Optional[torch.Tensor]
:param input1_img_pixels: Image pixel tensor for candidate 1, or ``None`` if not used.
:type input1_img_pixels: Optional[torch.Tensor]
:param input1_img_grid_thws: Image grid metadata (e.g., THW) for candidate 1, or ``None``.
:type input1_img_grid_thws: Optional[torch.Tensor]
:param input0_video_pixels: Video pixel tensor for candidate 0, or ``None`` if not used.
:type input0_video_pixels: Optional[torch.Tensor]
:param input0_video_grid_thws: Video grid metadata (e.g., THW) for candidate 0, or ``None``.
:type input0_video_grid_thws: Optional[torch.Tensor]
:param input1_video_pixels: Video pixel tensor for candidate 1, or ``None`` if not used.
:type input1_video_pixels: Optional[torch.Tensor]
:param input1_video_grid_thws: Video grid metadata (e.g., THW) for candidate 1, or ``None``.
:type input1_video_grid_thws: Optional[torch.Tensor]
:param pad_token_id: Token id used for left-padding text sequences to equal length.
:type pad_token_id: int
:return: A tuple ``(scores0, scores1)`` where each element is either a tensor of shape
``(B, ...)`` or a dict mapping head names to tensors, mirroring the model output
for each candidate.
:rtype: Tuple[Union[torch.Tensor, Dict[str, torch.Tensor]], Union[torch.Tensor, Dict[str, torch.Tensor]]]
"""
# Compute shared maximum lengths across the pair for text ids and masks.
max_length_ids = max(input0_ids.shape[1], input1_ids.shape[1])
max_length_mask = max(input0_mask.shape[1], input1_mask.shape[1])
input_ids = torch.cat(
(
pad_to_length(input0_ids, max_length_ids, pad_token_id),
pad_to_length(input1_ids, max_length_ids, pad_token_id),
),
dim=0,
)
att_masks = torch.cat(
(pad_to_length(input0_mask, max_length_mask, 0), pad_to_length(input1_mask, max_length_mask, 0)), dim=0
)
# Default multimodal inputs to None unless provided.
pixel_values = None
image_grid_thws = None
pixel_values_videos = None
video_grid_thws = None
with torch.no_grad():
if input0_img_pixels is not None:
pixel_values = torch.cat((input0_img_pixels, input1_img_pixels), dim=0)
image_grid_thws = torch.cat((input0_img_grid_thws, input1_img_grid_thws), dim=0)
if input0_video_pixels is not None:
pixel_values_videos = torch.cat((input0_video_pixels, input1_video_pixels), dim=0)
video_grid_thws = torch.cat((input0_video_grid_thws, input1_video_grid_thws), dim=0)
# Forward pass over the concatenated batch (size 2 * B).
scores = model(
input_ids,
attention_mask=att_masks,
pixel_values=pixel_values,
image_grid_thw=image_grid_thws,
pixel_values_videos=pixel_values_videos,
video_grid_thw=video_grid_thws
)
batch_size_0 = input0_ids.shape[0]
scores0 = {head_type: score[:batch_size_0] for head_type, score in scores.items()}
scores1 = {head_type: score[batch_size_0:] for head_type, score in scores.items()}
return scores0, scores1
class AttentionPooling(nn.Module):
"""
Attention pooling over the sequence dimension of VLM hidden states.
This module compresses a sequence of hidden states into a single fixed-size
representation by attending from a learnable global query to the sequence.
:param hidden_size: Hidden size of the backbone model. Must be divisible by ``num_heads``.
:type hidden_size: int
:param num_heads: Number of attention heads used for pooling. Defaults to ``4``.
:type num_heads: int, optional
:param qkv_bias: Whether to use bias terms in the key and value projection layers. Defaults to ``False``.
:type qkv_bias: bool, optional
:param position_bias: If ``True``, add a linear 1-D positional bias to attention logits. Defaults to ``False``.
:type position_bias: bool, optional
:param position_bias_scale: Scale factor for the positional bias; larger values more strongly favor later positions.
:type position_bias_scale: float, optional
.. note::
The learnable query is shared across heads and batches. Attention logits are
scaled by ``1 / sqrt(head_dim)`` where ``head_dim = hidden_size // num_heads``.
Example::
pool = AttentionPooling(hidden_size=1024, num_heads=8).to(torch.bfloat16).cuda()
x = torch.randn(2, 128, 1024, dtype=torch.bfloat16, device='cuda') # (B=2, S=128, C=1024)
y = pool(x)
assert y.shape == (2, 1024)
"""
def __init__(
self,
hidden_size: int,
num_heads: int = 4,
qkv_bias: bool = False,
position_bias: bool = False,
position_bias_scale: float = 3.0,
) -> None:
super(AttentionPooling, self).__init__()
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.scale = self.head_dim ** -0.5
self.position_bias = position_bias
self.position_bias_scale = position_bias_scale
self.k = nn.Linear(hidden_size, hidden_size, bias=qkv_bias)
self.v = nn.Linear(hidden_size, hidden_size, bias=qkv_bias)
# 0.02 for better initialization
self.query = nn.Parameter(torch.randn(hidden_size) * 0.02)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""
Apply attention pooling over the sequence of hidden states.
:param hidden_states: Hidden states to pool, of shape ``(B, S, C)``.
:type hidden_states: torch.Tensor
:returns: Pooled hidden states of shape ``(B, C)``.
:rtype: torch.Tensor
"""
B, S, C = hidden_states.shape
# Multi-head projection for key and value
k = self.k(hidden_states).reshape(B, S, self.num_heads, self.head_dim).permute(0, 2, 1, 3) # B, H, S, D
v = self.v(hidden_states).reshape(B, S, self.num_heads, self.head_dim).permute(0, 2, 1, 3) # B, H, S, D
# Expand query for batch dimension
q = self.query.unsqueeze(0).expand(B, -1, -1) # B, H, C
q = q.unsqueeze(2) # B, H, 1, C
q = q.reshape(B, self.num_heads, 1, self.head_dim) # B, H, 1, D
# Attention weights
attn = (q @ k.transpose(-2, -1)) * self.scale # B, H, 1, S
# Add position bias
if self.position_bias:
position_bias = torch.arange(S, device=k.device).float() / S * self.position_bias_scale
attn = attn + position_bias.view(1, 1, 1, -1) # Add position bias
# Attention pooling
attn = torch.softmax(attn, dim=-1) # B, H, 1, S
out = (attn @ v).squeeze(2) # B, H, D
out = out.reshape(B, -1) # B, C
return out