# LightRFT Strategy Design Philosophy ## Overview The LightRFT strategy module provides a unified interface for distributed training strategies, enabling seamless switching between different distributed training backends while maintaining a consistent API. This document outlines the design principles, architecture, and usage patterns of the strategy module. ## Core Design Principles ### 1. Abstraction and Unification **Principle**: Provide a unified interface that abstracts away the complexities of different distributed training frameworks. **Implementation**: - All strategies inherit from [`StrategyBase`](lightrft/strategy/strategy_base.py:63) - Common methods like [`backward()`](lightrft/strategy/strategy_base.py:257), [`optimizer_step()`](lightrft/strategy/strategy_base.py:272), and [`save_ckpt()`](lightrft/strategy/strategy_base.py:340) have consistent signatures - Strategy-specific implementations are encapsulated within concrete strategy classes ### 2. Configuration-Driven Design **Principle**: Use typed configuration objects instead of dynamic attribute access for better type safety and code clarity. **Implementation**: - [`StrategyConfig`](lightrft/strategy/config.py:16) dataclass provides typed access to all configuration parameters - Eliminates the need for `getattr(args, "parameter", default)` pattern - Enables IDE autocompletion and static type checking ### 3. Backward Compatibility **Principle**: Maintain compatibility with existing code while introducing improvements. **Implementation**: - [`StrategyConfig.from_args()`](lightrft/strategy/config.py:67) method extracts parameters from legacy argument objects - Original `args` object is preserved for compatibility - [`get_extra_arg()`](lightrft/strategy/config.py:130) method provides access to non-standard parameters ### 4. Testability **Principle**: Enable comprehensive testing without requiring distributed environments. **Implementation**: - [`FakeStrategy`](lightrft/strategy/fake_strategy.py:14) provides a drop-in replacement for testing - All strategy methods have mock implementations for single-process testing - Unit tests verify both functionality and API consistency ## Architecture ### Strategy Hierarchy ``` StrategyBase (ABC) ├── DeepspeedStrategy ├── FSDPV2Strategy └── FakeStrategy (for testing) ``` ### Key Components #### 1. Strategy Factory The [`get_strategy()`](lightrft/strategy/strategy.py:20) function serves as the entry point, automatically selecting the appropriate strategy based on configuration: ```python from lightrft.strategy import get_strategy # Automatically selects DeepSpeed or FSDP based on args.fsdp strategy = get_strategy(args) ``` #### 2. Configuration Management The [`StrategyConfig`](lightrft/strategy/config.py:16) class centralizes all configuration parameters: ```python from lightrft.strategy.config import StrategyConfig config = StrategyConfig.from_args(args) # Access parameters with type safety learning_rate = config.actor_learning_rate use_bf16 = config.bf16 ``` #### 3. Common Interface All strategies implement the same core interface: ```python class StrategyBase(ABC): def setup_distributed(self, timeout=None) -> None: ... def create_optimizer(self, model, **kwargs) -> Optimizer: ... def prepare(self, *models, is_rlhf=False) -> Any: ... def backward(self, loss, model, optimizer, **kwargs) -> None: ... def optimizer_step(self, optimizer, model, scheduler, **kwargs) -> None: ... def save_ckpt(self, model, save_dir, **kwargs) -> None: ... def load_ckpt(self, model, load_dir, **kwargs) -> Any: ... ``` ## Usage Patterns ### 1. Basic Usage ```python from lightrft.strategy import get_strategy # Initialize strategy strategy = get_strategy(args) # Prepare models and optimizers actor, critic, reward_models, initial_model = strategy.prepare_models_and_optimizers( actor, critic, reward_models, initial_model, args, max_steps ) # Training loop for batch in dataloader: loss = compute_loss(batch) strategy.backward(loss, actor, actor_optimizer) strategy.optimizer_step(actor_optimizer, actor, actor_scheduler) ``` ### 2. Configuration-Driven Usage ```python from lightrft.strategy.config import StrategyConfig # Create configuration config = StrategyConfig( seed=42, max_norm=1.0, micro_train_batch_size=4, train_batch_size=32, bf16=True, zero_stage=2 ) # Use configuration to create strategy strategy = get_strategy(config) ``` ### 3. Testing with FakeStrategy ```python from lightrft.strategy import get_fake_strategy # Use fake strategy for testing strategy = get_fake_strategy() # All operations work without distributed environment strategy.setup_distributed() strategy.backward(loss, model, optimizer) strategy.save_ckpt(model, "checkpoints") ``` ## Design Benefits ### 1. Improved Type Safety **Before** (using getattr): ```python seed = getattr(args, "seed", 42) # Type: Any max_norm = getattr(args, "max_norm", 1.0) # Type: Any ``` **After** (using StrategyConfig): ```python config = StrategyConfig.from_args(args) seed = config.seed # Type: int max_norm = config.max_norm # Type: float ``` ### 2. Better Code Organization - Configuration parameters are explicitly defined in [`StrategyConfig`](lightrft/strategy/config.py:16) - Strategy-specific logic is encapsulated in concrete strategy classes - Common functionality is implemented in [`StrategyBase`](lightrft/strategy/strategy_base.py:63) ### 3. Enhanced Testability - [`FakeStrategy`](lightrft/strategy/fake_strategy.py:14) enables testing without distributed setup - Unit tests can verify all strategy functionality - Mock implementations ensure consistent behavior ### 4. Future Extensibility - New strategies can be added by implementing the [`StrategyBase`](lightrft/strategy/strategy_base.py:63) interface - Configuration can be extended without breaking existing code - The factory pattern makes it easy to add new strategy types ## Best Practices ### 1. Configuration Management - Use [`StrategyConfig`](lightrft/strategy/config.py:16) for all parameter access - Avoid direct `getattr` calls on argument objects - Use [`get_extra_arg()`](lightrft/strategy/config.py:130) for non-standard parameters ### 2. Strategy Selection - Use [`get_strategy()`](lightrft/strategy/strategy.py:20) factory function for strategy creation - Let the factory determine the appropriate strategy based on configuration - Use [`FakeStrategy`](lightrft/strategy/fake_strategy.py:14) for testing and development ### 3. Error Handling - Strategies should provide clear error messages for unsupported operations - Use the strategy's [`print()`](lightrft/strategy/strategy_base.py:450) method for logging - Implement proper cleanup in context managers ### 4. Testing - Use [`FakeStrategy`](lightrft/strategy/fake_strategy.py:14) for unit tests - Test both strategy-specific and common functionality - Verify that all strategies implement the required interface ## Conclusion The LightRFT strategy module represents a paradigm shift in distributed training abstraction design, establishing a new standard for flexibility, type safety, and developer experience in RLHF training systems. By embracing a holistic design philosophy centered on abstraction unification and configuration-driven development, the module achieves unprecedented levels of interoperability across diverse distributed training frameworks. ### Foundational Design Innovations **Unified Abstraction Architecture**: The module introduces a sophisticated abstraction layer that seamlessly unifies DeepSpeed, FSDP, and other distributed training backends under a single, consistent API. This architectural breakthrough eliminates the traditional fragmentation between different distributed training approaches, enabling developers to switch between strategies with zero code changes. **Type-Safe Configuration Ecosystem**: By replacing dynamic attribute access patterns with strongly-typed configuration objects through [`StrategyConfig`](lightrft/strategy/config.py:16), the module establishes a new paradigm for configuration management that eliminates runtime errors, enables IDE autocompletion, and provides compile-time safety guarantees. **Intelligent Strategy Selection**: The factory-based approach in [`get_strategy()`](lightrft/strategy/strategy.py:21) implements automatic strategy selection based on configuration parameters, abstracting away the complexity of backend selection while maintaining full control through explicit configuration. ### Advanced Capabilities **Multi-Modal Inference Engine Integration**: The strategy module pioneers unified support for both text-only and multi-modal generation through sophisticated engine management, supporting vLLM and SGLang backends with consistent interfaces for diverse AI workloads. **Comprehensive Testing Infrastructure**: Through [`FakeStrategy`](lightrft/strategy/fake_strategy.py:24), the module enables full-featured testing of distributed training workflows without requiring complex distributed environments, dramatically improving development velocity and test coverage. **Memory-Efficient Training Orchestration**: Advanced features like inference engine sleep/wake cycles, gradient accumulation optimizations, and memory-aware checkpointing demonstrate the module's commitment to resource efficiency at scale. ### Industry Impact This design philosophy has established new best practices for distributed training abstraction, demonstrating that type safety, developer experience, and performance optimization are not mutually exclusive goals. The module's extensible architecture ensures long-term viability, providing a solid foundation for integrating future distributed training technologies while maintaining backward compatibility with existing codebases. The LightRFT strategy module stands as a reference implementation for next-generation distributed training systems, proving that thoughtful abstraction design can simultaneously achieve simplicity for beginners and powerful customization for experts across the entire spectrum of RLHF training requirements.