Shortcuts

lightrft.strategy.fsdp.fsdp_utils

Gradient scaling and optimization utilities for deep learning.

This module provides tools for gradient handling, norm computation, and optimization in PyTorch. It includes dynamic gradient scaling for mixed precision training, gradient norm computation, and base classes for optimizers.

is_meta_initialized

lightrft.strategy.fsdp.fsdp_utils.is_meta_initialized(model) bool[source]

Check if a PyTorch model’s parameters are meta-initialized.

Meta-initialized models contain parameters on a ‘meta’ device, which are placeholders that don’t allocate actual memory. These are useful for model initialization without memory overhead, commonly used in model parallelism and large model initialization.

For more information on meta device and meta tensors, see: https://docs.pytorch.org/docs/stable/meta.html

Parameters:

model (torch.nn.Module) – The PyTorch module to check.

Raises:

TypeError – if model is not an instance of torch.nn.Module.

Returns:

True if any parameter in the model is on a meta device, False otherwise.

Return type:

bool

Example:

>>> import torch
>>> model = torch.nn.Linear(10, 1)
>>> is_meta_initialized(model)  # False for regular model
False
>>> with torch.device('meta'):
...     meta_model = torch.nn.Linear(10, 1)
>>> is_meta_initialized(meta_model)  # True for meta model
True

multi_tensor_l2norm_torch

lightrft.strategy.fsdp.fsdp_utils.multi_tensor_l2norm_torch(tensor_list, per_tensor)[source]

Compute L2 norm of multiple tensors using PyTorch.

This function provides a pure PyTorch implementation for computing L2 norms when APEX is not available. It converts all tensors to float32 for computation and returns both overall and per-tensor norms.

Parameters:
  • tensor_list (list[torch.Tensor]) – List of tensors to compute norm for

  • per_tensor (bool) – Whether to return per-tensor norms

Returns:

Tuple of (overall L2 norm, per-tensor norms)

Return type:

tuple[torch.Tensor, torch.Tensor]

Example:

>>> tensors = [torch.randn(3, 3), torch.randn(2, 2)]
>>> overall_norm, per_tensor_norms = multi_tensor_l2norm_torch(tensors, True)
>>> print(overall_norm.shape)  # torch.Size([1])
>>> print(per_tensor_norms.shape)  # torch.Size([2])

calc_l2_norm

lightrft.strategy.fsdp.fsdp_utils.calc_l2_norm(grads)[source]

Calculate L2 norm of gradients using optimized implementation when available.

This function automatically selects the fastest available implementation for computing L2 norms. It uses APEX’s multi-tensor operations when available for better performance, otherwise falls back to PyTorch implementation.

Parameters:

grads (list[torch.Tensor]) – List of gradient tensors

Returns:

L2 norm of gradients

Return type:

torch.Tensor

Example:

>>> grads = [torch.randn(10, 10).requires_grad_(), torch.randn(5, 5).requires_grad_()]
>>> norm = calc_l2_norm(grads)
>>> print(norm.item())  # scalar value

calc_lp

lightrft.strategy.fsdp.fsdp_utils.calc_lp(grads, norm_type)[source]

Calculate Lp norm of gradients.

Computes the p-norm of a list of gradient tensors, where p is specified by norm_type. This is useful for gradient clipping and monitoring.

Parameters:
  • grads (list[torch.Tensor]) – List of gradient tensors

  • norm_type (float) – The p in Lp norm

Returns:

Lp norm of gradients

Return type:

float

Example:

>>> grads = [torch.randn(3, 3), torch.randn(2, 2)]
>>> l1_norm = calc_lp(grads, 1.0)  # L1 norm
>>> l2_norm = calc_lp(grads, 2.0)  # L2 norm

get_norm

lightrft.strategy.fsdp.fsdp_utils.get_norm(grads, norm_type, enable_cuda_kernels)[source]

Get norm of gradients with specified norm type.

This function dispatches to the appropriate norm calculation method based on the norm type and whether CUDA kernels are available. It handles special cases like infinity norm and optimized L2 norm computation.

Parameters:
  • grads (list[torch.Tensor]) – List of gradient tensors

  • norm_type (float) – Type of norm to compute (2.0, inf, etc.)

  • enable_cuda_kernels (bool) – Whether to use CUDA optimized kernels

Returns:

Norm of gradients

Return type:

float or torch.Tensor

Example:

>>> grads = [torch.randn(3, 3), torch.randn(2, 2)]
>>> l2_norm = get_norm(grads, 2.0, True)
>>> inf_norm = get_norm(grads, float('inf'), True)

