lightrft.strategy.vllm_utils.vllm_worker_wrap_no_ray¶
This module provides a wrapper for the vLLM worker that extends its functionality.
The main purpose of this module is to provide a way to update weights of a vLLM worker model from a source rank. This is particularly useful for distributed training or inference scenarios where model weights need to be synchronized across multiple workers.
WorkerWrap¶
- class lightrft.strategy.vllm_utils.vllm_worker_wrap_no_ray.WorkerWrap(*args: Any, **kwargs: Any)[source]¶
A wrapper for vLLM worker that extends its functionality.
This class inherits from vLLM’s Worker class and adds the ability to update model weights dynamically. This is particularly useful for distributed setups where weights need to be broadcast from a source rank to all workers.
- Inherits:
vllm.worker.worker.Worker
- update_weight(name, dtype, shape, weight, empty_cache=False)[source]¶
Broadcast weight to all vLLM workers from source rank 0 (actor model).
This method updates a specific weight tensor in the model. It ensures that the data type of the incoming weight matches the model’s configured data type before loading the weight into the model.
- Parameters:
name (str) – The name of the weight tensor to update.
dtype (torch.dtype) – The data type of the weight tensor.
shape (tuple) – The shape of the weight tensor.
weight (torch.Tensor) – The new weight tensor values.
empty_cache (bool) – Whether to empty CUDA cache after updating weights.
- Raises:
AssertionError – If the data type of the weight doesn’t match the model’s configured data type.