Shortcuts

LightRFT Models Design Document

Overview

The lightrft/models module provides a comprehensive framework for implementing actor models in reinforcement learning scenarios, specifically designed for language model fine-tuning and human feedback integration. This document outlines the design philosophy, architecture, and implementation details of the models package.

Design Philosophy

1. Modular Architecture

The models package follows a modular design approach that separates concerns and promotes code reusability:

  • Actor Base Classes: Provide foundational functionality for different types of actors

  • Utility Functions: Common operations and helper functions shared across models

  • Model Patches: Specialized adaptations for specific model architectures

2. Flexibility and Extensibility

The design prioritizes flexibility to support various model types and use cases:

  • Support for both text-only and vision-language models

  • Configurable optimization strategies (LoRA, quantization, Flash Attention)

  • Adaptable to different model architectures and sizes

3. Performance Optimization

Built-in optimizations for efficient training and inference:

  • Memory-efficient implementations with gradient checkpointing

  • Support for distributed training with DeepSpeed and FSDP

  • Sample packing for improved batch processing efficiency

Architecture Components

Core Classes

1. ActorText

Purpose: General-purpose actor for text-only language models

Key Features:

  • Supports various causal language model architectures

  • Configurable LoRA adaptation with auto-detection of target modules

  • Flash Attention 2.0 integration for improved performance

Design Decisions:

  • Generic implementation that works with any HuggingFace causal LM

  • Automatic detection of linear modules for LoRA injection

  • Flexible generation parameters with post-processing for RL training

2. ActorVL (Vision-Language)

Purpose: Specialized actor for vision-language models

Key Features:

  • Multi-modal input processing (text + vision)

  • Support for various VL architectures (LLaVA, Qwen2-VL, Qwen2.5-VL, Keye-VL, Qwen3-VL, etc.)

  • Image grid processing for different aspect ratios

  • Specialized handling for different model types

Design Decisions:

  • Separate class to handle the complexity of multi-modal inputs

  • Model-specific adaptations for different VL architectures

  • Flexible pixel value and grid dimension handling

Utility Functions

1. LoRA Configuration (apply_lora_configuration)

Purpose: Centralized LoRA setup and configuration

Design Rationale:

  • Eliminates code duplication across different actor types

  • Provides consistent LoRA configuration across the framework

2. Log Probability Computation (log_probs_from_logits)

Purpose: Efficient computation of log probabilities from model logits

Design Features:

  • Memory-optimized implementation with row-by-row processing

  • Support for different data types (float32, float16, bfloat16)

  • Flash Attention integration for improved performance

  • Automatic fallback for unsupported configurations

3. Position ID Management (reset_position_ids)

Purpose: Handle position IDs for packed sequences

Design Rationale:

  • Essential for sample packing optimization

  • Maintains correct positional encoding across concatenated sequences

  • Supports variable-length sequences in packed format

Design Features:

  • Model-architecture-aware detection

  • Configurable exclusion of specific modules (vision towers, etc.)

  • Support for various model types and architectures

Model Patches

Purpose

The monkey_patch directory contains model-specific adaptations and optimizations:

  • Architecture-specific optimizations: Tailored improvements for specific model architectures

  • Generation method patches: Enhanced generation capabilities

  • Performance optimizations: Model-specific performance improvements

Implementation Details

1. Model Initialization Strategy

The models support two initialization patterns:

Pattern A: From Pretrained Path

actor = ActorText(
    pretrain_or_model="model_path",
    lora_rank=16,
    use_flash_attention_2=True
)

Pattern B: From Existing Model

actor = ActorText(
    pretrain_or_model=existing_model,
    packing_samples=True
)

Design Rationale:

  • Supports both training from scratch and fine-tuning existing models

  • Enables flexible model deployment scenarios

  • Maintains backward compatibility with existing workflows

2. Generation and Forward Pass Design

Generation Method

  • Input Processing: Handles various input formats and parameters

  • Model Generation: Delegates to underlying model with configured parameters

  • Post-processing: Creates attention masks and action masks for RL training

Forward Method

  • Position ID Handling: Manages positional encoding for different sequence formats

  • Log Probability Computation: Efficiently computes action probabilities

  • Packed Sequence Support: Handles multiple sequences in a single batch

3. Memory and Performance Optimizations

Gradient Checkpointing

  • Optional memory-saving technique

  • Configurable via gradient_checkpointing_enable/disable

  • Balances memory usage with computational overhead

Sample Packing

  • Concatenates multiple sequences for efficient batch processing

  • Maintains correct attention patterns through position ID management

  • Significantly improves training throughput for variable-length sequences

Configuration and Customization

1. LoRA Configuration

  • Rank and Alpha: Configurable LoRA dimensions and scaling

  • Target Modules: Automatic detection with manual override capability

  • Dropout: Configurable regularization strength

2. Attention Mechanisms

  • Flash Attention 2.0: Optional high-performance attention implementation

  • Fallback Support: Automatic fallback to standard attention when needed

  • Architecture Compatibility: Works across different model architectures

3. Device and Distributed Training

  • Device Mapping: Flexible device placement for multi-GPU setups

  • DeepSpeed Integration: Native support for DeepSpeed ZeRO optimization

  • FSDP Compatibility: Support for Fully Sharded Data Parallel training

Error Handling and Robustness

1. Graceful Degradation

  • Automatic fallback for unsupported features

  • Clear error messages for configuration issues

  • Compatibility checks for model requirements

2. Validation and Assertions

  • Input validation for critical parameters

  • Assertion checks for incompatible configurations

  • Runtime validation of model compatibility

Conclusion

The LightRFT models package provides a robust, flexible, and efficient foundation for reinforcement learning with language models. The modular design ensures maintainability and extensibility while the comprehensive optimization support enables efficient training and deployment across various hardware configurations and model architectures.

The design prioritizes:

  • Simplicity: Easy to use and understand

  • Flexibility: Adaptable to various use cases

  • Performance: Optimized for efficiency

  • Reliability: Robust error handling and validation

  • Extensibility: Easy to add new features and model types

This architecture serves as a solid foundation for current needs while providing a clear path for future enhancements and adaptations.