LightRFT Strategy Design Philosophy¶
Overview¶
The LightRFT strategy module provides a unified interface for distributed training strategies, enabling seamless switching between different distributed training backends while maintaining a consistent API. This document outlines the design principles, architecture, and usage patterns of the strategy module.
Core Design Principles¶
1. Abstraction and Unification¶
Principle: Provide a unified interface that abstracts away the complexities of different distributed training frameworks.
Implementation:
All strategies inherit from
StrategyBaseCommon methods like
backward(),optimizer_step(), andsave_ckpt()have consistent signaturesStrategy-specific implementations are encapsulated within concrete strategy classes
2. Configuration-Driven Design¶
Principle: Use typed configuration objects instead of dynamic attribute access for better type safety and code clarity.
Implementation:
StrategyConfigdataclass provides typed access to all configuration parametersEliminates the need for
getattr(args, "parameter", default)patternEnables IDE autocompletion and static type checking
3. Backward Compatibility¶
Principle: Maintain compatibility with existing code while introducing improvements.
Implementation:
StrategyConfig.from_args()method extracts parameters from legacy argument objectsOriginal
argsobject is preserved for compatibilityget_extra_arg()method provides access to non-standard parameters
4. Testability¶
Principle: Enable comprehensive testing without requiring distributed environments.
Implementation:
FakeStrategyprovides a drop-in replacement for testingAll strategy methods have mock implementations for single-process testing
Unit tests verify both functionality and API consistency
Architecture¶
Strategy Hierarchy¶
StrategyBase (ABC)
├── DeepspeedStrategy
├── FSDPV2Strategy
└── FakeStrategy (for testing)
Key Components¶
1. Strategy Factory¶
The get_strategy() function serves as the entry point, automatically selecting the appropriate strategy based on configuration:
from lightrft.strategy import get_strategy
# Automatically selects DeepSpeed or FSDP based on args.fsdp
strategy = get_strategy(args)
2. Configuration Management¶
The StrategyConfig class centralizes all configuration parameters:
from lightrft.strategy.config import StrategyConfig
config = StrategyConfig.from_args(args)
# Access parameters with type safety
learning_rate = config.actor_learning_rate
use_bf16 = config.bf16
3. Common Interface¶
All strategies implement the same core interface:
class StrategyBase(ABC):
def setup_distributed(self, timeout=None) -> None: ...
def create_optimizer(self, model, **kwargs) -> Optimizer: ...
def prepare(self, *models, is_rlhf=False) -> Any: ...
def backward(self, loss, model, optimizer, **kwargs) -> None: ...
def optimizer_step(self, optimizer, model, scheduler, **kwargs) -> None: ...
def save_ckpt(self, model, save_dir, **kwargs) -> None: ...
def load_ckpt(self, model, load_dir, **kwargs) -> Any: ...
Usage Patterns¶
1. Basic Usage¶
from lightrft.strategy import get_strategy
# Initialize strategy
strategy = get_strategy(args)
# Prepare models and optimizers
actor, critic, reward_models, initial_model = strategy.prepare_models_and_optimizers(
actor, critic, reward_models, initial_model, args, max_steps
)
# Training loop
for batch in dataloader:
loss = compute_loss(batch)
strategy.backward(loss, actor, actor_optimizer)
strategy.optimizer_step(actor_optimizer, actor, actor_scheduler)
2. Configuration-Driven Usage¶
from lightrft.strategy.config import StrategyConfig
# Create configuration
config = StrategyConfig(
seed=42,
max_norm=1.0,
micro_train_batch_size=4,
train_batch_size=32,
bf16=True,
zero_stage=2
)
# Use configuration to create strategy
strategy = get_strategy(config)
3. Testing with FakeStrategy¶
from lightrft.strategy import get_fake_strategy
# Use fake strategy for testing
strategy = get_fake_strategy()
# All operations work without distributed environment
strategy.setup_distributed()
strategy.backward(loss, model, optimizer)
strategy.save_ckpt(model, "checkpoints")
Design Benefits¶
1. Improved Type Safety¶
Before (using getattr):
seed = getattr(args, "seed", 42) # Type: Any
max_norm = getattr(args, "max_norm", 1.0) # Type: Any
After (using StrategyConfig):
config = StrategyConfig.from_args(args)
seed = config.seed # Type: int
max_norm = config.max_norm # Type: float
2. Better Code Organization¶
Configuration parameters are explicitly defined in
StrategyConfigStrategy-specific logic is encapsulated in concrete strategy classes
Common functionality is implemented in
StrategyBase
3. Enhanced Testability¶
FakeStrategyenables testing without distributed setupUnit tests can verify all strategy functionality
Mock implementations ensure consistent behavior
4. Future Extensibility¶
New strategies can be added by implementing the
StrategyBaseinterfaceConfiguration can be extended without breaking existing code
The factory pattern makes it easy to add new strategy types
Best Practices¶
1. Configuration Management¶
Use
StrategyConfigfor all parameter accessAvoid direct
getattrcalls on argument objectsUse
get_extra_arg()for non-standard parameters
2. Strategy Selection¶
Use
get_strategy()factory function for strategy creationLet the factory determine the appropriate strategy based on configuration
Use
FakeStrategyfor testing and development
3. Error Handling¶
Strategies should provide clear error messages for unsupported operations
Use the strategy’s
print()method for loggingImplement proper cleanup in context managers
4. Testing¶
Use
FakeStrategyfor unit testsTest both strategy-specific and common functionality
Verify that all strategies implement the required interface
Conclusion¶
The LightRFT Strategy module optimizes the abstract design of distributed training, aiming to enhance the flexibility, type safety, and development efficiency of RLHF systems. Through a unified abstraction layer and configuration-driven development, this module achieves interoperability between different training frameworks.
Core Design¶
Unified Interface Architecture: Encapsulates distributed backends like DeepSpeed and FSDP, providing a consistent API. Developers can switch underlying strategies without modifying business code.
Type-Safe Configuration: Converts dynamic configuration into strongly typed objects via
StrategyConfig, reducing runtime errors and supporting IDE autocompletion.Factory Pattern Selection:
get_strategy()automatically instantiates strategies based on configuration parameters, simplifying calls while retaining control over backends.
Feature Highlights¶
Inference Engine Integration: Supports plain text and multimodal generation through a unified interface, compatible with vLLM and SGLang backends.
Convenient Testing Support: Provides
FakeStrategy, allowing testing of training workflows without a distributed environment, reducing debugging costs.Resource Efficiency Optimization: Supports advanced features like inference engine sleep/wake, gradient accumulation, and memory-aware checkpointing to optimize resource usage during large-scale training.
Summary¶
The module strikes a balance between ease of use and flexibility. Its architectural design balances development experience with performance requirements, lowering the barrier to entry for distributed training while providing a scalable technical foundation for diverse scenarios from RLHF to RLVR.