reduce_grads

lightrft.strategy.fsdp.fsdp_utils.reduce_grads(gradients, parameters)[source]

Prepare gradients for norm computation in distributed training.

This function processes gradients to prepare them for norm computation, particularly in distributed training scenarios. It converts gradients to float32 for numerical stability during norm calculations.

Parameters:
  • gradients (list[torch.Tensor]) – List of gradient tensors

  • parameters (list[torch.Tensor]) – List of parameter tensors

Returns:

List of processed gradient tensors

Return type:

list[torch.Tensor]

Example:

>>> grads = [torch.randn(3, 3), torch.randn(2, 2)]
>>> params = [torch.randn(3, 3), torch.randn(2, 2)]
>>> processed_grads = reduce_grads(grads, params)

get_tensor_norm

lightrft.strategy.fsdp.fsdp_utils.get_tensor_norm(norm: float | torch.Tensor, move_to_cuda) torch.Tensor[source]

Convert norm to tensor and move to appropriate device.

This utility function ensures that norm values are properly converted to tensors and placed on the correct device for further computation. It handles both scalar and tensor inputs.

Parameters:
  • norm (Union[float, torch.Tensor]) – Norm value as float or tensor

  • move_to_cuda (bool) – Whether to move the tensor to CUDA

Returns:

Norm as tensor on appropriate device

Return type:

torch.Tensor

Example:

>>> norm_float = 2.5
>>> norm_tensor = get_tensor_norm(norm_float, True)
>>> print(norm_tensor.device)  # cuda:0 (if CUDA available)

compute_norm

lightrft.strategy.fsdp.fsdp_utils.compute_norm(gradients, parameters, norm_type=2)[source]

Get the norm across distributed environment.

This function computes gradient norms in a distributed training setting, handling device placement and distributed reduction. It’s commonly used for gradient clipping and monitoring training stability.

Parameters:
  • gradients (list[torch.Tensor]) – The gradient values

  • parameters (list[torch.Tensor]) – The parameters each gradient corresponds to

  • norm_type (float or int) – Type of the used p-norm. Can be 'inf' for infinity norm

Returns:

Total norm of the parameters, need total_norm**(1/norm) before using

Return type:

float

Example:

>>> grads = [param.grad for param in model.parameters() if param.grad is not None]
>>> params = [param for param in model.parameters() if param.grad is not None]
>>> total_norm = compute_norm(grads, params)
>>> print(f"Gradient norm: {total_norm}")

BaseGradScaler

class lightrft.strategy.fsdp.fsdp_utils.BaseGradScaler(initial_scale: float)[source]

A base class for the gradient scaler.

This abstract base class defines the interface for gradient scalers used in mixed precision training. Gradient scalers help prevent gradient underflow in float16 training by scaling up the loss before backpropagation.

Parameters:

initial_scale (float) – The initial loss scale

Example:

>>> # Subclass implementation
>>> class MyGradScaler(BaseGradScaler):
...     def update(self, overflow: bool) -> None:
...         # Custom update logic
...         pass
property inv_scale: torch.Tensor

Returns the inverse of the loss scale.

The inverse scale is used to unscale gradients after backpropagation to restore their original magnitudes.

Returns:

Inverse of current loss scale

Return type:

torch.Tensor

Example:

>>> scaler = DynamicGradScaler(initial_scale=1024.0)
>>> print(scaler.inv_scale.item())  # 0.0009765625 (1/1024)
load_state_dict(state_dict: Dict) None[source]

Load the states of the gradient scaler from a dict object.

Parameters:

state_dict (Dict) – The states of the gradient scaler

Example:

>>> scaler = DynamicGradScaler()
>>> state = {"scale": torch.tensor([2048.0])}
>>> scaler.load_state_dict(state)
property scale: torch.Tensor

Returns the loss scale.

Returns:

Current loss scale

Return type:

torch.Tensor

Example:

>>> scaler = DynamicGradScaler(initial_scale=1024.0)
>>> print(scaler.scale.item())  # 1024.0
state_dict() Dict[source]

Returns the states of the gradient scaler as a dict object.

Returns:

State dictionary containing scale

Return type:

Dict

Example:

>>> scaler = DynamicGradScaler()
>>> state = scaler.state_dict()
>>> print(state.keys())
dict_keys(['scale'])
abstract update(overflow: bool) None[source]

Update the loss scale.

