lightrft.strategy.utils.broadcast_utils¶
Module for managing weight synchronization between training and inference engines.
This module provides functionality to broadcast model weights from training to inference engines, supporting different distributed training strategies including DeepSpeed and FSDP (Fully Sharded Data Parallel v2). It handles the complexities of gathering sharded parameters and efficiently transferring them to inference engines like vllm and sglang.
BroadcastManager¶
- class lightrft.strategy.utils.broadcast_utils.BroadcastManager(actor: torch.nn.Module, strategy: Any, inference_engine: Any)[source]¶
Manage the weight synchronization between training and inference engine.
This class handles the broadcasting of model weights from a distributed training setup to inference engines. It supports different distributed training strategies including DeepSpeed ZeRO and PyTorch’s FSDP v2.
- Parameters:
actor – The actor model containing weights to be broadcasted
strategy – The training strategy object containing configuration and methods
inference_engine – The inference engine (vllm or sglang) to receive the weights
- __init__(actor: torch.nn.Module, strategy: Any, inference_engine: Any) None[source]¶
Initialize the BroadcastManager with the necessary components.
- Parameters:
actor (torch.nn.Module) – The actor model containing weights to be broadcasted
strategy (object) – The training strategy object containing configuration and methods
inference_engine (object) – The inference engine (vllm or sglang) to receive the weights
- broadcast_to_engine()[source]¶
Broadcast model weights to the inference engine.
This method selects the appropriate broadcasting strategy based on the distributed training configuration (DeepSpeed, FSDP v2). It automatically detects whether to use DeepSpeed or FSDP broadcasting based on the strategy configuration.
Example:
# Initialize the broadcast manager broadcast_manager = BroadcastManager(actor_model, strategy, inference_engine) # Broadcast weights to inference engine broadcast_manager.broadcast_to_engine()
- Raises:
NotImplementedError – If an unsupported configuration is used