Shortcuts

lightrft.models.srm_vl

class lightrft.models.srm_vl.ScalarRewardModelVL(*args: Any, **kwargs: Any)[source]

Bases: Module

Scalar Reward Model for reinforcement learning applications with vision-language backbones.

This class wraps around a pretrained vision-language model and adds reward heads to produce scalar scores. Reward heads are feed-forward networks that take the hidden states from a specific layer of the backbone model and output scalar values.

Parameters:
  • pretrain_or_model (Union[str, nn.Module]) – Either a string path to a pretrained model or a model instance

  • use_flash_attention_2 (bool) – Whether to utilize Flash Attention 2.0 for improved performance

  • bf16 (bool) – Enable bfloat16 precision for model computations

  • lora_rank (int) – Rank for LoRA adaptation (0 disables LoRA)

  • lora_alpha (int) – Alpha parameter for LoRA scaling

  • lora_dropout (float) – Dropout rate for LoRA layers

  • target_modules (Optional[list]) – List of target modules for applying LoRA (auto-detected if None)

  • ds_config (Optional[Dict]) – Configuration for DeepSpeed distributed training

  • device_map (Optional[Dict]) – Device mapping for loading the model onto specific devices

  • pooling_method (str) – Pooling method for aggregating hidden states (‘attn’ or ‘last’). Default to ‘attn’, which delivers better performance.

  • probing_layer (int) – Index of the layer from which to extract hidden states for reward heads. Default to -1, which means the last layer.

  • scale_for_train (bool) – Whether to scale the scores for training. Default to True. We recommend enabling this for better performance.

  • head_types (list[str]) – List of head types for scoring (e.g., [“preference”, “alignment”]). Default to [“preference”]. Must be consistent with the training data annotations.

forward(sequences: torch.LongTensor, attention_mask: torch.Tensor | None = None, pixel_values: torch.Tensor = None, image_grid_thw: torch.Tensor = None, pixel_values_videos: torch.Tensor = None, video_grid_thw: torch.Tensor = None) Dict[source]

The forward takes sequences and visual content (if exists) as input and returns a dictionary containing reward scores of different heads.

Parameters:
  • sequences (torch.LongTensor) – Input token sequences

  • attention_mask (Optional[torch.Tensor]) – Attention mask for the sequences

  • pixel_values (torch.Tensor) – Preprocessed pixel values of input images

  • image_grid_thw (torch.Tensor) – Image grid dimensions (time, height, width)

  • pixel_values_videos (torch.Tensor) – Preprocessed pixel values of input videos

  • video_grid_thw (torch.Tensor) – Video grid dimensions (time, height, width)

Returns:

A dictionary containing reward scores from different heads

Return type:

Dict

Example:

# Compute reward scores from sequences and visual inputs
# Suppose `reward_model` has two heads: "preference" and "alignment"
scores = reward_model(
    sequences=input_ids,
    attention_mask=attention_mask,
    pixel_values=pixel_values,
    image_grid_thw=image_grid_thw,
    pixel_values_videos=pixel_values_videos,
    video_grid_thw=video_grid_thw,
)
preference_score = scores["preference"]
alignment_score = scores["alignment"]
gradient_checkpointing_disable()[source]

Disable gradient checkpointing to use normal forward/backward computation.

This method restores the default behavior where all intermediate activations are stored during the forward pass for use in the backward pass. This increases memory usage but reduces computation time.

Example:

# Disable gradient checkpointing
actor.gradient_checkpointing_disable()
gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant': False})[source]

Enable gradient checkpointing to reduce memory usage during training.

Gradient checkpointing trades compute for memory by recomputing intermediate activations during the backward pass instead of storing them. This is particularly useful for training large vision-language models with limited GPU memory.

Parameters:

gradient_checkpointing_kwargs (dict) – Additional arguments for gradient checkpointing

Example:

# Enable gradient checkpointing with default settings
actor.gradient_checkpointing_enable()

# Enable with custom settings
actor.gradient_checkpointing_enable({"use_reentrant": True})
print_trainable_parameters()[source]

Print information about trainable parameters in the model.

This method displays the number and percentage of trainable parameters, which is particularly useful when using parameter-efficient methods like LoRA. It helps monitor the efficiency of the fine-tuning approach.

Example:

# Print trainable parameter statistics
actor.print_trainable_parameters()
# Output: trainable params: 4,194,304 || all params: 7,241,732,096 || trainable%: 0.058