This abstract method must be implemented by subclasses to define how the loss scale should be updated based on overflow detection.

Parameters:

overflow (bool) – Whether overflow occurs

DynamicGradScaler

class lightrft.strategy.fsdp.fsdp_utils.DynamicGradScaler(initial_scale: float = 65536, growth_factor: float = 2, backoff_factor: float = 0.5, growth_interval: int = 1000, min_scale: float | None = 1, max_scale: float | None = 16777216, hysteresis: int = 2, dtype: torch.dtype = torch.bfloat16)[source]

A gradient scaler which uses dynamic loss scale.

This scaler automatically adjusts the loss scale based on gradient overflow detection. It increases the scale when training is stable and decreases it when overflow occurs, providing automatic mixed precision training support.

Parameters:
  • initial_scale (float) – The initial loss scale

  • growth_factor (float) – The multiplication factor for increasing loss scale

  • backoff_factor (float) – The multiplication factor for decreasing loss scale

  • growth_interval (int) – The number of steps to increase loss scale when no overflow occurs

  • min_scale (Optional[float]) – The minimum loss scale

  • max_scale (Optional[float]) – The maximum loss scale

  • hysteresis (int) – The number of overflows before decreasing loss scale

  • dtype (torch.dtype) – The data type used for training

Example:

>>> scaler = DynamicGradScaler(initial_scale=2**16, growth_factor=2.0)
>>> # In training loop
>>> for epoch in range(num_epochs):
...     for batch in dataloader:
...         # Forward pass with scaled loss
...         scaled_loss = loss * scaler.scale
...         scaled_loss.backward()
...         # Check for overflow and update scaler
...         overflow = check_overflow(model.parameters())
...         scaler.update(overflow)
load_state_dict(state_dict)[source]

Load the states of the gradient scaler from a dict object.

This method restores the scaler state from a previously saved state dictionary, enabling seamless checkpoint restoration.

Parameters:

state_dict (dict) – The states of the gradient scaler

Example:

>>> scaler = DynamicGradScaler()
>>> scaler.load_state_dict({
...     "_scale": 2048.0,
...     "_growth_step": 0,
...     "_hysteresis_step": 0
... })
state_dict()[source]

Returns the states of the gradient scaler as a dict object.

This method provides a complete state dictionary that can be saved and restored to maintain training consistency across checkpoints.

Returns:

A dictionary containing the current state of the gradient scaler

Return type:

dict

Example:

>>> scaler = DynamicGradScaler()
>>> scaler_state = scaler.state_dict()
>>> print(scaler_state.keys())
dict_keys(['_scale', '_growth_step', '_hysteresis_step'])
update(overflow: bool) None[source]

Update the loss scale based on whether overflow occurred.

This method implements the dynamic scaling algorithm. When overflow occurs, it increments the hysteresis counter and resets growth progress. When no overflow occurs for a sufficient period, it increases the scale to maximize gradient precision.

Parameters:

overflow (bool) – Whether overflow occurs

Example:

>>> scaler = DynamicGradScaler()
>>> # Simulate training steps
>>> scaler.update(False)  # No overflow, increment growth counter
>>> scaler.update(True)   # Overflow detected, may decrease scale

BaseOptimizer

class lightrft.strategy.fsdp.fsdp_utils.BaseOptimizer(*args: Any, **kwargs: Any)[source]

Base Optimizer class that wraps a PyTorch optimizer.

This class provides a wrapper around PyTorch optimizers, exposing the same interface while allowing for additional functionality like custom backward passes and gradient clipping. It serves as a foundation for building more sophisticated optimizers with enhanced features for distributed training, gradient scaling, and custom optimization strategies.

Parameters:

optim (torch.optim.Optimizer) – The PyTorch optimizer to wrap

Example:

>>> import torch.optim as optim
>>> model = torch.nn.Linear(10, 1)
>>> pytorch_optimizer = optim.Adam(model.parameters(), lr=0.001)
>>> wrapped_optimizer = BaseOptimizer(pytorch_optimizer)
>>> # Use wrapped_optimizer like a regular optimizer
>>> wrapped_optimizer.zero_grad()
>>> loss.backward()
>>> wrapped_optimizer.step()
add_param_group(*args, **kwargs)[source]

Add a parameter group to the optimizer.

This method allows adding new parameter groups with potentially different optimization settings during training.

Parameters:
  • args – Positional arguments to pass to the wrapped optimizer

  • kwargs – Keyword arguments to pass to the wrapped optimizer

Returns:

