Source code for lightrft.strategy.config
"""
Configuration dataclasses for LightRFT strategy module.
This module provides typed configuration objects to replace the use of getattr
for accessing configuration parameters, improving type safety and code clarity.
"""
from dataclasses import dataclass, field
from typing import Optional, Any, Dict
import torch
import dataclasses
[docs]@dataclass
class StrategyConfig:
"""Base configuration for all training strategies."""
# Basic training parameters
# (int): Random seed, defaults to 42
seed: int = 42
# (float): Maximum gradient norm for clipping, defaults to 1.0
max_norm: float = 1.0
# (int): Micro batch size for training, defaults to 1
micro_train_batch_size: int = 1
# (int): Training batch size, defaults to 128
train_batch_size: int = 128
# (bool): Use bfloat16 precision, defaults to True
bf16: bool = True
# DeepSpeed specific
# (int): DeepSpeed Zero optimization stage, defaults to 2
zero_stage: int = 2
# FSDP specific
# (bool): Use FSDP (Fully Sharded Data Parallel), defaults to False
fsdp: bool = False
# (bool): Enable FSDP CPU offload, defaults to False
fsdp_cpu_offload: bool = False
# Common distributed training parameters
# (bool): Offload Adam optimizer states, defaults to False
adam_offload: bool = False
# (int): ZeRO parallel group size, defaults to 1
zpg: int = 1
# (Optional[str]): Gradient accumulation data type, defaults to None
grad_accum_dtype: Optional[str] = None
# (bool): Overlap communication and computation, defaults to False
overlap_comm: bool = False
# Engine and inference parameters
# (str): Inference engine type, defaults to "vllm"
engine_type: str = "vllm"
# (int): Engine tensor parallelism size, defaults to 1
engine_tp_size: int = 1
# (bool): Enable engine sleep mode, defaults to False
enable_engine_sleep: bool = False
# (int): Local rank for distributed training, defaults to -1
local_rank: int = -1
# Sequence parallel parameters
# (int): Sequence parallelism size, defaults to 1
sp_size: int = 1
# Model parameters
# (float): Actor model learning rate, defaults to 1e-5
actor_learning_rate: float = 1e-5
# (float): Critic model learning rate, defaults to 1e-5
critic_learning_rate: float = 1e-5
# (tuple): Adam optimizer beta parameters, defaults to (0.9, 0.95)
adam_betas: tuple = (0.9, 0.95)
# (float): L2 regularization coefficient, defaults to 0.0
l2: float = 0.0
# (float): Learning rate warmup ratio, defaults to 0.03
lr_warmup_ratio: float = 0.03
# Training control
# (bool): Pretrain critic model, defaults to False
critic_pretrain: bool = False
# (Optional[str]): Remote reward model URL, defaults to None
remote_rm_url: Optional[str] = None
# (Optional[str]): Pretraining data path, defaults to None
pretrain_data: Optional[str] = None
# (bool): Use fused linear layer and logprob computation, defaults to False
fused_linear_logprob: bool = False
# Reward and advantage processing
# (bool): Apply running normalization to rewards, defaults to False
reward_running_norm: bool = False
# (bool): When reward_running_norm is True, subtract mean during normalization, defaults to False
reward_running_norm_minus_mean: bool = False
# (bool): Normalize advantages, defaults to False
advantages_norm: bool = False
# (float): Clip advantages to this value, 0 means no clipping, defaults to 0.0
advantage_clip: float = 0.0
# (float): Clip rewards to this value, 0 means no clipping, defaults to 0.0
reward_clip: float = 0.0
# Experience generation parameters
# (int): Batch size for micro rollout during experience generation, defaults to 1
micro_rollout_batch_size: int = 2
# (int): Number of samples to generate per prompt, defaults to 1
n_samples_per_prompt: int = 8
# Overlong sequence handling
# (bool): Enable overlong sequence buffer penalty, defaults to False
overlong_buffer: bool = False
# (int): Buffer length for overlong sequence penalty calculation, defaults to 0
overlong_buffer_len: int = 1024
# (float): Penalty factor for overlong sequences, defaults to 1.0
overlong_buffer_penalty_factor: float = 1.0
# Dynamic sampling and advantage estimation
# (bool): Enable dynamic sampling for advantage estimation, defaults to False
dynamic_sampling: bool = False
# (str): Advantage estimator method, defaults to "gae"
advantage_estimator: str = "group_norm"
# KL loss and estimation
# (bool): Use KL loss in training, defaults to False
use_kl_loss: bool = False
# (str): KL divergence estimator method, defaults to "mean"
kl_estimator: str = "k3"
# FSDP specific parameters
# (bool): Use mixed precision matrix multiplication data, defaults to False
mixed_mm_data: bool = False
# (bool): Use model parallel optimizer, defaults to False
use_mp_opt: bool = False
# Analysis and monitoring
# (int): Plot interval steps, defaults to -1
plot_every: int = -1
# (bool): Use TensorBoard for logging, defaults to False
use_tensorboard: bool = False
# Additional arguments for backward compatibility
# (Dict[str, Any]): Extra arguments for backward compatibility, defaults to {}
extra_args: Dict[str, Any] = field(default_factory=dict)
[docs] @classmethod
def from_args(cls, args_dict) -> 'StrategyConfig':
"""
Create StrategyConfig from argparse.Namespace or similar object.
This method provides backward compatibility by extracting parameters
that were previously accessed via getattr, ensuring smooth migration
from legacy configuration systems.
:param args_dict: Configuration arguments object containing training parameters
:type args_dict: object
:return: StrategyConfig instance with extracted parameters
:rtype: StrategyConfig
Example::
# From argparse.Namespace
args = argparse.Namespace(
seed=42,
max_norm=1.0,
micro_train_batch_size=1,
# ... other parameters
)
config = StrategyConfig.from_args(args)
# From dictionary
args_dict = {
'seed': 42,
'max_norm': 1.0,
'micro_train_batch_size': 1,
# ... other parameters
}
config = StrategyConfig.from_args(args_dict)
"""
# Extract all known parameters with their default values
config = cls()
# Get all field names from the dataclass (excluding 'extra_args')
field_names = [field.name for field in dataclasses.fields(cls) if field.name != 'extra_args']
# Automatically assign values using getattr/hasattr
for field_name in field_names:
if hasattr(args_dict, field_name):
setattr(config, field_name, getattr(args_dict, field_name))
# Store original args for backward compatibility
config.extra_args = {k: v for k, v in vars(args_dict).items() if not hasattr(config, k)}
config.print_config_summary()
return config
[docs] def print_config_summary(self) -> None:
"""
Print a summary of the configuration for verification.
This method shows which parameters were overridden from defaults
and which are using default values.
"""
# Only print on rank 0 GPU to avoid duplicate output in distributed training
if torch.distributed.is_initialized() and torch.distributed.get_rank() != 0:
return
print("=" * 60)
print("StrategyConfig Configuration Verification Result")
print("=" * 60)
# Define default configuration for comparison
default_config = StrategyConfig()
print("\nConfiguration Parameters Details:")
print("-" * 40)
# Basic Training Parameters
print("Basic Training Parameters:")
for attr in ['seed', 'max_norm', 'micro_train_batch_size', 'train_batch_size', 'bf16']:
current = getattr(self, attr)
default = getattr(default_config, attr)
status = "Overridden" if current != default else "Default"
print(f" {attr}: {current} ({status})")
# Distributed Training Parameters
print("\nDistributed Training Parameters:")
for attr in [
'zero_stage', 'fsdp', 'fsdp_cpu_offload', 'adam_offload', 'zpg', 'grad_accum_dtype', 'overlap_comm'
]:
current = getattr(self, attr)
default = getattr(default_config, attr)
status = "Overridden" if current != default else "Default"
print(f" {attr}: {current} ({status})")
# Engine and Inference Parameters
print("\nEngine and Inference Parameters:")
for attr in ['engine_type', 'engine_tp_size', 'enable_engine_sleep', 'local_rank', 'sp_size']:
current = getattr(self, attr)
default = getattr(default_config, attr)
status = "Overridden" if current != default else "Default"
print(f" {attr}: {current} ({status})")
# Model Parameters
print("\nModel Parameters:")
for attr in ['actor_learning_rate', 'critic_learning_rate', 'adam_betas', 'l2', 'lr_warmup_ratio']:
current = getattr(self, attr)
default = getattr(default_config, attr)
status = "Overridden" if current != default else "Default"
print(f" {attr}: {current} ({status})")
# Training Control Parameters
print("\nTraining Control Parameters:")
for attr in ['critic_pretrain', 'remote_rm_url', 'pretrain_data', 'fused_linear_logprob']:
current = getattr(self, attr)
default = getattr(default_config, attr)
status = "Overridden" if current != default else "Default"
print(f" {attr}: {current} ({status})")
# Reward and Advantage Processing Parameters
print("\nReward and Advantage Processing Parameters:")
for attr in [
'reward_running_norm', 'reward_running_norm_minus_mean', 'advantages_norm', 'advantage_clip', 'reward_clip'
]:
current = getattr(self, attr)
default = getattr(default_config, attr)
status = "Overridden" if current != default else "Default"
print(f" {attr}: {current} ({status})")
# Experience Generation Parameters
print("\nExperience Generation Parameters:")
for attr in ['micro_rollout_batch_size', 'n_samples_per_prompt']:
current = getattr(self, attr)
default = getattr(default_config, attr)
status = "Overridden" if current != default else "Default"
print(f" {attr}: {current} ({status})")
# Overlong Sequence Handling Parameters
print("\nOverlong Sequence Handling Parameters:")
for attr in ['overlong_buffer', 'overlong_buffer_len', 'overlong_buffer_penalty_factor']:
current = getattr(self, attr)
default = getattr(default_config, attr)
status = "Overridden" if current != default else "Default"
print(f" {attr}: {current} ({status})")
# Dynamic Sampling and Advantage Estimation Parameters
print("\nDynamic Sampling and Advantage Estimation Parameters:")
for attr in ['dynamic_sampling', 'advantage_estimator']:
current = getattr(self, attr)
default = getattr(default_config, attr)
status = "Overridden" if current != default else "Default"
print(f" {attr}: {current} ({status})")
# KL Loss and Estimation Parameters
print("\nKL Loss and Estimation Parameters:")
for attr in ['use_kl_loss', 'kl_estimator']:
current = getattr(self, attr)
default = getattr(default_config, attr)
status = "Overridden" if current != default else "Default"
print(f" {attr}: {current} ({status})")
# FSDP Specific Parameters
print("\nFSDP Specific Parameters:")
for attr in ['mixed_mm_data', 'use_mp_opt']:
current = getattr(self, attr)
default = getattr(default_config, attr)
status = "Overridden" if current != default else "Default"
print(f" {attr}: {current} ({status})")
# Analysis and Monitoring Parameters
print("\nAnalysis and Monitoring Parameters:")
for attr in ['plot_every', 'use_tensorboard']:
current = getattr(self, attr)
default = getattr(default_config, attr)
status = "Overridden" if current != default else "Default"
print(f" {attr}: {current} ({status})")
# extra_args
if self.extra_args:
print("\nExtra Parameters (extra_args):")
for key, value in self.extra_args.items():
print(f" {key}: {value}")
else:
print("\nExtra Parameters: None")
print("\n" + "=" * 60)
print("Configuration verification completed!")
print("=" * 60)