Shortcuts

lightrft.strategy.utils.broadcast_utils

Module for managing weight synchronization between training and inference engines.

This module provides functionality to broadcast model weights from training to inference engines, supporting different distributed training strategies including DeepSpeed and FSDP (Fully Sharded Data Parallel v2). It handles the complexities of gathering sharded parameters and efficiently transferring them to inference engines like vllm and sglang.

BroadcastManager

class lightrft.strategy.utils.broadcast_utils.BroadcastManager(actor: torch.nn.Module, strategy: Any, inference_engine: Any)[source]

Manage the weight synchronization between training and inference engine.

This class handles the broadcasting of model weights from a distributed training setup to inference engines. It supports different distributed training strategies including DeepSpeed ZeRO and PyTorch’s FSDP v2.

Parameters:
  • actor – The actor model containing weights to be broadcasted

  • strategy – The training strategy object containing configuration and methods

  • inference_engine – The inference engine (vllm or sglang) to receive the weights

__init__(actor: torch.nn.Module, strategy: Any, inference_engine: Any) None[source]

Initialize the BroadcastManager with the necessary components.

Parameters:
  • actor (torch.nn.Module) – The actor model containing weights to be broadcasted

  • strategy (object) – The training strategy object containing configuration and methods

  • inference_engine (object) – The inference engine (vllm or sglang) to receive the weights

broadcast_to_engine()[source]

Broadcast model weights to the inference engine.

This method selects the appropriate broadcasting strategy based on the distributed training configuration (DeepSpeed, FSDP v2). It automatically detects whether to use DeepSpeed or FSDP broadcasting based on the strategy configuration.

Example:

# Initialize the broadcast manager
broadcast_manager = BroadcastManager(actor_model, strategy, inference_engine)

# Broadcast weights to inference engine
broadcast_manager.broadcast_to_engine()
Raises:

NotImplementedError – If an unsupported configuration is used