Result from the wrapped optimizer’s add_param_group method

Example:

>>> optimizer = BaseOptimizer(torch.optim.Adam(model.parameters()))
>>> new_params = [torch.randn(10, requires_grad=True)]
>>> optimizer.add_param_group({'params': new_params, 'lr': 0.01})
backward(loss)[source]

Compute gradients of the loss.

This method performs backpropagation to compute gradients of the loss with respect to the model parameters.

Parameters:

loss (torch.Tensor) – The loss tensor to compute gradients for

Example:

>>> optimizer = BaseOptimizer(torch.optim.Adam(model.parameters()))
>>> loss = criterion(output, target)
>>> optimizer.backward(loss)
backward_by_grad(tensor, grad)[source]

Compute gradients of the tensor with respect to the provided gradients.

This method allows for custom gradient computation, useful in scenarios like gradient accumulation or custom loss functions.

Parameters:
  • tensor (torch.Tensor) – The tensor to compute gradients for

  • grad (torch.Tensor) – The gradients to backpropagate

Example:

>>> optimizer = BaseOptimizer(torch.optim.Adam(model.parameters()))
>>> output = model(input)
>>> custom_grad = torch.randn_like(output)
>>> optimizer.backward_by_grad(output, custom_grad)
clip_grad_norm()[source]

Clip the gradient norm.

This is a placeholder method that should be implemented by subclasses to provide gradient clipping functionality. Gradient clipping helps prevent gradient explosion in deep networks.

Example:

>>> class MyOptimizer(BaseOptimizer):
...     def clip_grad_norm(self):
...         torch.nn.utils.clip_grad_norm_(self.param_groups[0]['params'], max_norm=1.0)
property defaults

Access to the default parameters of the wrapped optimizer.

Returns:

Default parameters dictionary

Return type:

dict

Example:

>>> optimizer = BaseOptimizer(torch.optim.Adam(model.parameters(), lr=0.001))
>>> print(optimizer.defaults['lr'])  # 0.001
load_state_dict(*args, **kwargs)[source]

Load the optimizer state.

This method restores the optimizer’s internal state from a state dictionary, enabling checkpoint restoration and training resumption.

Parameters:
  • args – Positional arguments to pass to the wrapped optimizer

  • kwargs – Keyword arguments to pass to the wrapped optimizer

Example:

>>> optimizer = BaseOptimizer(torch.optim.Adam(model.parameters()))
>>> state_dict = torch.load('optimizer_checkpoint.pth')
>>> optimizer.load_state_dict(state_dict)
property param_groups

Access to the parameter groups of the wrapped optimizer.

Parameter groups allow different sets of parameters to have different optimization settings like learning rates, weight decay, etc.

Returns:

List of parameter groups

Return type:

list

Example:

>>> optimizer = BaseOptimizer(torch.optim.Adam(model.parameters()))
>>> print(len(optimizer.param_groups))  # Number of parameter groups
>>> print(optimizer.param_groups[0]['lr'])  # Learning rate of first group
state_dict()[source]

Return the state of the optimizer as a dict.

This method provides the optimizer’s complete state for checkpointing, including parameter states and hyperparameters.

Returns:

The state of the optimizer

Return type:

dict

Example:

>>> optimizer = BaseOptimizer(torch.optim.Adam(model.parameters()))
>>> state_dict = optimizer.state_dict()
>>> torch.save(state_dict, 'optimizer_checkpoint.pth')
step(*args, **kwargs)[source]

Perform a single optimization step.

This method executes one optimization step, updating the model parameters based on their gradients and the optimizer’s algorithm.

Parameters:
  • args – Positional arguments to pass to the wrapped optimizer

  • kwargs – Keyword arguments to pass to the wrapped optimizer

Returns:

Result from the wrapped optimizer’s step method

Example:

>>> optimizer = BaseOptimizer(torch.optim.Adam(model.parameters()))
>>> loss.backward()
>>> optimizer.step()
zero_grad(*args, **kwargs)[source]

Reset the gradients of all optimized tensors.

This method clears the gradients of all parameters, which is typically done before each backward pass to prevent gradient accumulation.

Parameters:
  • args – Positional arguments to pass to the wrapped optimizer

  • kwargs – Keyword arguments to pass to the wrapped optimizer

Example:

>>> optimizer = BaseOptimizer(torch.optim.Adam(model.parameters()))
>>> optimizer.zero_grad()  # Clear gradients
>>> loss.backward()        # Compute new gradients
>>> optimizer.step()       # Update parameters