Shortcuts

lightrft.models.utils

Utility functions for computing log probabilities from logits in PyTorch.

This module provides functions to efficiently calculate log probabilities for token predictions, with optimizations to handle different data types and reduce memory consumption. It also includes utilities for finding linear modules in neural networks and handling position IDs for packed sequences in transformer models.

The module is particularly useful for: - Computing log probabilities from model logits with memory-efficient approaches - Finding LoRA-injectable linear modules in various model architectures - Handling position IDs in packed sequence scenarios for transformer models

log_probs_from_logits

lightrft.models.utils.log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor, disable_logprobs_flashattn: bool = False) torch.Tensor[source]

Compute log probabilities for the given labels from logits.

This function calculates log probabilities efficiently, using different approaches based on the input data type to optimize memory usage. For float32/float64 tensors, it uses a direct computation approach, while for other data types (e.g. float16 and bfloat16) it uses PyTorch’s log_softmax function with row-by-row processing to reduce peak memory consumption.

Parameters:
  • logits (torch.Tensor) – Logits tensor of shape (batch_size, sequence_length, vocab_size) or (batch_size, vocab_size)

  • labels (torch.Tensor) – Labels tensor containing token indices, of shape (batch_size, sequence_length) or (batch_size,)

  • disable_logprobs_flashattn (bool) – Whether to use flash attn when calculating cross entropy loss default to False

Returns:

Log probabilities for the given labels, of shape matching labels

Return type:

torch.Tensor

Example::
>>> logits = torch.randn(2, 3, 5)  # batch_size=2, seq_len=3, vocab_size=5
>>> labels = torch.randint(0, 5, (2, 3))  # batch_size=2, seq_len=3
>>> log_probs = log_probs_from_logits(logits, labels)
>>> log_probs.shape
torch.Size([2, 3])