Shortcuts

LightRFT Strategy Design Philosophy

Overview

The LightRFT strategy module provides a unified interface for distributed training strategies, enabling seamless switching between different distributed training backends while maintaining a consistent API. This document outlines the design principles, architecture, and usage patterns of the strategy module.

Core Design Principles

1. Abstraction and Unification

Principle: Provide a unified interface that abstracts away the complexities of different distributed training frameworks.

Implementation:

2. Configuration-Driven Design

Principle: Use typed configuration objects instead of dynamic attribute access for better type safety and code clarity.

Implementation:

  • StrategyConfig dataclass provides typed access to all configuration parameters

  • Eliminates the need for getattr(args, "parameter", default) pattern

  • Enables IDE autocompletion and static type checking

3. Backward Compatibility

Principle: Maintain compatibility with existing code while introducing improvements.

Implementation:

4. Testability

Principle: Enable comprehensive testing without requiring distributed environments.

Implementation:

  • FakeStrategy provides a drop-in replacement for testing

  • All strategy methods have mock implementations for single-process testing

  • Unit tests verify both functionality and API consistency

Architecture

Strategy Hierarchy

StrategyBase (ABC)
├── DeepspeedStrategy
├── FSDPV2Strategy  
└── FakeStrategy (for testing)

Key Components

1. Strategy Factory

The get_strategy() function serves as the entry point, automatically selecting the appropriate strategy based on configuration:

from lightrft.strategy import get_strategy

# Automatically selects DeepSpeed or FSDP based on args.fsdp
strategy = get_strategy(args)

2. Configuration Management

The StrategyConfig class centralizes all configuration parameters:

from lightrft.strategy.config import StrategyConfig

config = StrategyConfig.from_args(args)
# Access parameters with type safety
learning_rate = config.actor_learning_rate
use_bf16 = config.bf16

3. Common Interface

All strategies implement the same core interface:

class StrategyBase(ABC):
    def setup_distributed(self, timeout=None) -> None: ...
    def create_optimizer(self, model, **kwargs) -> Optimizer: ...
    def prepare(self, *models, is_rlhf=False) -> Any: ...
    def backward(self, loss, model, optimizer, **kwargs) -> None: ...
    def optimizer_step(self, optimizer, model, scheduler, **kwargs) -> None: ...
    def save_ckpt(self, model, save_dir, **kwargs) -> None: ...
    def load_ckpt(self, model, load_dir, **kwargs) -> Any: ...

Usage Patterns

1. Basic Usage

from lightrft.strategy import get_strategy

# Initialize strategy
strategy = get_strategy(args)

# Prepare models and optimizers
actor, critic, reward_models, initial_model = strategy.prepare_models_and_optimizers(
    actor, critic, reward_models, initial_model, args, max_steps
)

# Training loop
for batch in dataloader:
    loss = compute_loss(batch)
    strategy.backward(loss, actor, actor_optimizer)
    strategy.optimizer_step(actor_optimizer, actor, actor_scheduler)

2. Configuration-Driven Usage

from lightrft.strategy.config import StrategyConfig

# Create configuration
config = StrategyConfig(
    seed=42,
    max_norm=1.0,
    micro_train_batch_size=4,
    train_batch_size=32,
    bf16=True,
    zero_stage=2
)

# Use configuration to create strategy
strategy = get_strategy(config)

3. Testing with FakeStrategy

from lightrft.strategy import get_fake_strategy

# Use fake strategy for testing
strategy = get_fake_strategy()

# All operations work without distributed environment
strategy.setup_distributed()
strategy.backward(loss, model, optimizer)
strategy.save_ckpt(model, "checkpoints")

Design Benefits

1. Improved Type Safety

Before (using getattr):

seed = getattr(args, "seed", 42)  # Type: Any
max_norm = getattr(args, "max_norm", 1.0)  # Type: Any

After (using StrategyConfig):

config = StrategyConfig.from_args(args)
seed = config.seed  # Type: int
max_norm = config.max_norm  # Type: float

2. Better Code Organization

  • Configuration parameters are explicitly defined in StrategyConfig

  • Strategy-specific logic is encapsulated in concrete strategy classes

  • Common functionality is implemented in StrategyBase

3. Enhanced Testability

  • FakeStrategy enables testing without distributed setup

  • Unit tests can verify all strategy functionality

  • Mock implementations ensure consistent behavior

4. Future Extensibility

  • New strategies can be added by implementing the StrategyBase interface

  • Configuration can be extended without breaking existing code

  • The factory pattern makes it easy to add new strategy types

Best Practices

1. Configuration Management

2. Strategy Selection

  • Use get_strategy() factory function for strategy creation

  • Let the factory determine the appropriate strategy based on configuration

  • Use FakeStrategy for testing and development

3. Error Handling

  • Strategies should provide clear error messages for unsupported operations

  • Use the strategy’s print() method for logging

  • Implement proper cleanup in context managers

4. Testing

  • Use FakeStrategy for unit tests

  • Test both strategy-specific and common functionality

  • Verify that all strategies implement the required interface

Conclusion

The LightRFT Strategy module optimizes the abstract design of distributed training, aiming to enhance the flexibility, type safety, and development efficiency of RLHF systems. Through a unified abstraction layer and configuration-driven development, this module achieves interoperability between different training frameworks.

Core Design

  • Unified Interface Architecture: Encapsulates distributed backends like DeepSpeed and FSDP, providing a consistent API. Developers can switch underlying strategies without modifying business code.

  • Type-Safe Configuration: Converts dynamic configuration into strongly typed objects via StrategyConfig, reducing runtime errors and supporting IDE autocompletion.

  • Factory Pattern Selection: get_strategy() automatically instantiates strategies based on configuration parameters, simplifying calls while retaining control over backends.

Feature Highlights

  • Inference Engine Integration: Supports plain text and multimodal generation through a unified interface, compatible with vLLM and SGLang backends.

  • Convenient Testing Support: Provides FakeStrategy, allowing testing of training workflows without a distributed environment, reducing debugging costs.

  • Resource Efficiency Optimization: Supports advanced features like inference engine sleep/wake, gradient accumulation, and memory-aware checkpointing to optimize resource usage during large-scale training.

Summary

The module strikes a balance between ease of use and flexibility. Its architectural design balances development experience with performance requirements, lowering the barrier to entry for distributed training while providing a scalable technical foundation for diverse scenarios from RLHF to RLVR.