Shortcuts

lightrft.strategy.sglang_utils.sglang_engine

This module provides a distributed reinforcement learning generation engine for language models.

The RLGenerationEngine class facilitates text generation across multiple processes and nodes, leveraging PyTorch’s distributed capabilities. It supports tensor parallelism for efficient model execution and provides methods for generating text, updating model weights, and managing memory resources.

The module is designed to work with different versions of SGLang: - Supports both old and new Engine import paths - Compatible with various SGLang API changes across versions

The module is designed to work with the SGLang runtime (srt) system and supports features such as batch processing, custom sampling parameters, and LoRA fine-tuning.

RLGenerationEngine

class lightrft.strategy.sglang_utils.sglang_engine.RLGenerationEngine(tp_group_cpu, num_gpu_per_node: int = 8, **kwargs)[source]

A distributed reinforcement learning generation engine for language models.

This class manages text generation across multiple processes and nodes using tensor parallelism. It wraps the SGLang Engine to provide distributed generation capabilities with efficient memory management and weight updating functionality.

The engine automatically handles distributed coordination, ensuring that only the leader rank performs actual generation while broadcasting results to all participating processes. It supports various input formats including text prompts, token IDs, and image data.

Parameters:
  • tp_group_cpu (torch.distributed.ProcessGroup) – PyTorch process group for tensor parallelism on CPU

  • num_gpu_per_node (int) – Number of GPUs per node, defaults to 8

  • kwargs – Additional arguments to pass to the underlying SGLang Engine

Example:

>>> import torch.distributed as dist
>>> # Initialize distributed environment
>>> dist.init_process_group("nccl")
>>> tp_group = dist.new_group()
>>> engine = RLGenerationEngine(
...     tp_group_cpu=tp_group,
...     num_gpu_per_node=8,
...     model="llama2-7b"
... )
generate(prompt: List[str] | str | None = None, sampling_params: List[Dict] | Dict | None = None, input_ids: List[List[int]] | List[int] | None = None, image_data: List[str] | str | None = None, return_logprob: List[bool] | bool | None = False, logprob_start_len: List[int] | int | None = None, top_logprobs_num: List[int] | int | None = None, lora_path: List[str | None] | None = None, custom_logit_processor: List[str] | str | None = None, gather_inputs=False) Dict[source]

Generate text using the language model in a distributed manner.

This method coordinates text generation across multiple processes, with only the leader rank performing actual generation and broadcasting results to all participants. It supports various input formats and generation parameters.

The arguments of this function are the same as sglang/srt/managers/io_struct.py::GenerateReqInput. Please refer to GenerateReqInput for detailed documentation of each parameter.

Parameters:
  • prompt (Optional[Union[List[str], str]]) – The input prompt(s) for text generation

  • sampling_params (Optional[Union[List[Dict], Dict]]) – Parameters controlling the sampling strategy (temperature, top_p, etc.)

  • input_ids (Optional[Union[List[List[int]], List[int]]]) – Token IDs for input text (alternative to using prompt)

  • image_data (Optional[Union[List[str], str]]) – Image input as file name, URL, or base64 encoded string

  • return_logprob (Optional[Union[List[bool], bool]]) – Whether to return log probabilities for generated tokens

  • logprob_start_len (Optional[Union[List[int], int]]) – Start position for log probability calculation

  • top_logprobs_num (Optional[Union[List[int], int]]) – Number of top log probabilities to return

  • lora_path (Optional[List[Optional[str]]]) – Paths to LoRA weights to apply during generation

  • custom_logit_processor (Optional[Union[List[str], str]]) – Custom logit processor for modifying logits

  • gather_inputs (bool) – Whether to gather inputs across all ranks before generation

Returns:

Generation results including generated text and metadata

Return type:

Dict

Example:

>>> result = engine.generate(
...     prompt="Translate this to French: Hello, world!",
...     sampling_params={"temperature": 0.7, "max_tokens": 50}
... )
>>> print(result["text"])
>>>
>>> # Batch generation with different parameters
>>> results = engine.generate(
...     prompt=["Hello", "Goodbye"],
...     sampling_params=[{"temperature": 0.5}, {"temperature": 0.9}],
...     return_logprob=True
... )
shutdown()[source]

Shut down the engine and release all resources.

This method should be called when the engine is no longer needed to properly clean up resources including GPU memory, process groups, and any background threads. It ensures a clean termination of the engine.

Example:

>>> try:
...     # Use engine for generation
...     results = engine.generate("Hello")
... finally:
...     engine.shutdown()  # Always clean up resources
sleep(release_weights: bool = False)[source]

Release memory resources temporarily to free up GPU memory.

This method releases KV cache and CUDA graph memory to free up GPU resources during idle periods. By default, model weights are kept in memory to avoid the overhead and risk of saving/restoring them.

Parameters:

release_weights (bool) – Whether to also release weights memory. Default is False. Set to True only if you need maximum memory savings and understand the risks (SGLang may not properly restore weights).

Note:
  • By default, only KV cache and CUDA graph are released (recommended)

  • Weights are kept in memory unless explicitly requested

  • After calling sleep(), you must call wake_up() before using the engine again

  • ⚠️ WARNING: Releasing weights may cause generation issues due to SGLang limitations

Example:

>>> # Standard usage: keep weights in memory (recommended)
>>> engine.sleep()
>>> # ... other operations ...
>>> engine.wake_up()
>>>
>>> # Maximum memory savings: also release weights (use with caution)
>>> engine.sleep(release_weights=True)
>>> engine.wake_up(release_weights=True)
update_weights_from_tensor(name: str, tensor: torch.Tensor, flush_cache: bool = False, load_format: str | None = None)[source]

Update model weights with the provided tensor in a distributed manner.

This method allows dynamic updating of model parameters during runtime, which is particularly useful for reinforcement learning scenarios where model weights need to be updated based on training feedback. The method handles distributed coordination to ensure all ranks participate in the weight update process.

Parameters:
  • name (str) – Name of the weight tensor to update (e.g., “model.layers.0.self_attn.q_proj.weight”)

  • tensor (torch.Tensor) – New weight tensor to replace the existing weights

  • flush_cache (bool) – Whether to flush the KV cache after updating weights

  • load_format (Optional[str]) – Format specification for loading the weights

Example:

>>> # Update a specific layer's weights
>>> new_weights = torch.randn(768, 768)
>>> engine.update_weights_from_tensor(
...     "model.layers.0.self_attn.q_proj.weight",
...     new_weights,
...     flush_cache=True
... )
wake_up(release_weights: bool = False)[source]

Resume memory occupation after a call to sleep().

This method should be called before using the engine for generation after a previous call to sleep(). It restores the engine to its fully operational state with all necessary memory allocations.

Parameters:

release_weights (bool) – Must match the value used in sleep(). Default is False.

Important:
  • The release_weights parameter must match what was used in sleep()

  • If you called sleep(release_weights=False), call wake_up(release_weights=False)

  • If you called sleep(release_weights=True), call wake_up(release_weights=True)

Example:

>>> # Standard usage (recommended)
>>> engine.sleep()              # Only KV cache & CUDA graph released
>>> # ... do other work ...
>>> engine.wake_up()            # Only KV cache & CUDA graph restored
>>> result = engine.generate("Hello world")
>>>
>>> # If weights were released (use with caution)
>>> engine.sleep(release_weights=True)
>>> engine.wake_up(release_weights=True)  # Must match!