Source code for lightrft.trainer.replay_buffer_vl
from typing import List
from abc import ABC
import random
import torch
from .experience_maker_vl import ExperienceVL
from .replay_buffer_utils import (
BufferItemVL, split_experience_batch, make_experience_batch, remove_padding_in_sequences
)
[docs]class NaiveReplayBufferVL(ABC):
"""
Naive replay buffer class for Vision-Language models. It stores experience samples.
:param sample_batch_size: Batch size when sampling (train_micro_batch_size).
:type sample_batch_size: int
:param limit: Limit of number of experience samples. A number <= 0 means unlimited, defaults to 0.
:type limit: int
:param cpu_offload: Whether to offload experience to CPU when sampling, defaults to True.
:type cpu_offload: bool
:param packing_samples: Whether to use packed samples format, defaults to False.
:type packing_samples: bool
"""
def __init__(
self, sample_batch_size: int, limit: int = 0, cpu_offload: bool = True, packing_samples: bool = False
) -> None:
super().__init__()
self.sample_batch_size = sample_batch_size
# Limit <= 0 means unlimited
self.limit = limit
self.cpu_offload = cpu_offload
self.packing_samples = packing_samples
self.target_device = torch.device(f"cuda:{torch.cuda.current_device()}")
self.items: List[BufferItemVL] = []
@torch.no_grad()
def append(self, experience: ExperienceVL) -> None:
"""
Append experience to the replay buffer.
:param experience: Experience batch to append.
:type experience: ExperienceVL
"""
if self.cpu_offload:
experience.to_device(torch.device("cpu"))
items = split_experience_batch(experience)
# The packed samples come with no padding
if not self.packing_samples:
items = remove_padding_in_sequences(items)
self.items.extend(items)
if self.limit > 0:
samples_to_remove = len(self.items) - self.limit
if samples_to_remove > 0:
self.items = self.items[samples_to_remove:]
@torch.no_grad()
def sample(self) -> ExperienceVL:
"""
Sample a batch of experiences from the replay buffer.
:return: Batch of sampled experiences.
:rtype: ExperienceVL
"""
items = random.sample(self.items, self.sample_batch_size)
experience = make_experience_batch(items, self.packing_samples)
if self.cpu_offload:
experience.to_device(self.target_device)
return experience
def __len__(self) -> int:
"""
Get the number of items in the replay buffer.
:return: Number of items.
:rtype: int
"""
return len(self.items)
def __getitem__(self, idx: int) -> BufferItemVL:
"""
Get an item from the replay buffer by index.
:param idx: Index of the item.
:type idx: int
:return: Buffer item at the specified index.
:rtype: BufferItemVL
"""
return self.items[idx]
[docs] def collate_fn(self, batch) -> ExperienceVL:
"""
Collate function for DataLoader.
:param batch: Batch of buffer items.
:type batch: List[BufferItemVL]
:return: Batched experience.
:rtype: ExperienceVL
"""
experience = make_experience_batch(batch, self.packing_samples)
return experience
[docs] def normalize(self, attribute: str, strategy) -> None:
"""
Normalize a specified attribute across all items in the buffer.
This method computes the mean and standard deviation of the specified attribute
across all items and normalizes them. Currently only supports "advantages".
:param attribute: Name of the attribute to normalize (currently only "advantages" is supported).
:type attribute: str
:param strategy: Distributed training strategy for all_reduce operations.
:type strategy: Strategy
"""
assert attribute == "advantages"
items = []
action_masks = []
for item in self:
items.append(getattr(item, attribute))
action_masks.append(item.action_mask)
items_vector = torch.cat(items).float().flatten()
if action_masks[0] is None:
# Packing samples has no action mask
action_masks_vector = 1
num_actions = items_vector.numel()
else:
action_masks_vector = torch.cat(action_masks).flatten()
num_actions = action_masks_vector.sum()
# For distributed data parallel: compute mean
sum_and_count = torch.tensor([items_vector.sum(), num_actions], device=items_vector.device)
all_sum, all_count = strategy.all_reduce(sum_and_count, "sum")
mean = all_sum / all_count
# Compute standard deviation
std = ((items_vector - mean).pow(2) * action_masks_vector).sum()
all_std = strategy.all_reduce(std, "sum")
rstd = (all_std / all_count).clamp(min=1e-8).rsqrt()
for i, item in enumerate(self):
setattr(item, attribute, (items[i] - mean) * rstd)