Shortcuts

Source code for lightrft.strategy.utils.broadcast_utils

"""
Module for managing weight synchronization between training and inference engines.

This module provides functionality to broadcast model weights from training to inference engines,
supporting different distributed training strategies including DeepSpeed and FSDP (Fully Sharded
Data Parallel v2). It handles the complexities of gathering sharded
parameters and efficiently transferring them to inference engines like vllm and sglang.
"""

from typing import Any

import deepspeed
import torch
from torch.distributed.tensor import DTensor

from lightrft.utils import get_current_device


[docs]class BroadcastManager: """ Manage the weight synchronization between training and inference engine. This class handles the broadcasting of model weights from a distributed training setup to inference engines. It supports different distributed training strategies including DeepSpeed ZeRO and PyTorch's FSDP v2. :param actor: The actor model containing weights to be broadcasted :param strategy: The training strategy object containing configuration and methods :param inference_engine: The inference engine (vllm or sglang) to receive the weights """
[docs] def __init__(self, actor: torch.nn.Module, strategy: Any, inference_engine: Any) -> None: """ Initialize the BroadcastManager with the necessary components. :param actor: The actor model containing weights to be broadcasted :param strategy: The training strategy object containing configuration and methods :param inference_engine: The inference engine (vllm or sglang) to receive the weights :type actor: torch.nn.Module :type strategy: object :type inference_engine: object """ self.actor = actor self.strategy = strategy self.inference_engine = inference_engine
def _map_weight_name_for_sglang(self, name: str) -> str: """ Map weight names from training model format to SGLang format. Training model (Qwen2.5-VL with wrapper): - model.visual.xxx - model.language_model.embed_tokens - model.language_model.layers.xxx - model.language_model.norm - model.language_model.lm_head SGLang expects: - visual.xxx - model.embed_tokens - model.layers.xxx - model.norm - lm_head :param name: Original weight name from training model :return: Mapped weight name for SGLang """ # Step 1: Remove outermost "model." prefix if present if name.startswith("model."): name = name[6:] # Remove "model." # Step 2: Handle language_model prefix mapping if name.startswith("language_model."): # Remove "language_model." prefix name = name[15:] # Remove "language_model." # For lm_head, keep as is (no "model." prefix) if name.startswith("lm_head"): return name # For other components (embed_tokens, layers, norm), add "model." prefix return f"model.{name}" # Step 3: Return as is for other cases (e.g., visual.xxx) return name def _deepspeed_broadcast(self): """ Broadcast model weights using DeepSpeed's ZeRO optimization. This method handles gathering sharded parameters in ZeRO-3 and broadcasts them to all inference engines. It processes parameters one by one to avoid memory issues. For ZeRO-3, it uses DeepSpeed's GatheredParameters context manager to collect sharded parameters before broadcasting. :raises NotImplementedError: If an unsupported inference engine is specified """ # avoid OOM torch.cuda.empty_cache() model = self.actor.model.module count, num_params = 0, len(list(model.named_parameters())) for name, param in model.named_parameters(): count += 1 # empty_cache at last param # For ZeRO-3, allgather sharded parameter and broadcast to all vllm engines by rank 0 with deepspeed.zero.GatheredParameters([param], enabled=self.strategy.args.zero_stage == 3): shape = param.shape if self.strategy.args.zero_stage != 3 else param.ds_shape kwargs = dict( name=name, dtype=param.dtype, shape=shape, weight=param.data, empty_cache=(count == num_params) ) if self.strategy.engine_type == "vllm": self.inference_engine.llm_engine.model_executor.collective_rpc("update_weight", kwargs=kwargs) elif self.strategy.engine_type == "sglang": if self.strategy.args.text_only: # for LLM self.inference_engine.update_weights_from_tensor( name, param.data, flush_cache=(count == num_params) ) else: # for VLM # Map weight names from training model to SGLang format # Training model: model.visual.xxx, model.language_model.xxx # SGLang expects: visual.xxx, model.xxx (for language model), lm_head sglang_name = self._map_weight_name_for_sglang(name) self.inference_engine.update_weights_from_tensor( sglang_name, param.data, flush_cache=(count == num_params) ) def _fsdp_v2_broadcast(self): """ Broadcast model weights using PyTorch's FSDP v2. This method uses the state_dict approach to gather and broadcast weights for FSDP v2, which has a different API compared to v1. It handles DTensor parameters by converting them to full tensors before broadcasting. :raises NotImplementedError: If sglang is used as the inference engine, which doesn't support FSDP v2 """ model = self.actor.model count, num_params = 0, len(list(model.named_parameters())) dst_dtype = torch.bfloat16 if self.strategy.args.bf16 else torch.float16 for name, param in model.named_parameters(): count += 1 # empty_cache at last param param_on_device = param.to(get_current_device()) if isinstance(param, DTensor): full_param = param_on_device.full_tensor().to(dst_dtype) else: full_param = param_on_device.to(dst_dtype) if self.strategy.engine_type == "vllm": kwargs = dict( name=name, dtype=full_param.dtype, shape=full_param.shape, weight=full_param.data, empty_cache=(count == num_params), ) self.inference_engine.llm_engine.model_executor.collective_rpc("update_weight", kwargs=kwargs) elif self.strategy.engine_type == "sglang": if self.strategy.args.text_only: # for LLM self.inference_engine.update_weights_from_tensor( name, param.data, flush_cache=(count == num_params) ) else: # for VLM # Map weight names from training model to SGLang format # Training model: model.visual.xxx, model.language_model.xxx # SGLang expects: visual.xxx, model.xxx (for language model), lm_head sglang_name = self._map_weight_name_for_sglang(name) self.inference_engine.update_weights_from_tensor( sglang_name, param.data, flush_cache=(count == num_params) ) del param_on_device del full_param
[docs] def broadcast_to_engine(self): """ Broadcast model weights to the inference engine. This method selects the appropriate broadcasting strategy based on the distributed training configuration (DeepSpeed, FSDP v2). It automatically detects whether to use DeepSpeed or FSDP broadcasting based on the strategy configuration. Example:: # Initialize the broadcast manager broadcast_manager = BroadcastManager(actor_model, strategy, inference_engine) # Broadcast weights to inference engine broadcast_manager.broadcast_to_engine() :raises NotImplementedError: If an unsupported configuration is used """ if self.strategy.args.fsdp: self._fsdp_v2_broadcast() else: self._deepspeed_broadcast()