Shortcuts

lightrft.models.grm_vl

class lightrft.models.grm_vl.GenerativeRewardModelVL(*args: Any, **kwargs: Any)[source]

Bases: Module

Generative reward model for reinforcement learning applications.

This class wraps a pretrained vision-language model to serve as a generative reward model, which are capable of processing both textual and visual inputs, and generating textual outputs.

Parameters:
  • pretrain_or_model (Union[str, nn.Module]) – Either a string path to a pretrained model or a model instance

  • use_flash_attention_2 (bool) – Whether to utilize Flash Attention 2.0 for improved performance

  • bf16 (bool) – Enable bfloat16 precision for model computations

  • lora_rank (int) – Rank for LoRA adaptation (0 disables LoRA)

  • lora_alpha (int) – Alpha parameter for LoRA scaling

  • lora_dropout (float) – Dropout rate for LoRA layers

  • target_modules (Optional[list]) – List of target modules for applying LoRA (auto-detected if None)

  • ds_config (Optional[dict]) – Configuration for DeepSpeed distributed training

  • device_map (Optional[dict]) – Device mapping for loading the model onto specific devices

forward(sequences: torch.LongTensor, attention_mask: torch.Tensor | None = None, labels: torch.Tensor | None = None, pixel_values: torch.Tensor = None, image_grid_thw: torch.Tensor = None, pixel_values_videos: torch.Tensor = None, video_grid_thw: torch.Tensor = None, return_outputs: bool = True) transformers.modeling_outputs.ModelOutput | torch.Tensor[source]

The forward pass takes sequences and visual content (if exists) as input and returns the model output.

Parameters:
  • sequences (torch.LongTensor) – Input token sequences

  • attention_mask (Optional[torch.Tensor]) – Attention mask for the sequences

  • labels (Optional[torch.Tensor]) – Target labels for computing loss

  • pixel_values (torch.Tensor) – Preprocessed pixel values of input images

  • image_grid_thw (torch.Tensor) – Image grid dimensions (time, height, width)

  • pixel_values_videos (torch.Tensor) – Preprocessed pixel values of input videos

  • video_grid_thw (torch.Tensor) – Video grid dimensions (time, height, width)

  • return_output (bool) – Whether to return the full model output along with log probs

Returns:

Model output or logits based on return_outputs flag

Return type:

Union[ModelOutput, torch.Tensor]

Example:

# Coumpute logits from sequences and visual inputs
logits = reward_model(
    sequences=input_ids,
    attention_mask=attention_mask,
    pixel_values=pixel_values,
    image_grid_thw=image_grid_thw,
    pixel_values_videos=pixel_values_videos,
    video_grid_thw=video_grid_thw,
    return_outputs=False
)

# Get full model output including loss
outputs = reward_model(
    sequences=input_ids,
    attention_mask=attention_mask,
    labels=labels,
    pixel_values=pixel_values,
    image_grid_thw=image_grid_thw,
    pixel_values_videos=pixel_values_videos,
    video_grid_thw=video_grid_thw,
    return_outputs=True
)
generate(input_ids: torch.Tensor, attention_mask: torch.Tensor | None = None, pixel_values: torch.Tensor = None, image_grid_thw: torch.Tensor = None, pixel_values_videos: torch.Tensor = None, video_grid_thw: torch.Tensor = None, max_new_tokens: int | None = None, do_sample: bool = True, temperature: float = 1.0, top_k: int | None = None, top_p: float | None = None, num_beams: int = 1, **kwargs) transformers.modeling_outputs.ModelOutput | torch.LongTensor

Generate text sequences based on input text and visual information.

Parameters:
  • input_ids (torch.Tensor) – Input token IDs representing the text prompt

  • attention_mask (Optional[torch.Tensor]) – Attention mask for the sequences

  • pixel_values (torch.Tensor) – Preprocessed pixel values of input images

  • image_grid_thw (torch.Tensor) – Image grid dimensions (time, height, width)

  • pixel_values_videos (torch.Tensor) – Preprocessed pixel values of input videos

  • video_grid_thw (torch.Tensor) – Video grid dimensions (time, height, width)

  • max_new_tokens (Optional[int]) – Maximum number of new tokens to generate

  • do_sample (bool) – Whether to use sampling for generation

  • temperature (float) – The value used to module the next token probabilities

  • top_k (Optional[int]) – The number of highest probability vocabulary tokens to keep for top-k-filtering

  • top_p (Optional[float]) – If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation

  • num_beams (int) – Number of beams for beam search

  • kwargs (dict) – Additional generation parameters

Returns:

Generated sequences

Return type:

Union[ModelOutput, torch.LongTensor]

gradient_checkpointing_disable()[source]

Disable gradient checkpointing to use normal forward/backward computation.

This method restores the default behavior where all intermediate activations are stored during the forward pass for use in the backward pass. This increases memory usage but reduces computation time.

Example:

# Disable gradient checkpointing
actor.gradient_checkpointing_disable()
gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant': False})[source]

Enable gradient checkpointing to reduce memory usage during training.

Gradient checkpointing trades compute for memory by recomputing intermediate activations during the backward pass instead of storing them. This is particularly useful for training large vision-language models with limited GPU memory.

Parameters:

gradient_checkpointing_kwargs (dict) – Additional arguments for gradient checkpointing

Example:

# Enable gradient checkpointing with default settings
actor.gradient_checkpointing_enable()

# Enable with custom settings
actor.gradient_checkpointing_enable({"use_reentrant": True})
print_trainable_parameters()[source]

Print information about trainable parameters in the model.

This method displays the number and percentage of trainable parameters, which is particularly useful when using parameter-efficient methods like LoRA. It helps monitor the efficiency of the fine-tuning approach.

Example:

# Print trainable parameter statistics
actor.print_trainable_parameters()
# Output: trainable params: 4,194,304 || all params: 7,241,732,096 || trainable%: 0.058