lightrft.models.actor_al¶
Audio-language actor for reinforcement learning.
Provides the ActorAL (Audio-language) class: an actor that generates text (actions) from audio and text inputs. Supports LoRA, Flash Attention 2, DeepSpeed, sample packing, gradient checkpointing, and MoE.
- class lightrft.models.actor_al.ActorAL(*args: Any, **kwargs: Any)[source]¶
Bases:
ModuleAudio-language actor for RL: generates text (actions) from audio and text inputs.
Supports LoRA, quantization, and distributed training. Can be initialized from a pretrained path or an existing model instance.
- Parameters:
pretrain_or_model (Union[str, nn.Module]) – Path to a pretrained model or an existing model instance.
use_flash_attention_2 (bool) – Whether to utilize Flash Attention 2.0 for improved performance
bf16 (bool) – Enable bfloat16 precision for model computations
lora_rank (int) – Rank for LoRA adaptation (0 disables LoRA)
lora_alpha (int) – Alpha parameter for LoRA scaling
lora_dropout (float) – Dropout rate for LoRA layers
target_modules (Optional[list]) – List of target modules for applying LoRA (auto-detected if None)
ds_config (Optional[dict]) – Configuration for DeepSpeed distributed training
device_map (Optional[dict]) – Device mapping for loading the model onto specific devices
packing_samples (bool) – Whether to pack samples during training for efficiency
Example:
# Initialize with a pretrained model path actor = ActorAL( pretrain_or_model="Qwen/Qwen2-Audio-7B-Instruct", use_flash_attention_2=True, lora_rank=16, lora_alpha=32 ) # Generate responses sequences, attention_mask, action_mask = actor.generate( input_ids=input_tensor, audio_values=audio_features_tensor, max_new_tokens=100 )
- forward(sequences: torch.LongTensor, num_actions: int | list[int] | None = None, attention_mask: torch.Tensor | None = None, pixel_values: torch.Tensor | None = None, image_grid_thw: torch.Tensor | None = None, pixel_values_videos: torch.Tensor | None = None, video_grid_thw: torch.Tensor | None = None, return_output=False, packed_seq_lens: list[int] | None = None, audio_values: torch.Tensor | None = None) torch.Tensor[source]¶
Forward pass to compute action log probabilities for reinforcement learning.
This method processes input sequences and audio information to compute log probabilities of actions (tokens) for RL training. It supports both standard and packed sequence formats and can return either just the action log probabilities or the full model output.
Callers pass preprocessed audio as
audio_values; the pipeline maps from the VLpixel_valuesslot toaudio_valuesfor this actor.- Parameters:
sequences (torch.LongTensor) – Input token sequences
num_actions (Optional[Union[int, list[int]]]) – Number of action tokens to extract log probs for
attention_mask (Optional[torch.Tensor]) – Attention mask for the sequences
pixel_values (Optional[torch.Tensor]) – Unused (VL compatibility; audio pipeline passes audio_values)
image_grid_thw (Optional[torch.Tensor]) – Unused (accepted for pipeline compatibility)
pixel_values_videos (Optional[torch.Tensor]) – Unused (accepted for VL pipeline compatibility)
video_grid_thw (Optional[torch.Tensor]) – Unused (accepted for VL pipeline compatibility)
return_output (bool) – Whether to return the full model output along with log probs
packed_seq_lens (Optional[list[int]]) – Sequence lengths for packed samples
audio_values (Optional[torch.Tensor]) – Preprocessed audio features (mel-spectrogram from pipeline)
- Returns:
Action log probabilities or tuple of (action_log_probs, output) if return_output=True
- Return type:
torch.Tensor
Example:
# Compute action log probabilities for RL training log_probs = actor( sequences=token_sequences, num_actions=10, audio_values=input_features_tensor, ) # Get both log probs and full output log_probs, output = actor( sequences=token_sequences, num_actions=10, audio_values=input_features_tensor, return_output=True, )
- generate(input_ids: torch.Tensor, audio_values: torch.Tensor = None, **kwargs) Tuple[torch.LongTensor, torch.LongTensor] | Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]¶
Generate text sequences based on input text and audio information.
This method performs text generation conditioned on both textual prompts and audio inputs. It handles the generation process with various sampling strategies and returns the generated sequences along with attention masks and action masks for RL training.
- Parameters:
input_ids (torch.Tensor) – Input token IDs representing the text prompt
audio_values (torch.Tensor) – Preprocessed audio features (mel-spectrogram) for Qwen2-Audio
kwargs (dict) – Additional generation parameters (top_k, top_p, temperature, etc.)
- Returns:
Tuple containing generated sequences, attention mask, and action mask
- Return type:
Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]] # noqa
Example:
sequences, attention_mask, action_mask = actor.generate( input_ids=torch.tensor([[1, 2, 3]]), audio_values=audio_features_tensor, max_new_tokens=50, temperature=0.8, do_sample=True )
- gradient_checkpointing_disable()[source]¶
Disable gradient checkpointing to use normal forward/backward computation.
This method restores the default behavior where all intermediate activations are stored during the forward pass for use in the backward pass. This increases memory usage but reduces computation time.
Example:
# Disable gradient checkpointing actor.gradient_checkpointing_disable()
- gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant': False})[source]¶
Enable gradient checkpointing to reduce memory usage during training.
Gradient checkpointing trades compute for memory by recomputing intermediate activations during the backward pass instead of storing them. This is particularly useful for training large audio-language models with limited GPU memory.
- Parameters:
gradient_checkpointing_kwargs (dict) – Additional arguments for gradient checkpointing
Example:
# Enable gradient checkpointing with default settings actor.gradient_checkpointing_enable() # Enable with custom settings actor.gradient_checkpointing_enable({"use_reentrant": True})
- modality = 'audio'¶
- print_trainable_parameters()[source]¶
Print information about trainable parameters in the model.
This method displays the number and percentage of trainable parameters, which is particularly useful when using parameter-efficient methods like LoRA. It helps monitor the efficiency of the fine-tuning approach.
Example:
# Print trainable parameter statistics actor.print_trainable_parameters() # Output: trainable params: 4,194,304 || all params: 7,241,732,096 || trainable%: 0.058
- process_sequences(sequences: torch.Tensor, input_len: int, eos_token_id: int, pad_token_id: int) Tuple[torch.Tensor, torch.Tensor, torch.Tensor][source]¶
Called by trainer/fast_exp_maker.py.
Process generated sequences to create proper attention and action masks.
This method post-processes the generated sequences to ensure proper handling of end-of-sequence tokens and creates masks needed for reinforcement learning training. It handles edge cases like multiple EOS tokens and ensures consistent sequence formatting.
- Parameters:
sequences (torch.Tensor) – Generated token sequences
input_len (int) – Length of the input prompt
eos_token_id (int) – End-of-sequence token ID
pad_token_id (int) – Padding token ID
- Returns:
Tuple of processed sequences, attention mask, and action mask
- Return type:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]