lightrft.models.monkey_patch.llama¶
This module implements a modified version of the LLaMA attention mechanism with support for Ulysses sequence parallelism. It adapts the original transformers implementation to work with sequence parallelism for improved performance on distributed systems.
- lightrft.models.monkey_patch.llama.llama_attn_forward(self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: torch.Tensor | None, past_key_value: transformers.cache_utils.Cache | None = None, cache_position: torch.LongTensor | None = None, **kwargs) Tuple[torch.Tensor, torch.Tensor | None][source]¶
Modified LLaMA attention forward pass with Ulysses sequence parallelism support.
This function implements the attention mechanism for LLaMA models with added support for sequence parallelism. It handles the projection of input states to query/key/value, applies rotary position embeddings, and performs the attention computation.
- Parameters:
self (LlamaAttention) – The attention module instance
hidden_states (torch.Tensor) – Input tensor to compute attention on
position_embeddings (Tuple[torch.Tensor, torch.Tensor]) – Tuple of (cos, sin) tensors for rotary position embeddings
attention_mask (Optional[torch.Tensor]) – Optional mask to prevent attention to certain positions
past_key_value (Optional[Cache]) – Optional cached key and value tensors for incremental decoding
cache_position (Optional[torch.LongTensor]) – Optional tensor indicating positions in the cache
kwargs (dict) – Additional keyword arguments passed to the attention implementation
- Returns:
Tuple containing: - attention output tensor - attention weights (optional)
- Return type:
Tuple[torch.Tensor, Optional[torch.Tensor]]
- Raises:
Warning – When using SDPA with output_attentions=True
Note
Adapted from transformers 4.49.0 to support Ulysses sequence parallelism for transformers >= 4.48.0.
This function has been tested only on transformers versions between 4.48.0 and 4.49.0.