lightrft.trainer.replay_buffer_utils¶
Utility functions for replay buffer operations in reinforcement learning.
This module provides specialized functions for handling both language model experiences and vision-language model experiences. It includes utilities for batch splitting, sequence padding, and experience creation optimized for distributed training.
Key features: - Automatic detection of experience types - Efficient batch splitting and creation - Sequence padding and padding removal - Support for both packed and unpacked samples
- class lightrft.trainer.replay_buffer_utils.BufferItem(sequences: torch.Tensor, action_log_probs: torch.Tensor, base_action_log_probs: torch.Tensor, values: torch.Tensor, returns: torch.Tensor, advantages: torch.Tensor, attention_mask: torch.LongTensor | None, action_mask: torch.BoolTensor | None, info: dict | None, action_entropy: torch.Tensor | None = None)[source]¶
Bases:
objectBufferItem is an item of experience data.
Shapes of each tensor: sequences: (S) action_log_probs: (A) base_action_log_probs: (A) values: (1) returns: (1) advantages: (1) attention_mask: (S) action_mask: (A) action_entropy: (A) - Entropy values for high-entropy token filtering
“A” is the number of actions.
- action_entropy: torch.Tensor | None = None¶
- action_log_probs: torch.Tensor¶
- action_mask: torch.BoolTensor | None¶
- advantages: torch.Tensor¶
- attention_mask: torch.LongTensor | None¶
- base_action_log_probs: torch.Tensor¶
- info: dict | None¶
- returns: torch.Tensor¶
- sequences: torch.Tensor¶
- values: torch.Tensor¶
- class lightrft.trainer.replay_buffer_utils.BufferItemVL(sequences: torch.Tensor, pixel_values: torch.Tensor | None = None, image_grid_thws: torch.Tensor | None = None, pixel_values_videos: torch.Tensor | None = None, video_grid_thws: torch.Tensor | None = None, raw_images: List[Image] | None = None, action_log_probs: torch.Tensor = None, base_action_log_probs: torch.Tensor = None, values: torch.Tensor = None, returns: torch.Tensor = None, advantages: torch.Tensor = None, attention_mask: torch.LongTensor | None = None, action_mask: torch.BoolTensor | None = None, info: dict | None = None, action_entropy: torch.Tensor | None = None)[source]¶
Bases:
objectBufferItemVL is an item of experience data.
Shapes of each tensor: sequences: (S) pixel_values: (B*H, W) image_grid_thws: (B, 3) raw_images: Optional[List[Image.Image]] # raw images before processing action_log_probs: (A) base_action_log_probs: (A) values: (1) returns: (1) advantages: (1) attention_mask: (S) action_mask: (A) action_entropy: (A) - Entropy values for high-entropy token filtering
“A” is the number of actions.
- action_entropy: torch.Tensor | None = None¶
- action_log_probs: torch.Tensor = None¶
- action_mask: torch.BoolTensor | None = None¶
- advantages: torch.Tensor = None¶
- attention_mask: torch.LongTensor | None = None¶
- base_action_log_probs: torch.Tensor = None¶
- image_grid_thws: torch.Tensor | None = None¶
- info: dict | None = None¶
- pixel_values: torch.Tensor | None = None¶
- pixel_values_videos: torch.Tensor | None = None¶
- raw_images: List[Image] | None = None¶
- returns: torch.Tensor = None¶
- sequences: torch.Tensor¶
- values: torch.Tensor = None¶
- video_grid_thws: torch.Tensor | None = None¶
- lightrft.trainer.replay_buffer_utils.is_vl_experience(experience: Experience | ExperienceVL) bool[source]¶
Determine if an experience is a vision-language experience.
Checks for the presence of vision-specific attributes to distinguish between language model experiences and vision-language experiences.
- Parameters:
experience (Union[Experience, ExperienceVL]) – The experience object to check
- Returns:
True if the experience contains vision data, False otherwise
- Return type:
bool
Example:
exp = ExperienceVL(...) if is_vl_experience(exp): print("This is a vision-language experience")
- lightrft.trainer.replay_buffer_utils.make_experience_batch(items: List, packing_samples: bool = False) Experience | ExperienceVL[source]¶
Create a batch experience from individual items.
This generic function automatically detects the item type and delegates to the appropriate batch creation function. It handles both packed and unpacked samples efficiently.
- Parameters:
items (List) – List of individual experience items to batch
packing_samples (bool) – Whether to pack samples without padding (True) or use padding (False)
- Returns:
Batched experience (either Experience or ExperienceVL)
- Return type:
Union[Experience, ExperienceVL]
- Raises:
ValueError – If items list is empty
Example:
# Create batch from items items = [BufferItem(...), BufferItem(...)] batch_exp = make_experience_batch(items, packing_samples=False) # Create batch from vision-language items vl_items = [BufferItemVL(...), BufferItemVL(...)] batch_vl_exp = make_experience_batch(vl_items, packing_samples=True)
- lightrft.trainer.replay_buffer_utils.remove_padding_in_sequences(items: List) List[source]¶
Remove padding from sequences in experience items.
This generic function automatically detects the item type and delegates to the appropriate padding removal function. It removes both left and right padding from sequences to restore their original lengths.
- Parameters:
items (List) – List of experience items with padded sequences
- Returns:
List of experience items with padding removed
- Return type:
List
Example:
# Remove padding from items padded_items = [BufferItem(sequences=torch.tensor([0,0,1,2,3,0,0]), ...)] clean_items = remove_padding_in_sequences(padded_items) # Result: sequences become torch.tensor([1,2,3]) # Remove padding from vision-language items padded_vl_items = [BufferItemVL(sequences=torch.tensor([0,0,4,5,6,0]), ...)] clean_vl_items = remove_padding_in_sequences(padded_vl_items) # Result: sequences become torch.tensor([4,5,6])
- lightrft.trainer.replay_buffer_utils.split_experience_batch(experience: Experience | ExperienceVL) List[source]¶
Split a batch of experiences into individual items.
Automatically detects the experience type and delegates to the appropriate splitting function. This is a generic interface that handles both types of experiences.
- Parameters:
experience (Union[Experience, ExperienceVL]) – Batch experience to split into individual items
- Returns:
List of individual experience items
- Return type:
List
Example:
# Split a batch of experiences batch_experience = make_experience_batch(items) individual_items = split_experience_batch(batch_experience) # Process each item individually for item in individual_items: process_item(item)
- lightrft.trainer.replay_buffer_utils.zero_pad_sequences(sequences: List[torch.Tensor], side: str = 'left') torch.Tensor[source]¶
Zero-pad a list of sequences to the same length.
This utility function pads sequences to the maximum length in the batch, either on the left or right side. It is used for creating batched tensors from variable-length sequences.
- Parameters:
sequences (List[torch.Tensor]) – List of sequences to pad (each sequence is a 1D tensor)
side (str) – Padding side, either “left” or “right”
- Returns:
Batched tensor of padded sequences
- Return type:
torch.Tensor
- Raises:
AssertionError – If side is not “left” or “right”
Example:
sequences = [ torch.tensor([1, 2, 3]), torch.tensor([4, 5]), torch.tensor([6, 7, 8, 9]) ] # Pad to the right padded = zero_pad_sequences(sequences, side="right") # Result: tensor([[1, 2, 3, 0], # [4, 5, 0, 0], # [6, 7, 8, 9]]) # Pad to the left padded_left = zero_pad_sequences(sequences, side="left") # Result: tensor([[0, 1, 2, 3], # [0, 0, 4, 5], # [6, 7, 8, 9]])