Shortcuts

lightrft.models.srm_al

class lightrft.models.srm_al.ScalarRewardModelAL(*args: Any, **kwargs: Any)[source]

Bases: Module

Scalar Reward Model for reinforcement learning applications with audio-language backbones.

This class wraps around a pretrained audio-language model and adds reward heads to produce scalar scores. Reward heads are feed-forward networks that take the hidden states from a specific layer of the backbone model and output scalar values.

Parameters:
  • pretrain_or_model (Union[str, nn.Module]) – Either a string path to a pretrained model or a model instance

  • use_flash_attention_2 (bool) – Whether to utilize Flash Attention 2.0 for improved performance

  • bf16 (bool) – Enable bfloat16 precision for model computations

  • lora_rank (int) – Rank for LoRA adaptation (0 disables LoRA)

  • lora_alpha (int) – Alpha parameter for LoRA scaling

  • lora_dropout (float) – Dropout rate for LoRA layers

  • target_modules (Optional[list]) – List of target modules for applying LoRA (auto-detected if None)

  • ds_config (Optional[Dict]) – Configuration for DeepSpeed distributed training

  • device_map (Optional[Dict]) – Device mapping for loading the model onto specific devices

  • pooling_method (str) – Pooling method for aggregating hidden states (‘attn’ or ‘last’). Default to ‘attn’, which delivers best performance.

  • probing_layer (int) – Index of the layer from which to extract hidden states for reward heads. Default to -1, which means the last layer.

  • scale_for_train (bool) – Whether to scale the scores for training. Default to True. We recommend enabling this for better performance.

  • head_types (list[str]) – List of head types for scoring (e.g., [“preference”, “alignment”]). Default to [“preference”]. Must be consistent with the training data annotations.

forward(sequences: torch.LongTensor, attention_mask: torch.Tensor | None = None, input_features: torch.Tensor | None = None, feature_attention_mask: torch.Tensor | None = None) Dict[source]

The forward takes sequences and audio content as input and returns a dictionary containing reward scores of different heads.

Parameters:
  • sequences (torch.LongTensor) – Input token sequences

  • attention_mask (Optional[torch.Tensor]) – Attention mask for the sequences

  • input_features (torch.Tensor) – Preprocessed audio features of input audio

  • feature_attention_mask (Optional[torch.Tensor]) – Attention mask for the audio features

Returns:

A dictionary containing reward scores from different heads

Return type:

Dict

Example:

# Compute reward scores from sequences and audio inputs
# Suppose `reward_model` has two heads: "preference" and "alignment"
scores = reward_model(
    sequences=input_ids,
    attention_mask=attention_mask,
    input_features=input_features,
    feature_attention_mask=feature_attention_mask
)
preference_score = scores["preference"]
alignment_score = scores["alignment"]
gradient_checkpointing_disable()[source]

Disable gradient checkpointing to use normal forward/backward computation.

This method restores the default behavior where all intermediate activations are stored during the forward pass for use in the backward pass. This increases memory usage but reduces computation time.

Example:

# Disable gradient checkpointing
actor.gradient_checkpointing_disable()
gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant': False})[source]

Enable gradient checkpointing to reduce memory usage during training.

Gradient checkpointing trades compute for memory by recomputing intermediate activations during the backward pass instead of storing them. This is particularly useful for training large audio-language models with limited GPU memory.

Parameters:

gradient_checkpointing_kwargs (dict) – Additional arguments for gradient checkpointing

Example:

# Enable gradient checkpointing with default settings
actor.gradient_checkpointing_enable()

# Enable with custom settings
actor.gradient_checkpointing_enable({"use_reentrant": True})
print_trainable_parameters()[source]

Print information about trainable parameters in the model.

This method displays the number and percentage of trainable parameters, which is particularly useful when using parameter-efficient methods like LoRA. It helps monitor the efficiency of the fine-tuning approach.

Example:

# Print trainable parameter statistics
actor.print_trainable_parameters()
# Output: trainable params: 4,194,304 || all params: 7,241,732,096 || trainable%: 0.058