Shortcuts

lightrft.models.monkey_patch.llama

This module implements a modified version of the LLaMA attention mechanism with support for Ulysses sequence parallelism. It adapts the original transformers implementation to work with sequence parallelism for improved performance on distributed systems.

lightrft.models.monkey_patch.llama.llama_attn_forward(self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: torch.Tensor | None, past_key_value: transformers.cache_utils.Cache | None = None, cache_position: torch.LongTensor | None = None, **kwargs) Tuple[torch.Tensor, torch.Tensor | None][source]

Modified LLaMA attention forward pass with Ulysses sequence parallelism support.

This function implements the attention mechanism for LLaMA models with added support for sequence parallelism. It handles the projection of input states to query/key/value, applies rotary position embeddings, and performs the attention computation.

Parameters:
  • self (LlamaAttention) – The attention module instance

  • hidden_states (torch.Tensor) – Input tensor to compute attention on

  • position_embeddings (Tuple[torch.Tensor, torch.Tensor]) – Tuple of (cos, sin) tensors for rotary position embeddings

  • attention_mask (Optional[torch.Tensor]) – Optional mask to prevent attention to certain positions

  • past_key_value (Optional[Cache]) – Optional cached key and value tensors for incremental decoding

  • cache_position (Optional[torch.LongTensor]) – Optional tensor indicating positions in the cache

  • kwargs (dict) – Additional keyword arguments passed to the attention implementation

Returns:

Tuple containing: - attention output tensor - attention weights (optional)

Return type:

Tuple[torch.Tensor, Optional[torch.Tensor]]

Raises:

Warning – When using SDPA with output_attentions=True

Note

Adapted from transformers 4.49.0 to support Ulysses sequence parallelism for transformers >= 4.48.0.

This function has been tested only on transformers versions between 4.48.0 and 4.49.0.