LightRFT Strategy Design Philosophy¶
Overview¶
The LightRFT strategy module provides a unified interface for distributed training strategies, enabling seamless switching between different distributed training backends while maintaining a consistent API. This document outlines the design principles, architecture, and usage patterns of the strategy module.
Core Design Principles¶
1. Abstraction and Unification¶
Principle: Provide a unified interface that abstracts away the complexities of different distributed training frameworks.
Implementation:
All strategies inherit from
StrategyBaseCommon methods like
backward(),optimizer_step(), andsave_ckpt()have consistent signaturesStrategy-specific implementations are encapsulated within concrete strategy classes
2. Configuration-Driven Design¶
Principle: Use typed configuration objects instead of dynamic attribute access for better type safety and code clarity.
Implementation:
StrategyConfigdataclass provides typed access to all configuration parametersEliminates the need for
getattr(args, "parameter", default)patternEnables IDE autocompletion and static type checking
3. Backward Compatibility¶
Principle: Maintain compatibility with existing code while introducing improvements.
Implementation:
StrategyConfig.from_args()method extracts parameters from legacy argument objectsOriginal
argsobject is preserved for compatibilityget_extra_arg()method provides access to non-standard parameters
4. Testability¶
Principle: Enable comprehensive testing without requiring distributed environments.
Implementation:
FakeStrategyprovides a drop-in replacement for testingAll strategy methods have mock implementations for single-process testing
Unit tests verify both functionality and API consistency
Architecture¶
Strategy Hierarchy¶
StrategyBase (ABC)
├── DeepspeedStrategy
├── FSDPV2Strategy
└── FakeStrategy (for testing)
Key Components¶
1. Strategy Factory¶
The get_strategy() function serves as the entry point, automatically selecting the appropriate strategy based on configuration:
from lightrft.strategy import get_strategy
# Automatically selects DeepSpeed or FSDP based on args.fsdp
strategy = get_strategy(args)
2. Configuration Management¶
The StrategyConfig class centralizes all configuration parameters:
from lightrft.strategy.config import StrategyConfig
config = StrategyConfig.from_args(args)
# Access parameters with type safety
learning_rate = config.actor_learning_rate
use_bf16 = config.bf16
3. Common Interface¶
All strategies implement the same core interface:
class StrategyBase(ABC):
def setup_distributed(self, timeout=None) -> None: ...
def create_optimizer(self, model, **kwargs) -> Optimizer: ...
def prepare(self, *models, is_rlhf=False) -> Any: ...
def backward(self, loss, model, optimizer, **kwargs) -> None: ...
def optimizer_step(self, optimizer, model, scheduler, **kwargs) -> None: ...
def save_ckpt(self, model, save_dir, **kwargs) -> None: ...
def load_ckpt(self, model, load_dir, **kwargs) -> Any: ...
Usage Patterns¶
1. Basic Usage¶
from lightrft.strategy import get_strategy
# Initialize strategy
strategy = get_strategy(args)
# Prepare models and optimizers
actor, critic, reward_models, initial_model = strategy.prepare_models_and_optimizers(
actor, critic, reward_models, initial_model, args, max_steps
)
# Training loop
for batch in dataloader:
loss = compute_loss(batch)
strategy.backward(loss, actor, actor_optimizer)
strategy.optimizer_step(actor_optimizer, actor, actor_scheduler)
2. Configuration-Driven Usage¶
from lightrft.strategy.config import StrategyConfig
# Create configuration
config = StrategyConfig(
seed=42,
max_norm=1.0,
micro_train_batch_size=4,
train_batch_size=32,
bf16=True,
zero_stage=2
)
# Use configuration to create strategy
strategy = get_strategy(config)
3. Testing with FakeStrategy¶
from lightrft.strategy import get_fake_strategy
# Use fake strategy for testing
strategy = get_fake_strategy()
# All operations work without distributed environment
strategy.setup_distributed()
strategy.backward(loss, model, optimizer)
strategy.save_ckpt(model, "checkpoints")
Design Benefits¶
1. Improved Type Safety¶
Before (using getattr):
seed = getattr(args, "seed", 42) # Type: Any
max_norm = getattr(args, "max_norm", 1.0) # Type: Any
After (using StrategyConfig):
config = StrategyConfig.from_args(args)
seed = config.seed # Type: int
max_norm = config.max_norm # Type: float
2. Better Code Organization¶
Configuration parameters are explicitly defined in
StrategyConfigStrategy-specific logic is encapsulated in concrete strategy classes
Common functionality is implemented in
StrategyBase
3. Enhanced Testability¶
FakeStrategyenables testing without distributed setupUnit tests can verify all strategy functionality
Mock implementations ensure consistent behavior
4. Future Extensibility¶
New strategies can be added by implementing the
StrategyBaseinterfaceConfiguration can be extended without breaking existing code
The factory pattern makes it easy to add new strategy types
Best Practices¶
1. Configuration Management¶
Use
StrategyConfigfor all parameter accessAvoid direct
getattrcalls on argument objectsUse
get_extra_arg()for non-standard parameters
2. Strategy Selection¶
Use
get_strategy()factory function for strategy creationLet the factory determine the appropriate strategy based on configuration
Use
FakeStrategyfor testing and development
3. Error Handling¶
Strategies should provide clear error messages for unsupported operations
Use the strategy’s
print()method for loggingImplement proper cleanup in context managers
4. Testing¶
Use
FakeStrategyfor unit testsTest both strategy-specific and common functionality
Verify that all strategies implement the required interface
Conclusion¶
The LightRFT strategy module represents a paradigm shift in distributed training abstraction design, establishing a new standard for flexibility, type safety, and developer experience in RLHF training systems. By embracing a holistic design philosophy centered on abstraction unification and configuration-driven development, the module achieves unprecedented levels of interoperability across diverse distributed training frameworks.
Foundational Design Innovations¶
Unified Abstraction Architecture: The module introduces a sophisticated abstraction layer that seamlessly unifies DeepSpeed, FSDP, and other distributed training backends under a single, consistent API. This architectural breakthrough eliminates the traditional fragmentation between different distributed training approaches, enabling developers to switch between strategies with zero code changes.
Type-Safe Configuration Ecosystem: By replacing dynamic attribute access patterns with strongly-typed configuration objects through StrategyConfig, the module establishes a new paradigm for configuration management that eliminates runtime errors, enables IDE autocompletion, and provides compile-time safety guarantees.
Intelligent Strategy Selection: The factory-based approach in get_strategy() implements automatic strategy selection based on configuration parameters, abstracting away the complexity of backend selection while maintaining full control through explicit configuration.
Advanced Capabilities¶
Multi-Modal Inference Engine Integration: The strategy module pioneers unified support for both text-only and multi-modal generation through sophisticated engine management, supporting vLLM and SGLang backends with consistent interfaces for diverse AI workloads.
Comprehensive Testing Infrastructure: Through FakeStrategy, the module enables full-featured testing of distributed training workflows without requiring complex distributed environments, dramatically improving development velocity and test coverage.
Memory-Efficient Training Orchestration: Advanced features like inference engine sleep/wake cycles, gradient accumulation optimizations, and memory-aware checkpointing demonstrate the module’s commitment to resource efficiency at scale.
Industry Impact¶
This design philosophy has established new best practices for distributed training abstraction, demonstrating that type safety, developer experience, and performance optimization are not mutually exclusive goals. The module’s extensible architecture ensures long-term viability, providing a solid foundation for integrating future distributed training technologies while maintaining backward compatibility with existing codebases.
The LightRFT strategy module stands as a reference implementation for next-generation distributed training systems, proving that thoughtful abstraction design can simultaneously achieve simplicity for beginners and powerful customization for experts across the entire spectrum of RLHF training requirements.