Shortcuts

Source code for lightrft.models.actor_vl

"""
Vision-Language Actor Model Module for Reinforcement Learning.

This module provides the ActorVL class, which implements an actor model specifically designed
for vision-language tasks in reinforcement learning scenarios. The actor is responsible for
generating actions (text sequences) based on visual inputs (images and videos) and textual prompts.

The module supports various optimization techniques including:
- LoRA (Low-Rank Adaptation) for parameter-efficient fine-tuning
- Flash Attention 2.0 for improved performance
- DeepSpeed integration for distributed training
- Sample packing for efficient batch processing

Key Features:
- Multi-modal input processing (text + vision)
- Flexible model loading from pretrained checkpoints
- Support for various vision-language model architectures
- Gradient checkpointing for memory optimization
- MoE (Mixture of Experts) model support
"""

from typing import Optional, Tuple, Union

import torch
import torch.distributed as dist
import torch.nn as nn
from transformers import AutoModel, AutoModelForVision2Seq
from transformers.integrations.deepspeed import HfDeepSpeedConfig

from .utils import apply_lora_configuration, log_probs_from_logits, reset_position_ids, entropy_from_logits
from .actor_modality import ActorModality


[docs]class ActorVL(nn.Module): """ Vision-Language Actor model for reinforcement learning applications. This class serves as a foundation for implementing vision-language actor models in RL, which are responsible for generating text sequences (actions) based on visual (images and videos) and textual inputs. The model supports various optimization techniques including LoRA adaptation, quantization, and distributed training. The actor model can be initialized either from a pretrained model path or from an existing model instance, providing flexibility in model deployment scenarios. :param pretrain_or_model: Either a string path to a pretrained model or a model instance :type pretrain_or_model: Union[str, nn.Module] :param use_flash_attention_2: Whether to utilize Flash Attention 2.0 for improved performance :type use_flash_attention_2: bool :param bf16: Enable bfloat16 precision for model computations :type bf16: bool :param lora_rank: Rank for LoRA adaptation (0 disables LoRA) :type lora_rank: int :param lora_alpha: Alpha parameter for LoRA scaling :type lora_alpha: int :param lora_dropout: Dropout rate for LoRA layers :type lora_dropout: float :param target_modules: List of target modules for applying LoRA (auto-detected if None) :type target_modules: Optional[list] :param ds_config: Configuration for DeepSpeed distributed training :type ds_config: Optional[dict] :param device_map: Device mapping for loading the model onto specific devices :type device_map: Optional[dict] :param packing_samples: Whether to pack samples during training for efficiency :type packing_samples: bool Example:: # Initialize with a pretrained model path actor = ActorVL( pretrain_or_model="microsoft/LLaVA-1.5-7b-hf", use_flash_attention_2=True, lora_rank=16, lora_alpha=32 ) # Generate responses sequences, attention_mask, action_mask = actor.generate( input_ids=input_tensor, pixel_values=image_tensor, image_grid_thw=grid_tensor, max_new_tokens=100 ) """ # Model modality declaration - defines what types of inputs this model accepts modality = ActorModality.VISION_LANGUAGE def __init__( self, pretrain_or_model, use_flash_attention_2=False, bf16=True, lora_rank=0, lora_alpha=16, lora_dropout=0, target_modules=None, ds_config=None, device_map=None, packing_samples=False, high_entropy_token_ratio=0.0, **kwargs, ) -> None: super().__init__() self.high_entropy_token_ratio = high_entropy_token_ratio if isinstance(pretrain_or_model, str): self.pretrain_or_model = pretrain_or_model attn_implementation = "flash_attention_2" if use_flash_attention_2 else "eager" # Note: dschf is defined in function scope to avoid global effects # https://huggingface.co/docs/transformers/deepspeed#non-trainer-deepspeed-integration if ds_config is not None and ds_config["zero_optimization"]["stage"] == 3: dschf = HfDeepSpeedConfig(ds_config) else: dschf = None # noqa: F841 # When using FSDP with meta_init, we should not pass device_map to avoid # "Cannot copy out of meta tensor" error. FSDP will handle device placement. from_pretrained_kwargs = { "trust_remote_code": True, "attn_implementation": attn_implementation, "torch_dtype": torch.bfloat16 if bf16 else "auto", } # Only add device_map if we're not in a meta device context (used by FSDP) # Check if the default tensor creation device is meta try: test_tensor = torch.empty(1) is_meta_context = test_tensor.is_meta except: # noqa is_meta_context = False if not is_meta_context and device_map is not None: from_pretrained_kwargs["device_map"] = device_map self.model = AutoModelForVision2Seq.from_pretrained(pretrain_or_model, **from_pretrained_kwargs) # LoRA if lora_rank > 0: # https://github.com/huggingface/peft/issues/137 self.model = apply_lora_configuration( model=self.model, lora_rank=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, target_modules=target_modules, freeze_vision_tower=True, ) # https://github.com/huggingface/transformers/issues/26877 # Use `model.generate(use_cache=True)` instead.` self.model.config.use_cache = False # packing samples using Flash Attention 2 self.packing_samples = packing_samples else: self.model = pretrain_or_model self.pretrain_or_model = pretrain_or_model.config.model_type print("pretrain_or_model: ", self.pretrain_or_model) @torch.no_grad() def generate( self, input_ids: torch.Tensor, pixel_values: Optional[torch.Tensor] = None, image_grid_thw: Optional[torch.Tensor] = None, pixel_values_videos: Optional[torch.Tensor] = None, video_grid_thw: Optional[torch.Tensor] = None, **kwargs ) -> Union[ Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor], ]: """ Generate text sequences based on input text and visual information. This method performs text generation conditioned on both textual prompts and visual inputs. It handles the generation process with various sampling strategies and returns the generated sequences along with attention masks and action masks for RL training. :param input_ids: Input token IDs representing the text prompt :type input_ids: torch.Tensor :param pixel_values: Preprocessed pixel values of input images :type pixel_values: Optional[torch.Tensor] :param image_grid_thw: Image grid dimensions (time, height, width) :type image_grid_thw: Optional[torch.Tensor] :param pixel_values_videos: Preprocessed pixel values of input videos :type pixel_values_videos: Optional[torch.Tensor] :param video_grid_thw: Video grid dimensions :type video_grid_thw: Optional[torch.Tensor] :param kwargs: Additional generation parameters (top_k, top_p, temperature, etc.) :type kwargs: dict :return: Tuple containing generated sequences, attention mask, and action mask :rtype: Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]] # noqa Example:: sequences, attention_mask, action_mask = actor.generate( input_ids=torch.tensor([[1, 2, 3]]), pixel_values=image_tensor, image_grid_thw=torch.tensor([[1, 24, 24]]), max_new_tokens=50, temperature=0.8, do_sample=True ) """ generate_args = { "input_ids": input_ids, "pixel_values": pixel_values, "image_grid_thw": image_grid_thw, "pixel_values_videos": pixel_values_videos, "video_grid_thw": video_grid_thw, "top_k": kwargs.get("top_k", None), "top_p": kwargs.get("top_p", None), "do_sample": kwargs.get("do_sample", True), "early_stopping": kwargs.get("num_beams", 1) > 1, "temperature": kwargs.get("temperature", 1), "use_cache": True, "num_beams": kwargs.get("num_beams", 1), "attention_mask": kwargs.get("attention_mask"), "eos_token_id": kwargs.get("eos_token_id"), "pad_token_id": kwargs.get("pad_token_id"), "min_new_tokens": kwargs.get("min_new_tokens", 1), } if kwargs.get("max_new_tokens", None): generate_args["max_new_tokens"] = kwargs.get("max_new_tokens") if kwargs.get("max_length", None): generate_args["max_length"] = kwargs.get("max_length") # Call generate sequences = self.model.generate(**generate_args) # Prepare mask tensor eos_token_id = generate_args["eos_token_id"] pad_token_id = generate_args["pad_token_id"] return self.process_sequences(sequences, input_ids.size(1), eos_token_id, pad_token_id)
[docs] def forward( self, sequences: torch.LongTensor, num_actions: Optional[Union[int, list[int]]] = None, attention_mask: Optional[torch.Tensor] = None, pixel_values: Optional[torch.Tensor] = None, image_grid_thw: Optional[torch.Tensor] = None, pixel_values_videos: Optional[torch.Tensor] = None, video_grid_thw: Optional[torch.Tensor] = None, return_output=False, packed_seq_lens: Optional[list[int]] = None, ) -> torch.Tensor: """ Forward pass to compute action log probabilities for reinforcement learning. This method processes input sequences and visual information to compute log probabilities of actions (tokens) for RL training. It supports both standard and packed sequence formats and can return either just the action log probabilities or the full model output. :param sequences: Input token sequences :type sequences: torch.LongTensor :param num_actions: Number of action tokens to extract log probs for :type num_actions: Optional[Union[int, list[int]]] :param attention_mask: Attention mask for the sequences :type attention_mask: Optional[torch.Tensor] :param pixel_values: Preprocessed pixel values of input images :type pixel_values: Optional[torch.Tensor] :param image_grid_thw: Image grid dimensions (time, height, width) :type image_grid_thw: Optional[torch.Tensor] :param pixel_values_videos: Preprocessed pixel values of input videos :type pixel_values_videos: Optional[torch.Tensor] :param video_grid_thw: Video grid dimensions :type video_grid_thw: Optional[torch.Tensor] :param return_output: Whether to return the full model output along with log probs :type return_output: bool :param packed_seq_lens: Sequence lengths for packed samples :type packed_seq_lens: Optional[list[int]] :return: Action log probabilities or tuple of (action_log_probs, output) if return_output=True :rtype: torch.Tensor Example:: # Compute action log probabilities for RL training log_probs = actor( sequences=token_sequences, num_actions=10, pixel_values=image_tensor, image_grid_thw=grid_tensor ) # Get both log probs and full output log_probs, output = actor( sequences=token_sequences, num_actions=10, pixel_values=image_tensor, image_grid_thw=grid_tensor, return_output=True ) """ if not self.packing_samples: position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) else: # convert attention_mask to position_ids position_ids = reset_position_ids(attention_mask) # explicitly ignore attention_mask for packing_samples attention_mask = None output = self.model( sequences, attention_mask=attention_mask, position_ids=position_ids, pixel_values=pixel_values, image_grid_thw=image_grid_thw, pixel_values_videos=pixel_values_videos, video_grid_thw=video_grid_thw, ) if num_actions is None: # defult assert return_output return output log_probs = log_probs_from_logits(output["logits"][:, :-1, :], sequences[:, 1:]) # Calculate entropy for action tokens (for high-entropy token identification) # Only compute entropy when high_entropy_token_ratio is not 0 if self.high_entropy_token_ratio > 0.0: action_logits = output["logits"][:, :-1, :] # Shape: (batch, seq_len-1, vocab_size) action_entropy = entropy_from_logits(action_logits) # Shape: (batch, seq_len-1) else: action_entropy = None if not self.packing_samples: action_log_probs = log_probs[:, -num_actions:] if action_entropy is not None: action_entropy = action_entropy[:, -num_actions:] else: assert isinstance(num_actions, list) and len(num_actions) == len(packed_seq_lens) action_log_probs = [] action_entropy_list = [] offset = 0 for num_action, seq_len in zip(num_actions, packed_seq_lens): start, end = max(0, offset + seq_len - num_action - 1), offset + seq_len - 1 action_log_probs.append(log_probs[:, start:end]) if action_entropy is not None: action_entropy_list.append(action_entropy[:, start:end]) offset += seq_len action_log_probs = torch.cat(action_log_probs, dim=1) if action_entropy is not None: action_entropy = torch.cat(action_entropy_list, dim=1) if return_output: # Include action_entropy in output if computed if action_entropy is not None: output_dict = dict(output) output_dict["action_entropy"] = action_entropy return (action_log_probs, output_dict) return (action_log_probs, output) else: return action_log_probs
[docs] def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs={"use_reentrant": False}): """ Enable gradient checkpointing to reduce memory usage during training. Gradient checkpointing trades compute for memory by recomputing intermediate activations during the backward pass instead of storing them. This is particularly useful for training large vision-language models with limited GPU memory. :param gradient_checkpointing_kwargs: Additional arguments for gradient checkpointing :type gradient_checkpointing_kwargs: dict Example:: # Enable gradient checkpointing with default settings actor.gradient_checkpointing_enable() # Enable with custom settings actor.gradient_checkpointing_enable({"use_reentrant": True}) """ self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)
[docs] def gradient_checkpointing_disable(self): """ Disable gradient checkpointing to use normal forward/backward computation. This method restores the default behavior where all intermediate activations are stored during the forward pass for use in the backward pass. This increases memory usage but reduces computation time. Example:: # Disable gradient checkpointing actor.gradient_checkpointing_disable() """ self.model.gradient_checkpointing_disable()
[docs] def print_trainable_parameters(self): """ Print information about trainable parameters in the model. This method displays the number and percentage of trainable parameters, which is particularly useful when using parameter-efficient methods like LoRA. It helps monitor the efficiency of the fine-tuning approach. Example:: # Print trainable parameter statistics actor.print_trainable_parameters() # Output: trainable params: 4,194,304 || all params: 7,241,732,096 || trainable%: 0.058 """ self.model.print_trainable_parameters()
[docs] def process_sequences(self, sequences: torch.Tensor, input_len: int, eos_token_id: int, pad_token_id: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Called by `trainer/fast_exp_maker.py`. Process generated sequences to create proper attention and action masks. This method post-processes the generated sequences to ensure proper handling of end-of-sequence tokens and creates masks needed for reinforcement learning training. It handles edge cases like multiple EOS tokens and ensures consistent sequence formatting. :param sequences: Generated token sequences :type sequences: torch.Tensor :param input_len: Length of the input prompt :type input_len: int :param eos_token_id: End-of-sequence token ID :type eos_token_id: int :param pad_token_id: Padding token ID :type pad_token_id: int :return: Tuple of processed sequences, attention mask, and action mask :rtype: Tuple[torch.Tensor, torch.Tensor, torch.Tensor] """ # Process generated sequences to create proper attention and action masks attention_mask = (sequences.ne(eos_token_id) & sequences.ne(pad_token_id)).to(dtype=torch.long) seq_length = attention_mask.size(1) eos_indices = seq_length - attention_mask.long().fliplr().argmax(dim=1, keepdim=True).clamp(min=1) sequences.scatter_(dim=1, index=eos_indices, value=eos_token_id) # For Llama3 and Qwen2 models, there are some eos_tokens in the middle of the prompt. first_token_indices = attention_mask.long().argmax(dim=1, keepdim=True) mask = torch.arange(seq_length).unsqueeze(0).expand(sequences.size(0), -1).to(device=sequences.device) attention_mask = (mask >= first_token_indices) & (mask <= eos_indices).to(dtype=torch.long) # in RL, state_i (current token) + action_i (next token) -> state_i+1 (next token) state_seq = sequences[:, input_len - 1:-1] action_mask = state_seq.ne(eos_token_id) & state_seq.ne(pad_token_id) action_mask[:, 0] = 1 return sequences, attention_mask, action_mask