lightrft.strategy.vllm_utils¶
This module provides utilities for initializing and configuring a vLLM engine.
The module simplifies the process of creating a vLLM engine with specific configurations for large language model inference, particularly in reinforcement learning from human feedback (RLHF) contexts. It offers both high-level and low-level functions for engine creation, with support for tensor parallelism, memory optimization, and multimodal capabilities.
- Note:
This module uses lazy imports for vLLM. The module can be imported without vLLM installed (when using SGLang backend). ImportError will only be raised when actually trying to use vLLM functions with engine_type=”vllm”.
To use vLLM backend, install with: pip install “LightRFT[vllm]”
get_vllm_engine¶
- lightrft.strategy.vllm_utils.get_vllm_engine(pretrain_name_or_path: str, dtype: str = 'bfloat16', tp_size: int = 1, mem_util: float = 0.5, max_model_len: int = 4096, enable_sleep: bool = True, **kwargs: Any)[source]¶
Create and configure a vLLM engine with specified parameters.
This is the core function for initializing a vLLM engine with custom configurations. It sets up the engine with distributed execution capabilities, memory optimization, and custom worker classes for RLHF training scenarios.
- Parameters:
pretrain_name_or_path (str) – Path or name of the pretrained model to load.
dtype (str) – Data type for model weights, either “bfloat16” or “float16”. Defaults to “bfloat16”.
tp_size (int) – Tensor parallel size for distributed inference. Defaults to 1.
mem_util (float) – GPU memory utilization ratio (0.0 to 1.0). Defaults to 0.5.
max_model_len (int) – Maximum sequence length the model can handle. Defaults to 4096.
enable_sleep (bool) – Whether to enable sleep mode for memory efficiency. Defaults to True.
kwargs (Any) – Additional keyword arguments passed to the LLM constructor.
- Returns:
Configured vLLM engine instance.
- Return type:
vllm.LLM
- Raises:
ImportError – If vLLM is not installed when this function is called.
Example:
>>> engine = get_vllm_engine( ... "Qwen/Qwen2.5-14B-Instruct", ... dtype="bfloat16", ... tp_size=2, ... mem_util=0.8, ... max_model_len=2048, ... enable_sleep=True ... )
- Note:
Uses external launcher for distributed execution and custom worker class for integration with lightrft strategy components.