Shortcuts

LightRFT Models Design Document

Overview

The lightrft/models module provides a comprehensive framework for implementing actor models in reinforcement learning scenarios, specifically designed for language model fine-tuning and human feedback integration. This document outlines the design philosophy, architecture, and implementation details of the models package.

Design Philosophy

1. Modular Architecture

The models package follows a modular design approach that separates concerns and promotes code reusability:

  • Actor Base Classes: Provide foundational functionality for different types of actors

  • Reward Model Base Classes: Provide foundational functionality for different types of reward models

  • Utility Functions: Common operations and helper functions shared across models

  • Model Patches: Specialized adaptations for specific model architectures

2. Flexibility and Extensibility

The design prioritizes flexibility to support various model types and use cases:

  • Support for both text-only and vision-language models

  • Configurable optimization strategies (LoRA, quantization, Flash Attention)

  • Adaptable to different model architectures and sizes

3. Performance Optimization

Built-in optimizations for efficient training and inference:

  • Memory-efficient implementations with gradient checkpointing

  • Support for distributed training with DeepSpeed and FSDP

  • Sample packing for improved batch processing efficiency

Architecture Components

Core Classes

1. ActorModality (Modality Definition)

Purpose: Located in models/actor_modality.py, it defines the modality types for actor models and manages the parameters supported by each modality.

Key Features:

  • Categorization: Defines model types via the ActorModality enum (e.g., LANGUAGE_ONLY, VISION_LANGUAGE, AUDIO_LANGUAGE, OMNI).

  • Parameter Mapping: The MODALITY_PARAMETERS dictionary defines special parameters supported by each modality (e.g., pixel_values for vision models, audio_values for audio models).

  • Decoupled Design: The Trainer dynamically retrieves required parameters via the get_supported_parameters interface, decoupling training logic from specific model input formats.

2. ActorLanguage

Purpose: General-purpose actor for text-only language models.

Key Features:

  • Text-only Support: Modality is explicitly declared as ActorModality.LANGUAGE_ONLY.

  • Wide Compatibility: Supports most Causal Language Model architectures available on HuggingFace.

  • Performance Optimization: Simple LoRA injection with auto-detection and Flash Attention 2.0 integration.

3. ActorVL (Vision-Language)

Purpose: Specialized actor for vision-language models, handling images, videos, and multi-modal inputs.

Key Features:

  • Multi-modal Capability: Modality declared as ActorModality.VISION_LANGUAGE.

  • Architecture Adaptation: Supports various VLM architectures including Qwen2-VL and Qwen2.5-VL.

  • Input Handling: Manages image grids and variable-length visual sequences internally.

4. ActorAL (Audio-Language)

Purpose: Specialized actor for audio-language models (e.g., Qwen2-Audio) with ActorModality.AUDIO_LANGUAGE, supporting audio capture and processing.

Key Features:

  • Multi-modal Capability: Modality declared as ActorModality.AUDIO_LANGUAGE, processing combined audio and textual inputs.

  • Architecture Adaptation: Supports various audio-language model architectures like Qwen2-Audio.

  • Optimization & MoE: Supports Mixture of Experts (MoE) models, memory optimization through gradient checkpointing, and efficiency via LoRA and Flash Attention.

5. Reward Models

Purpose: Scalar (SRM) or generative (GRM) reward models for evaluating response quality.

Key Classes:

  • ScalarRewardModelVL/AL: Scalar Reward Models (SRM) mapping multimodal inputs to scalar scores. Supports Bradley-Terry preference loss.

  • GenerativeRewardModelVL: Generative Reward Models (GRM) leveraging generation capabilities to output text-based evaluations with reasoning (CoT).

Example Behavior:

# Common Input
system_prompt = "You are a helpful visual assistant."
image = "path/to/dog.jpg" # Visual input
query = "What is shown in this image?"
response = "The image shows a cute brown dog playing in the park."

# 1. Scalar Reward Model
srm_score = srm(image=image, system_prompt=system_prompt, query=query, response=response)
print(srm_score) 
# Output: tensor([0.88]) -> Yields a direct scalar float value

# 2. Generative Reward Model
grm_output = grm(image=image, system_prompt=system_prompt, query=query, response=response)
print(grm_output)
# Output: "<reasoning>The response accurately describes the content of the image.</reasoning><score>5</score>" -> Yields textual reasoning and score

