Source code for lightrft.trainer.replay_buffer_utils
"""
Utility functions for replay buffer operations in reinforcement learning.
This module provides specialized functions for handling both language model
experiences and vision-language model experiences. It includes utilities for batch
splitting, sequence padding, and experience creation optimized for distributed training.
Key features:
- Automatic detection of experience types
- Efficient batch splitting and creation
- Sequence padding and padding removal
- Support for both packed and unpacked samples
"""
from typing import List, Optional, Union
from dataclasses import dataclass
import torch
import torch.nn.functional as F
from PIL import Image
from .experience_maker import Experience
from .experience_maker_vl import ExperienceVL
[docs]@dataclass
class BufferItem:
"""BufferItem is an item of experience data.
Shapes of each tensor:
sequences: (S)
action_log_probs: (A)
base_action_log_probs: (A)
values: (1)
returns: (1)
advantages: (1)
attention_mask: (S)
action_mask: (A)
action_entropy: (A) - Entropy values for high-entropy token filtering
"A" is the number of actions.
"""
sequences: torch.Tensor
action_log_probs: torch.Tensor
base_action_log_probs: torch.Tensor
values: torch.Tensor
returns: torch.Tensor
advantages: torch.Tensor
attention_mask: Optional[torch.LongTensor]
action_mask: Optional[torch.BoolTensor]
info: Optional[dict]
action_entropy: Optional[torch.Tensor] = None # Entropy for high-entropy token filtering
[docs]@dataclass
class BufferItemVL:
"""BufferItemVL is an item of experience data.
Shapes of each tensor:
sequences: (S)
pixel_values: (B*H, W)
image_grid_thws: (B, 3)
raw_images: Optional[List[Image.Image]] # raw images before processing
action_log_probs: (A)
base_action_log_probs: (A)
values: (1)
returns: (1)
advantages: (1)
attention_mask: (S)
action_mask: (A)
action_entropy: (A) - Entropy values for high-entropy token filtering
"A" is the number of actions.
"""
sequences: torch.Tensor
pixel_values: Optional[torch.Tensor] = None # image pixel processed by HF processor
image_grid_thws: Optional[torch.Tensor] = None # image grid thw
pixel_values_videos: Optional[torch.Tensor] = None # video pixel processed by HF processor
video_grid_thws: Optional[torch.Tensor] = None # video grid thw
raw_images: Optional[List[Image.Image]] = None # raw images before processing
action_log_probs: torch.Tensor = None
base_action_log_probs: torch.Tensor = None
values: torch.Tensor = None
returns: torch.Tensor = None
advantages: torch.Tensor = None
attention_mask: Optional[torch.LongTensor] = None
action_mask: Optional[torch.BoolTensor] = None
info: Optional[dict] = None
action_entropy: Optional[torch.Tensor] = None # Entropy for high-entropy token filtering
[docs]def is_vl_experience(experience: Union[Experience, ExperienceVL]) -> bool:
"""
Determine if an experience is a vision-language experience.
Checks for the presence of vision-specific attributes to distinguish between
language model experiences and vision-language experiences.
:param experience: The experience object to check
:type experience: Union[Experience, ExperienceVL]
:return: True if the experience contains vision data, False otherwise
:rtype: bool
Example::
exp = ExperienceVL(...)
if is_vl_experience(exp):
print("This is a vision-language experience")
"""
return hasattr(experience, 'pixel_values')
[docs]def split_experience_batch(experience: Union[Experience, ExperienceVL]) -> List:
"""
Split a batch of experiences into individual items.
Automatically detects the experience type and delegates to the appropriate
splitting function. This is a generic interface that handles both types of experiences.
:param experience: Batch experience to split into individual items
:type experience: Union[Experience, ExperienceVL]
:return: List of individual experience items
:rtype: List
Example::
# Split a batch of experiences
batch_experience = make_experience_batch(items)
individual_items = split_experience_batch(batch_experience)
# Process each item individually
for item in individual_items:
process_item(item)
"""
if is_vl_experience(experience):
return _split_experience_batch_vl(experience)
else:
return _split_experience_batch(experience)
def _split_experience_batch(experience: Experience) -> List:
"""
Split a batch of language model experiences into individual items.
This function processes a batch of experiences (without vision data)
and splits them into individual BufferItem objects. It handles all
experience attributes including sequences, log probabilities, values, returns,
advantages, masks, and additional info.
:param experience: Batch of experiences to split
:type experience: Experience
:return: List of individual BufferItem objects
:rtype: List
:raises AssertionError: If batch size consistency check fails
Example::
# Create a batch experience
batch_exp = Experience(
sequences=torch.tensor([[1,2,3],[4,5,6]]),
action_log_probs=torch.tensor([[0.1,0.2],[0.3,0.4]]),
# ... other attributes
)
# Split into individual items
items = _split_experience_batch(batch_exp)
print(f"Split {len(items)} items from batch")
"""
batch_size = len(experience.sequences)
batch_kwargs = [{} for _ in range(batch_size)]
keys = (
"sequences",
"action_log_probs",
"base_action_log_probs",
"values",
"returns",
"advantages",
"attention_mask",
"action_mask",
"action_entropy",
)
for key in keys:
# Use getattr with default None to handle optional attributes like action_entropy
value = getattr(experience, key, None)
if value is None:
for i in range(batch_size):
batch_kwargs[i][key] = None
continue
vals = value
if isinstance(vals, torch.Tensor):
vals = torch.unbind(vals)
assert batch_size == len(vals)
for i, v in enumerate(vals):
batch_kwargs[i][key] = v
for i in range(batch_size):
batch_kwargs[i]["info"] = {}
# Instead of unbinding tensors, we handle various data types in info.
if experience.info:
for k, v_batch in experience.info.items():
if isinstance(v_batch, torch.Tensor):
# If it's a tensor, unbind it as before
vals = torch.unbind(v_batch)
assert batch_size == len(vals)
for i, vv in enumerate(vals):
if isinstance(vv, torch.Tensor) and vv.numel() == 1:
batch_kwargs[i]["info"][k] = vv.item()
else:
batch_kwargs[i]["info"][k] = vv
elif isinstance(v_batch, list) and len(v_batch) == batch_size:
# If it's a list (e.g., list of strings, dicts), distribute it
for i in range(batch_size):
batch_kwargs[i]["info"][k] = v_batch[i]
else:
# For other cases, broadcast the same value (if not a sequence)
for i in range(batch_size):
batch_kwargs[i]["info"][k] = v_batch
items = [BufferItem(**kwargs) for kwargs in batch_kwargs]
return items
def _split_experience_batch_vl(experience: ExperienceVL) -> List:
"""
Split a batch of vision-language experiences into individual items.
This function handles the complex logic of de-stacking vision-language data.
Unlike text-only data, vision components (images/videos) are often flattened
into a single continuous tensor for efficiency during rollout. This function
uses metadata in `experience.info` (`image_num` and `video_num`) to correctly
slice these flattened tensors back into their per-sample components.
Splitting Logic:
1. Standard Tensors: `sequences`, `values`, etc. are split using `torch.unbind`.
2. Grid Metadata: `image_grid_thws` (N, 3) is sliced based on `experience.info["image_num"]`.
For example, if `image_num` is [2, 1], the first sample gets the first 2 rows of grids.
3. Pixel Values: `pixel_values` (Total_Patches, patches) is sliced based on the sum
of tokens calculated from the sample's corresponding `image_grid_thws`.
:param experience: Batch of vision-language experiences to split
:type experience: ExperienceVL
:return: List of individual BufferItemVL objects
:rtype: List
Example::
# Multi-image scenario: Batch size 2
# Sample 0 has 2 images, Sample 1 has 1 image.
# Total 3 images in image_grid_thws
batch_exp = ExperienceVL(
sequences=torch.zeros(2, 10),
image_grid_thws=torch.tensor([[1, 10, 10], [1, 20, 20], [1, 15, 15]]),
pixel_values=torch.randn(100+400+225, 1152), # flattened patches
info={
"image_num": torch.tensor([2, 1], dtype=torch.float32)
}
)
items = _split_experience_batch_vl(batch_exp)
# items[0].image_grid_thws: Shape [2, 3] (First two rows)
# items[0].pixel_values: Shape [10*10 + 20*20, 1152]
# items[1].image_grid_thws: Shape [1, 3] (Last row)
# items[1].pixel_values: Shape [15*15, 1152]
"""
batch_size = len(experience.sequences)
batch_kwargs = [{} for _ in range(batch_size)]
# First, split standard tensors that always match batch_size
keys = (
"sequences",
"action_log_probs",
"base_action_log_probs",
"values",
"returns",
"advantages",
"attention_mask",
"action_mask",
"action_entropy",
)
for key in keys:
# Use getattr with default None to handle optional attributes like action_entropy
value = getattr(experience, key, None)
if value is None:
for i in range(batch_size):
batch_kwargs[i][key] = None
continue
vals = value
if isinstance(vals, torch.Tensor):
vals = torch.unbind(vals)
assert batch_size == len(vals), f"Key {key} size mismatch: {len(vals)} vs {batch_size}"
for i, v in enumerate(vals):
batch_kwargs[i][key] = v
# Split image_grid_thws and video_grid_thws accurately using metadata
for grid_key, num_key in [("image_grid_thws", "image_num"), ("video_grid_thws", "video_num")]:
grid_data = getattr(experience, grid_key, None)
if grid_data is not None:
# If it's already a list, it was pre-split by _process_multi_image_video_thws in FastExperienceMaker
if isinstance(grid_data, list):
for i in range(batch_size):
batch_kwargs[i][grid_key] = grid_data[i]
continue
# Try to get number of components per sample from info
nums = experience.info.get(num_key) if experience.info else None
if nums is not None:
if isinstance(nums, torch.Tensor):
nums = nums.tolist()
curr_idx = 0
for i, n in enumerate(nums):
if n > 0:
batch_kwargs[i][grid_key] = grid_data[curr_idx:curr_idx + n]
curr_idx += n
else:
batch_kwargs[i][grid_key] = None
else:
# Fallback for simple case: 1-to-1 mapping
if isinstance(grid_data, torch.Tensor) and grid_data.size(0) == batch_size:
vals = torch.unbind(grid_data)
for i, v in enumerate(vals):
batch_kwargs[i][grid_key] = v
elif isinstance(grid_data, list):
for i, v in enumerate(grid_data):
batch_kwargs[i][grid_key] = v
else:
raise ValueError(
f"Ambiguous {grid_key} split: Total {grid_data.size(0)} vs Batch {batch_size}. "
f"Missing '{num_key}' in info."
)
else:
for i in range(batch_size):
batch_kwargs[i][grid_key] = None
# Split image data
if experience.pixel_values is not None:
pixel_values = experience.pixel_values
if isinstance(pixel_values, torch.Tensor):
index = 0
for i in range(len(batch_kwargs)):
if batch_kwargs[i]["image_grid_thws"] is not None:
grid = batch_kwargs[i]["image_grid_thws"]
# grid is already [N, 3] for this sample
num_image_tokens = torch.sum(torch.prod(grid, dim=-1)).item()
else:
num_image_tokens = 0
# Slice from the flattened pixel_values
batch_kwargs[i]["pixel_values"] = pixel_values[index:index + num_image_tokens]
index += num_image_tokens
# Split video data
if experience.pixel_values_videos is not None:
pixel_values_videos = experience.pixel_values_videos
if isinstance(pixel_values_videos, torch.Tensor):
index = 0
for i in range(len(batch_kwargs)):
if batch_kwargs[i]["video_grid_thws"] is not None:
grid = batch_kwargs[i]["video_grid_thws"]
num_video_tokens = torch.sum(torch.prod(grid, dim=-1)).item()
else:
num_video_tokens = 0
batch_kwargs[i]["pixel_values_videos"] = pixel_values_videos[index:index + num_video_tokens]
index += num_video_tokens
# Split raw images
if experience.raw_images is not None:
for i in range(len(batch_kwargs)):
batch_kwargs[i]["raw_images"] = experience.raw_images[i]
for i in range(batch_size):
batch_kwargs[i]["info"] = {}
# Instead of unbinding tensors, we handle various data types in info.
if experience.info:
for k, v_batch in experience.info.items():
if isinstance(v_batch, torch.Tensor):
# If it's a tensor, unbind it as before
vals = torch.unbind(v_batch)
assert batch_size == len(vals)
for i, vv in enumerate(vals):
if isinstance(vv, torch.Tensor) and vv.numel() == 1:
batch_kwargs[i]["info"][k] = vv.item()
else:
batch_kwargs[i]["info"][k] = vv
elif isinstance(v_batch, list) and len(v_batch) == batch_size:
# If it's a list (e.g., list of strings, dicts), distribute it
for i in range(batch_size):
batch_kwargs[i]["info"][k] = v_batch[i]
else:
# For other cases, broadcast the same value (if not a sequence)
for i in range(batch_size):
batch_kwargs[i]["info"][k] = v_batch
items = [BufferItemVL(**kwargs) for kwargs in batch_kwargs]
return items
[docs]def zero_pad_sequences(sequences: List[torch.Tensor], side: str = "left") -> torch.Tensor:
"""
Zero-pad a list of sequences to the same length.
This utility function pads sequences to the maximum length in the batch,
either on the left or right side. It is used for creating batched tensors
from variable-length sequences.
:param sequences: List of sequences to pad (each sequence is a 1D tensor)
:type sequences: List[torch.Tensor]
:param side: Padding side, either "left" or "right"
:type side: str
:return: Batched tensor of padded sequences
:rtype: torch.Tensor
:raises AssertionError: If side is not "left" or "right"
Example::
sequences = [
torch.tensor([1, 2, 3]),
torch.tensor([4, 5]),
torch.tensor([6, 7, 8, 9])
]
# Pad to the right
padded = zero_pad_sequences(sequences, side="right")
# Result: tensor([[1, 2, 3, 0],
# [4, 5, 0, 0],
# [6, 7, 8, 9]])
# Pad to the left
padded_left = zero_pad_sequences(sequences, side="left")
# Result: tensor([[0, 1, 2, 3],
# [0, 0, 4, 5],
# [6, 7, 8, 9]])
"""
assert side in ("left", "right")
max_len = max(seq.size(0) for seq in sequences)
padded_sequences = []
for seq in sequences:
pad_len = max_len - seq.size(0)
padding = (pad_len, 0) if side == "left" else (0, pad_len)
padded_sequences.append(F.pad(seq, padding))
return torch.stack(padded_sequences, dim=0)
[docs]def make_experience_batch(items: List, packing_samples: bool = False) -> Union[Experience, ExperienceVL]:
"""
Create a batch experience from individual items.
This generic function automatically detects the item type
and delegates to the appropriate batch creation function. It handles both packed
and unpacked samples efficiently.
:param items: List of individual experience items to batch
:type items: List
:param packing_samples: Whether to pack samples without padding (True) or use padding (False)
:type packing_samples: bool
:return: Batched experience (either Experience or ExperienceVL)
:rtype: Union[Experience, ExperienceVL]
:raises ValueError: If items list is empty
Example::
# Create batch from items
items = [BufferItem(...), BufferItem(...)]
batch_exp = make_experience_batch(items, packing_samples=False)
# Create batch from vision-language items
vl_items = [BufferItemVL(...), BufferItemVL(...)]
batch_vl_exp = make_experience_batch(vl_items, packing_samples=True)
"""
if not items:
raise ValueError("items list cannot be empty")
# Determine experience type by checking the first item
first_item = items[0]
if hasattr(first_item, 'pixel_values'):
return _make_experience_batch_vl(items, packing_samples)
else:
return _make_experience_batch(items, packing_samples)
def _make_experience_batch(items: List, packing_samples: bool = False) -> Experience:
"""
Create a batch of experiences from individual items.
This function aggregates individual experience items into a batched
Experience object. It handles both packed and unpacked samples, with padding
applied for unpacked samples to ensure consistent tensor shapes.
:param items: List of individual experience items
:type items: List
:param packing_samples: Whether to pack samples without padding (True) or use padding (False)
:type packing_samples: bool
:return: Batched experience
:rtype: Experience
Example::
items = [
BufferItem(sequences=torch.tensor([1,2,3]), ...),
BufferItem(sequences=torch.tensor([4,5,6,7]), ...)
]
# With padding (packing_samples=False)
batch_exp = _make_experience_batch(items, packing_samples=False)
# sequences will be padded to length 4
# Without padding (packing_samples=True)
packed_batch = _make_experience_batch(items, packing_samples=True)
# sequences will remain as a list of variable-length tensors
"""
kwargs = {}
keys = (
"sequences",
"action_log_probs",
"base_action_log_probs",
"values",
"returns",
"advantages",
"attention_mask",
"action_mask",
)
for key in keys:
vals = [getattr(item, key) for item in items]
if not packing_samples:
batch_data = zero_pad_sequences(vals, "left") if vals[0] is not None else None
else:
batch_data = vals if vals[0] is not None else None
kwargs[key] = batch_data
# Handle action_entropy if it exists in any item
# action_entropy has shape (A,) per item, handle it like action_mask
if items and hasattr(items[0], 'action_entropy') and items[0].action_entropy is not None:
entropy_vals = [getattr(item, 'action_entropy', None) for item in items]
if all(v is not None for v in entropy_vals):
if not packing_samples:
# For padded batches, pad action_entropy to match action_mask
kwargs["action_entropy"] = zero_pad_sequences(entropy_vals, "left")
else:
# For packed batches, check if all have the same length
lengths = [len(v) for v in entropy_vals]
if len(set(lengths)) == 1:
kwargs["action_entropy"] = torch.stack(entropy_vals, dim=0)
else:
# If lengths differ, pad to max length
kwargs["action_entropy"] = zero_pad_sequences(entropy_vals, "left")
else:
kwargs["action_entropy"] = None
else:
kwargs["action_entropy"] = None
kwargs["info"] = {}
if items and items[0].info:
for key in items[0].info.keys():
vals = [item.info[key] for item in items]
# Check if the values can be converted to a tensor (i.e., are numeric)
if isinstance(vals[0], (int, float, bool)):
try:
# Convert numeric types to a tensor
kwargs["info"][key] = torch.tensor(vals)
except (TypeError, ValueError):
# Fallback for mixed types or other errors
kwargs["info"][key] = vals
else:
# For non-numeric types (str, list, dict), keep them as a Python list
kwargs["info"][key] = vals
return Experience(**kwargs)
def _make_experience_batch_vl(items: List, packing_samples: bool = False) -> ExperienceVL:
"""
Create a batch of vision-language experiences from individual items.
This function aggregates individual `BufferItemVL` objects into a single `ExperienceVL`
batch. It concatenates visual data (pixels and grids) into flattened tensors
and automatically records the count per sample (`image_num`, `video_num`) in
the `info` dictionary to enable later splitting.
:param items: List of individual vision-language experience items
:type items: List
:param packing_samples: Whether to pack samples without padding (True) or use padding (False)
:type packing_samples: bool
:return: Batched vision-language experience
:rtype: ExperienceVL
Example::
# Create a batch from two items
item1 = BufferItemVL(
sequences=torch.zeros(5),
image_grid_thws=torch.tensor([[1, 5, 5], [1, 8, 8]]), # 2 images
pixel_values=torch.randn(25+64, 1152)
)
item2 = BufferItemVL(
sequences=torch.zeros(8),
image_grid_thws=torch.tensor([[1, 10, 10]]), # 1 image
pixel_values=torch.randn(100, 1152)
)
batch = _make_experience_batch_vl([item1, item2])
# batch.image_grid_thws: Shape [3, 3] (2 + 1 rows concatenated)
# batch.info["image_num"]: tensor([2., 1.], dtype=float32)
"""
kwargs = {}
keys = (
"sequences",
"action_log_probs",
"base_action_log_probs",
"values",
"returns",
"advantages",
"attention_mask",
"action_mask",
)
for key in keys:
vals = [getattr(item, key) for item in items]
if not packing_samples:
batch_data = zero_pad_sequences(vals, "left") if vals[0] is not None else None
else:
batch_data = vals if vals[0] is not None else None
kwargs[key] = batch_data
# Handle action_entropy if it exists in any item
# action_entropy has shape (A,) per item, handle it like action_mask
if items and hasattr(items[0], 'action_entropy') and items[0].action_entropy is not None:
entropy_vals = [getattr(item, 'action_entropy', None) for item in items]
if all(v is not None for v in entropy_vals):
if not packing_samples:
# For padded batches, pad action_entropy to match action_mask
kwargs["action_entropy"] = zero_pad_sequences(entropy_vals, "left")
else:
# For packed batches, check if all have the same length
lengths = [len(v) for v in entropy_vals]
if len(set(lengths)) == 1:
kwargs["action_entropy"] = torch.stack(entropy_vals, dim=0)
else:
# If lengths differ, pad to max length
kwargs["action_entropy"] = zero_pad_sequences(entropy_vals, "left")
else:
kwargs["action_entropy"] = None
else:
kwargs["action_entropy"] = None
# Image data processing
pixel_values_list = [
item.pixel_values for item in items if item.pixel_values is not None and item.pixel_values.numel() > 0
]
kwargs["pixel_values"] = torch.cat(pixel_values_list, dim=0) if pixel_values_list else None
image_grid_thws_list = [
item.image_grid_thws.unsqueeze(0) if
(item.image_grid_thws is not None and item.image_grid_thws.dim() == 1) else item.image_grid_thws
for item in items
if item.image_grid_thws is not None
]
if image_grid_thws_list:
kwargs["image_grid_thws"] = torch.cat(image_grid_thws_list, dim=0)
else:
kwargs["image_grid_thws"] = None
# Video data processing
pixel_values_videos_list = [
item.pixel_values_videos
for item in items
if item.pixel_values_videos is not None and item.pixel_values_videos.numel() > 0
]
kwargs["pixel_values_videos"] = torch.cat(pixel_values_videos_list, dim=0) if pixel_values_videos_list else None
video_grid_thws_list = [
item.video_grid_thws.unsqueeze(0) if
(item.video_grid_thws is not None and item.video_grid_thws.dim() == 1) else item.video_grid_thws
for item in items
if item.video_grid_thws is not None
]
if video_grid_thws_list:
kwargs["video_grid_thws"] = torch.cat(video_grid_thws_list, dim=0)
else:
kwargs["video_grid_thws"] = None
raw_images_list = [item.raw_images for item in items]
kwargs["raw_images"] = raw_images_list if raw_images_list and raw_images_list[0] is not None else None
# Record the number of components (images/videos) per sample into info dictionary.
# This ensures accuracy when splitting the batch back into individual items.
kwargs["info"] = {}
image_nums = []
video_nums = []
for item in items:
# Determine number of image components
if item.image_grid_thws is not None:
image_nums.append(item.image_grid_thws.size(0) if item.image_grid_thws.dim() > 1 else 1)
else:
image_nums.append(0)
# Determine number of video components
if item.video_grid_thws is not None:
video_nums.append(item.video_grid_thws.size(0) if item.video_grid_thws.dim() > 1 else 1)
else:
video_nums.append(0)
if items and items[0].info:
for key in items[0].info.keys():
vals = [item.info[key] for item in items]
# Check if the values can be converted to a tensor (i.e., are numeric)
if isinstance(vals[0], (int, float, bool)):
try:
# Convert numeric types to a tensor
kwargs["info"][key] = torch.tensor(vals)
except (TypeError, ValueError):
# Fallback for mixed types or other errors
kwargs["info"][key] = vals
else:
# For non-numeric types (str, list, dict), keep them as a Python list
kwargs["info"][key] = vals
return ExperienceVL(**kwargs)
[docs]def remove_padding_in_sequences(items: List) -> List:
"""
Remove padding from sequences in experience items.
This generic function automatically detects the item type and delegates to the
appropriate padding removal function. It removes both left and right padding
from sequences to restore their original lengths.
:param items: List of experience items with padded sequences
:type items: List
:return: List of experience items with padding removed
:rtype: List
Example::
# Remove padding from items
padded_items = [BufferItem(sequences=torch.tensor([0,0,1,2,3,0,0]), ...)]
clean_items = remove_padding_in_sequences(padded_items)
# Result: sequences become torch.tensor([1,2,3])
# Remove padding from vision-language items
padded_vl_items = [BufferItemVL(sequences=torch.tensor([0,0,4,5,6,0]), ...)]
clean_vl_items = remove_padding_in_sequences(padded_vl_items)
# Result: sequences become torch.tensor([4,5,6])
"""
if not items:
return items
# Determine item type by checking the first item
first_item = items[0]
if hasattr(first_item, 'pixel_values'):
return _remove_padding_in_sequences_vl(items)
else:
return _remove_padding_in_sequences(items)
def _remove_padding_in_sequences(items: List) -> List:
"""
Remove padding from sequences in experience items.
This function processes experience items and removes both left and right
padding from sequences and related tensors. It uses attention masks and action masks
to determine the original sequence boundaries.
:param items: List of experience items with padded sequences
:type items: List
:return: List of experience items with padding removed
:rtype: List
Example::
# Item with left and right padding
item = BufferItem(
sequences=torch.tensor([0, 0, 1, 2, 3, 0, 0]), # padded sequence
attention_mask=torch.tensor([0, 0, 1, 1, 1, 0, 0]),
action_mask=torch.tensor([0, 0, 1, 1, 1, 0, 0]),
# ... other attributes
)
# Remove padding
clean_item = _remove_padding_in_sequences([item])[0]
# Result: sequences become torch.tensor([1, 2, 3])
"""
for item in items:
seq, act_log_prob, base_act_log_prob, value, ret, adv, att_mask, act_mask = (
item.sequences,
item.action_log_probs,
item.base_action_log_probs,
item.values,
item.returns,
item.advantages,
item.attention_mask,
item.action_mask,
)
# Get action_entropy if it exists
action_entropy = getattr(item, 'action_entropy', None)
right_pad = (1 - act_mask.long()).sum()
right_pad = None if right_pad == 0 else -right_pad
# left_pad for seq and att_mask
left_pad = att_mask.long().argmax()
(
item.sequences,
item.action_log_probs,
item.base_action_log_probs,
item.values,
item.returns,
item.advantages,
item.attention_mask,
item.action_mask,
) = (
seq[left_pad:right_pad],
act_log_prob[:right_pad],
base_act_log_prob[:right_pad] if item.base_action_log_probs is not None else None,
value[:right_pad] if item.values is not None else None,
ret[:right_pad],
adv[:right_pad],
att_mask[left_pad:right_pad],
act_mask[:right_pad],
)
# Remove padding from action_entropy if it exists
if action_entropy is not None:
item.action_entropy = action_entropy[:right_pad]
return items
def _remove_padding_in_sequences_vl(items: List) -> List:
"""
Remove padding from sequences in vision-language experience items.
This function processes vision-language experience items and removes both left
and right padding from sequences and related tensors. The vision data (pixel values,
image grids, etc.) remains unchanged as they don't require padding removal.
:param items: List of vision-language experience items with padded sequences
:type items: List
:return: List of vision-language experience items with padding removed
:rtype: List
Example::
# Vision-language item with padding
item = BufferItemVL(
sequences=torch.tensor([0, 0, 4, 5, 6, 0]), # padded sequence
attention_mask=torch.tensor([0, 0, 1, 1, 1, 0]),
action_mask=torch.tensor([0, 0, 1, 1, 1, 0]),
pixel_values=torch.randn(1, 3, 224, 224), # vision data unchanged
# ... other attributes
)
# Remove padding
clean_item = _remove_padding_in_sequences_vl([item])[0]
# Result: sequences become torch.tensor([4, 5, 6])
"""
for item in items:
seq, act_log_prob, base_act_log_prob, value, ret, adv, att_mask, act_mask = (
item.sequences,
item.action_log_probs,
item.base_action_log_probs,
item.values,
item.returns,
item.advantages,
item.attention_mask,
item.action_mask,
)
# Get action_entropy if it exists
action_entropy = getattr(item, 'action_entropy', None)
right_pad = (1 - act_mask.long()).sum()
right_pad = None if right_pad == 0 else -right_pad
# left_pad for seq and att_mask
left_pad = att_mask.long().argmax()
(
item.sequences,
item.action_log_probs,
item.base_action_log_probs,
item.values,
item.returns,
item.advantages,
item.attention_mask,
item.action_mask,
) = (
seq[left_pad:right_pad],
act_log_prob[:right_pad],
base_act_log_prob[:right_pad] if item.base_action_log_probs is not None else None,
value[:right_pad] if item.values is not None else None,
ret[:right_pad],
adv[:right_pad],
att_mask[left_pad:right_pad],
act_mask[:right_pad],
)
# Remove padding from action_entropy if it exists
if action_entropy is not None:
item.action_entropy = action_entropy[:right_pad]
return items