Strategy Usage Guide¶
Overview¶
LightRFT’s strategy module is the distributed training capabilities with additional features for efficient reinforcement learning fine-tuning. The strategy provides a unified interface for managing:
Distributed Training Backends: DeepSpeed ZeRO and FSDP (Fully Sharded Data Parallel)
Inference Engine Integration: vLLM and SGLang for high-throughput generation
Memory Optimization: Optimizer offloading, gradient accumulation, and engine sleep modes
Sequence Parallelism: Efficient handling of long sequences across multiple GPUs
Core API Extensions¶
LightRFT adds the following key methods to the strategy interface:
Method |
Purpose |
|---|---|
|
Initialize vLLM or SGLang inference engine |
|
Synchronize actor model weights to inference engine |
|
Distributed generation with automatic prompt gathering |
|
Load optimizer states from CPU (FSDP only) |
|
Offload optimizer states to CPU (FSDP only) |
|
Wake up inference engine from sleep mode |
|
Put inference engine to sleep to save memory |
Creating a Strategy¶
Basic Setup¶
Use the factory function get_strategy() to create a strategy instance:
from lightrft.strategy import get_strategy
from lightrft.utils import add_arguments
def train(args):
# Create strategy (automatically selects DeepSpeed or FSDP based on args)
strategy = get_strategy(args)
# Setup inference engine for generation
strategy.setup_inference_engine(args, engine_type='vllm')
# Access the engine if needed
vllm_engine = strategy.inference_engine
# Create trainer
trainer = SPMDPPOTrainer(
strategy=strategy,
actor=actor,
critic=critic,
reward_model=reward_model,
initial_model=initial_model,
ema_model=ema_model,
actor_optim=actor_optim,
critic_optim=critic_optim,
actor_scheduler=actor_scheduler,
critic_scheduler=critic_scheduler,
...
)
Strategy Selection¶
The strategy type is automatically determined by configuration arguments:
FSDP: Set
--fsdpflagDeepSpeed: Default when
--fsdpis not set (configurable via--zero_stage)
Using Strategy in Trainers¶
Standard Training Operations¶
The strategy provides standard distributed training operations:
# Backward pass
strategy.backward(loss, model, optimizer)
# Optimizer step with gradient clipping
strategy.optimizer_step(optimizer, model, scheduler, name="actor")
# Distributed communication
averaged_value = strategy.all_reduce(local_value, op="mean")
gathered_values = strategy.all_gather(local_value)
Memory-Optimized Training¶
For FSDP-based training, use optimizer offloading to reduce GPU memory:
def ppo_train(self, global_steps=0):
torch.cuda.synchronize()
train_begin = time.time()
# Load optimizer states from CPU to GPU (FSDP only)
self.strategy.maybe_load_optimizer(self.actor_optim)
# Perform training
train_ret = super().ppo_train(global_steps)
# Offload optimizer states from GPU to CPU (FSDP only)
self.strategy.maybe_offload_optimizer(self.actor_optim)
torch.cuda.synchronize()
self.strategy.print(f"PPO Train TIMECOST {time.time() - train_begin}")
# Synchronize actor weights to inference engine
self.strategy.update_engine_weights(self.actor)
return train_ret
Engine Weight Synchronization¶
After training updates, synchronize model weights to the inference engine:
# Update inference engine with latest actor weights
strategy.update_engine_weights(actor)
This ensures that the inference engine uses the most recent model parameters for generation.
Using Strategy in Experience Makers¶
Text Generation (LLM)¶
Use gather_and_generate() for distributed text generation:
# Tokenize prompts (without padding for efficiency)
all_prompt_token_ids = self.tokenize_fn(
all_prompts,
self.prompt_max_len,
padding=False
)["input_ids"]
# Generate responses with automatic distribution
all_outputs = self.strategy.gather_and_generate(
sampling_params=sampling_params,
all_prompt_token_ids=all_prompt_token_ids,
sleep_engine=True # Automatically sleep engine after generation
)
if dist.get_rank(self.vllm_mp_group) == 0:
self.strategy.print(f"Generated {len(all_outputs)} outputs")
Multimodal Generation (VLM)¶
For vision-language models with images:
# Generate with multimodal inputs
all_outputs = self.strategy.gather_and_generate(
sampling_params=sampling_params,
all_prompts=all_prompts, # Text prompts
all_images=all_images, # Image data
images_num=images_num, # Number of images per prompt
sleep_engine=True
)
How gather_and_generate() Works¶
The method performs the following operations:
Gather: Collects prompts from all ranks within the tensor-parallel group to rank 0
Example: With
world_size=8andengine_tp_size=4, ranks [0,1,2,3] gather to rank 0, and ranks [4,5,6,7] gather to rank 4
Generate: Executes inference using the vLLM/SGLang engine on the gathered prompts
Distribute: Scatters the generated outputs back to the originating ranks in the same order
Sleep Management: Automatically handles engine sleep/wake cycles based on the
sleep_engineparameter
Note
Users don’t need to manually manage engine sleep states when using this interface.
Required Arguments¶
Add LightRFT-specific arguments to your argument parser:
from lightrft.utils import add_arguments
import argparse
# Create parser
parser = argparse.ArgumentParser()
# Add LightRFT arguments
add_arguments(parser)
# Parse arguments
args = parser.parse_args()
Key Arguments¶
Inference Engine Configuration:
--engine_tp_size 4 # Tensor parallelism size for inference engine
--engine_mem_util 0.85 # GPU memory utilization for KV cache (0.0-1.0)
--engine_type vllm # Engine type: 'vllm' or 'sglang'
--enable_engine_sleep # Enable engine sleep mode (default: True)
--disable_engine_sleep # Disable engine sleep mode
Distributed Training:
--fsdp # Use FSDP instead of DeepSpeed
--zero_stage 2 # DeepSpeed ZeRO stage (1, 2, or 3)
--fsdp_cpu_offload # Offload FSDP optimizer states to CPU
--adam_offload # Offload Adam optimizer states
--sp_size 2 # Sequence parallelism size
Training Optimization:
--packing_samples # Pack multiple samples into sequences
--use_mp_opt # Use mixed precision optimizer (FSDP)
--fused_linear_logprob # Fused linear layer and logprob computation
--chunk_size 4096 # Chunk size for fused operations
Monitoring:
--log_dir ./logs # Directory for logs and visualizations
--plot_every 10 # Plot generation length distribution every N steps
Strategy Implementation Details¶
Available Strategies¶
LightRFT provides two main strategy implementations:
DeepspeedStrategy (default)
Uses DeepSpeed ZeRO for memory-efficient training
Configurable ZeRO stages (1, 2, or 3)
Supports gradient accumulation and mixed precision
Best for: General RLHF training, well-established workflows
FSDPV2Strategy (when
--fsdpis set)Uses PyTorch’s Fully Sharded Data Parallel
Supports CPU offloading for optimizer states
Native PyTorch implementation with better integration
Best for: Maximum memory efficiency, PyTorch-native workflows
Strategy Selection Logic¶
# In get_strategy() function
if args.fsdp:
strategy = FSDPV2Strategy(...)
else:
strategy = DeepspeedStrategy(...)
Engine Sleep/Wake Mechanism¶
The strategy provides automatic memory management through engine sleep modes:
# Engine lifecycle management
strategy.setup_inference_engine(args, engine_type='vllm') # Creates and wakes engine
strategy.maybe_sleep_inference_engine() # Sleep to save memory
strategy.wakeup_inference_engine() # Wake for generation
Important
When using gather_and_generate() with sleep_engine=True, the sleep/wake cycle is handled automatically.
Configuration Examples¶
High-Throughput Setup (8 GPUs, DeepSpeed)¶
# Using DeepSpeed ZeRO-2 with large tensor parallelism
python train.py \
--zero_stage 2 \
--engine_tp_size 4 \
--engine_mem_util 0.9 \
--enable_engine_sleep \
--micro_train_batch_size 1 \
--train_batch_size 128
Memory-Efficient Setup (8 GPUs, FSDP with CPU Offload)¶
# Using FSDP with CPU offloading for maximum memory efficiency
python train.py \
--fsdp \
--fsdp_cpu_offload \
--use_mp_opt \
--engine_tp_size 2 \
--engine_mem_util 0.85 \
--enable_engine_sleep \
--micro_train_batch_size 1 \
--train_batch_size 64
Vision-Language Model Setup¶
# Training VLMs with multimodal data
python train_vl.py \
--fsdp \
--engine_tp_size 4 \
--mixed_mm_data \
--packing_samples \
--enable_engine_sleep \
--plot_every 20
Best Practices¶
1. Tensor Parallelism Configuration¶
Set
engine_tp_sizeto match your model size and GPU countFor 7B models:
engine_tp_size=1or2For 13B-70B models:
engine_tp_size=4or8Ensure
world_size % engine_tp_size == 0
2. Memory Management¶
Enable engine sleep mode for memory-constrained setups:
--enable_engine_sleepAdjust
engine_mem_utilbased on available memory (0.5-0.9)Use FSDP with CPU offload for maximum memory savings:
--fsdp --fsdp_cpu_offload
3. Performance Optimization¶
Use
--packing_samplesfor varied sequence lengthsEnable
--fused_linear_logprobfor large vocabulary modelsSet appropriate
micro_train_batch_sizeto saturate GPU utilization
4. Debugging and Monitoring¶
Use
--plot_everywith--log_dirto track generation length distributionMonitor memory with
strategy.report_memory(prefix="checkpoint_name")Check engine status with
strategy.inference_engine_status
Advanced Features¶
Sequence Parallelism¶
Enable sequence parallelism for very long sequences:
# In arguments
--sp_size 4 # Split sequence across 4 GPUs
The strategy automatically creates sequence-parallel groups and handles communication.
Custom Reward Models¶
For multiple reward models or remote reward APIs:
# Multiple reward models
reward_models = [reward_model_1, reward_model_2, reward_model_3]
strategy = get_strategy(args)
# Models are automatically sharded across GPUs
prepared_rms = [strategy.prepare_model(rm, shard_size=8) for rm in reward_models]
Mixed Precision Training¶
Control mixed precision behavior:
# Enable BF16 training
--bf16
# Use mixed precision optimizer (FSDP)
--use_mp_opt
Troubleshooting¶
Common Issues¶
Issue: Out of memory during generation
Solution: Reduce
engine_mem_utilor increaseengine_tp_size
Issue: Engine not updating with new weights
Solution: Ensure
update_engine_weights()is called after training
Issue: Slow generation speed
Solution: Increase
micro_rollout_batch_sizeor reduceengine_tp_size
Issue: FSDP optimizer offload errors
Solution: Verify you’re using FSDP strategy (
--fsdp) and calling offload/load in pairs
API Reference¶
For detailed API documentation, see:
lightrft.strategy.strategy_base.StrategyBase- Base strategy classlightrft.strategy.get_strategy()- Strategy factory functionlightrft.utils.add_arguments()- Argument configuration