Utility Functions

1. LoRA Configuration (apply_lora_configuration)

Purpose: Centralized LoRA setup and configuration

Design Rationale:

  • Eliminates code duplication across different actor types

  • Provides consistent LoRA configuration across the framework

2. Log Probability Computation (log_probs_from_logits)

Purpose: Efficient computation of log probabilities from model logits

Design Features:

  • Memory-optimized implementation with row-by-row processing

  • Support for different data types (float32, float16, bfloat16)

  • Flash Attention integration for improved performance

  • Automatic fallback for unsupported configurations

3. Position ID Management (reset_position_ids)

Purpose: Handle position IDs for packed sequences

Design Rationale:

  • Essential for sample packing optimization

  • Maintains correct positional encoding across concatenated sequences

  • Supports variable-length sequences in packed format

Design Features:

  • Model-architecture-aware detection

  • Configurable exclusion of specific modules (vision towers, etc.)

  • Support for various model types and architectures

Model Patches

Purpose

The monkey_patch directory contains model-specific adaptations and optimizations:

  • Architecture-specific optimizations: Tailored improvements for specific model architectures

  • Generation method patches: Enhanced generation capabilities

  • Performance optimizations: Model-specific performance improvements

Implementation Details

1. Model Initialization Strategy

The models support two initialization patterns:

Pattern A: From Pretrained Path

actor = ActorText(
    pretrain_or_model="model_path",
    lora_rank=16,
    use_flash_attention_2=True
)

Pattern B: From Existing Model

actor = ActorText(
    pretrain_or_model=existing_model,
    packing_samples=True
)

Design Rationale:

  • Supports both training from scratch and fine-tuning existing models

  • Enables flexible model deployment scenarios

  • Maintains backward compatibility with existing workflows

2. Generation and Forward Pass Design

Generation Method

  • Input Processing: Handles various input formats and parameters

  • Model Generation: Delegates to underlying model with configured parameters

  • Post-processing: Creates attention masks and action masks for RL training

Forward Method

  • Position ID Handling: Manages positional encoding for different sequence formats

  • Log Probability Computation: Efficiently computes action probabilities

  • Packed Sequence Support: Handles multiple sequences in a single batch

3. Memory and Performance Optimizations

Gradient Checkpointing

  • Optional memory-saving technique

  • Configurable via gradient_checkpointing_enable/disable

  • Balances memory usage with computational overhead

Sample Packing

  • Concatenates multiple sequences for efficient batch processing

  • Maintains correct attention patterns through position ID management

  • Significantly improves training throughput for variable-length sequences

Configuration and Customization

1. LoRA Configuration

  • Rank and Alpha: Configurable LoRA dimensions and scaling

  • Target Modules: Automatic detection with manual override capability

  • Dropout: Configurable regularization strength

2. Attention Mechanisms

  • Flash Attention 2.0: Optional high-performance attention implementation

  • Fallback Support: Automatic fallback to standard attention when needed

  • Architecture Compatibility: Works across different model architectures

3. Device and Distributed Training

  • Device Mapping: Flexible device placement for multi-GPU setups

  • DeepSpeed Integration: Native support for DeepSpeed ZeRO optimization

  • FSDP Compatibility: Support for Fully Sharded Data Parallel training

Error Handling and Robustness

1. Graceful Degradation

  • Automatic fallback for unsupported features

  • Clear error messages for configuration issues

  • Compatibility checks for model requirements

2. Validation and Assertions

  • Input validation for critical parameters

  • Assertion checks for incompatible configurations

  • Runtime validation of model compatibility

Conclusion

The LightRFT models package provides a robust, flexible, and efficient foundation for reinforcement learning with language models. The modular design ensures maintainability and extensibility while the comprehensive optimization support enables efficient training and deployment across various hardware configurations and model architectures.

The design prioritizes:

  • Simplicity: Easy to use and understand

  • Flexibility: Adaptable to various use cases

  • Performance: Optimized for efficiency

  • Reliability: Robust error handling and validation

  • Extensibility: Easy to add new features and model types

This architecture serves as a solid foundation for current needs while providing a clear path for future enhancements and adaptations.