lightrft.strategy.fsdp.fsdp_utils¶
Gradient scaling and optimization utilities for deep learning.
This module provides tools for gradient handling, norm computation, and optimization in PyTorch. It includes dynamic gradient scaling for mixed precision training, gradient norm computation, and base classes for optimizers.
is_meta_initialized¶
- lightrft.strategy.fsdp.fsdp_utils.is_meta_initialized(model) bool[source]¶
Check if a PyTorch model’s parameters are meta-initialized.
Meta-initialized models contain parameters on a ‘meta’ device, which are placeholders that don’t allocate actual memory. These are useful for model initialization without memory overhead, commonly used in model parallelism and large model initialization.
For more information on meta device and meta tensors, see: https://docs.pytorch.org/docs/stable/meta.html
- Parameters:
model (torch.nn.Module) – The PyTorch module to check.
- Raises:
TypeError – if
modelis not an instance oftorch.nn.Module.- Returns:
True if any parameter in the model is on a meta device, False otherwise.
- Return type:
bool
Example:
>>> import torch >>> model = torch.nn.Linear(10, 1) >>> is_meta_initialized(model) # False for regular model False >>> with torch.device('meta'): ... meta_model = torch.nn.Linear(10, 1) >>> is_meta_initialized(meta_model) # True for meta model True
multi_tensor_l2norm_torch¶
- lightrft.strategy.fsdp.fsdp_utils.multi_tensor_l2norm_torch(tensor_list, per_tensor)[source]¶
Compute L2 norm of multiple tensors using PyTorch.
This function provides a pure PyTorch implementation for computing L2 norms when APEX is not available. It converts all tensors to float32 for computation and returns both overall and per-tensor norms.
- Parameters:
tensor_list (list[torch.Tensor]) – List of tensors to compute norm for
per_tensor (bool) – Whether to return per-tensor norms
- Returns:
Tuple of (overall L2 norm, per-tensor norms)
- Return type:
tuple[torch.Tensor, torch.Tensor]
Example:
>>> tensors = [torch.randn(3, 3), torch.randn(2, 2)] >>> overall_norm, per_tensor_norms = multi_tensor_l2norm_torch(tensors, True) >>> print(overall_norm.shape) # torch.Size([1]) >>> print(per_tensor_norms.shape) # torch.Size([2])
calc_l2_norm¶
- lightrft.strategy.fsdp.fsdp_utils.calc_l2_norm(grads)[source]¶
Calculate L2 norm of gradients using optimized implementation when available.
This function automatically selects the fastest available implementation for computing L2 norms. It uses APEX’s multi-tensor operations when available for better performance, otherwise falls back to PyTorch implementation.
- Parameters:
grads (list[torch.Tensor]) – List of gradient tensors
- Returns:
L2 norm of gradients
- Return type:
torch.Tensor
Example:
>>> grads = [torch.randn(10, 10).requires_grad_(), torch.randn(5, 5).requires_grad_()] >>> norm = calc_l2_norm(grads) >>> print(norm.item()) # scalar value
calc_lp¶
- lightrft.strategy.fsdp.fsdp_utils.calc_lp(grads, norm_type)[source]¶
Calculate Lp norm of gradients.
Computes the p-norm of a list of gradient tensors, where p is specified by norm_type. This is useful for gradient clipping and monitoring.
- Parameters:
grads (list[torch.Tensor]) – List of gradient tensors
norm_type (float) – The p in Lp norm
- Returns:
Lp norm of gradients
- Return type:
float
Example:
>>> grads = [torch.randn(3, 3), torch.randn(2, 2)] >>> l1_norm = calc_lp(grads, 1.0) # L1 norm >>> l2_norm = calc_lp(grads, 2.0) # L2 norm
get_norm¶
- lightrft.strategy.fsdp.fsdp_utils.get_norm(grads, norm_type, enable_cuda_kernels)[source]¶
Get norm of gradients with specified norm type.
This function dispatches to the appropriate norm calculation method based on the norm type and whether CUDA kernels are available. It handles special cases like infinity norm and optimized L2 norm computation.
- Parameters:
grads (list[torch.Tensor]) – List of gradient tensors
norm_type (float) – Type of norm to compute (2.0, inf, etc.)
enable_cuda_kernels (bool) – Whether to use CUDA optimized kernels
- Returns:
Norm of gradients
- Return type:
float or torch.Tensor
Example:
>>> grads = [torch.randn(3, 3), torch.randn(2, 2)] >>> l2_norm = get_norm(grads, 2.0, True) >>> inf_norm = get_norm(grads, float('inf'), True)
reduce_grads¶
- lightrft.strategy.fsdp.fsdp_utils.reduce_grads(gradients, parameters)[source]¶
Prepare gradients for norm computation in distributed training.
This function processes gradients to prepare them for norm computation, particularly in distributed training scenarios. It converts gradients to float32 for numerical stability during norm calculations.
- Parameters:
gradients (list[torch.Tensor]) – List of gradient tensors
parameters (list[torch.Tensor]) – List of parameter tensors
- Returns:
List of processed gradient tensors
- Return type:
list[torch.Tensor]
Example:
>>> grads = [torch.randn(3, 3), torch.randn(2, 2)] >>> params = [torch.randn(3, 3), torch.randn(2, 2)] >>> processed_grads = reduce_grads(grads, params)
get_tensor_norm¶
- lightrft.strategy.fsdp.fsdp_utils.get_tensor_norm(norm: float | torch.Tensor, move_to_cuda) torch.Tensor[source]¶
Convert norm to tensor and move to appropriate device.
This utility function ensures that norm values are properly converted to tensors and placed on the correct device for further computation. It handles both scalar and tensor inputs.
- Parameters:
norm (Union[float, torch.Tensor]) – Norm value as float or tensor
move_to_cuda (bool) – Whether to move the tensor to CUDA
- Returns:
Norm as tensor on appropriate device
- Return type:
torch.Tensor
Example:
>>> norm_float = 2.5 >>> norm_tensor = get_tensor_norm(norm_float, True) >>> print(norm_tensor.device) # cuda:0 (if CUDA available)
compute_norm¶
- lightrft.strategy.fsdp.fsdp_utils.compute_norm(gradients, parameters, norm_type=2)[source]¶
Get the norm across distributed environment.
This function computes gradient norms in a distributed training setting, handling device placement and distributed reduction. It’s commonly used for gradient clipping and monitoring training stability.
- Parameters:
gradients (list[torch.Tensor]) – The gradient values
parameters (list[torch.Tensor]) – The parameters each gradient corresponds to
norm_type (float or int) – Type of the used p-norm. Can be
'inf'for infinity norm
- Returns:
Total norm of the parameters, need total_norm**(1/norm) before using
- Return type:
float
Example:
>>> grads = [param.grad for param in model.parameters() if param.grad is not None] >>> params = [param for param in model.parameters() if param.grad is not None] >>> total_norm = compute_norm(grads, params) >>> print(f"Gradient norm: {total_norm}")
BaseGradScaler¶
- class lightrft.strategy.fsdp.fsdp_utils.BaseGradScaler(initial_scale: float)[source]¶
A base class for the gradient scaler.
This abstract base class defines the interface for gradient scalers used in mixed precision training. Gradient scalers help prevent gradient underflow in float16 training by scaling up the loss before backpropagation.
- Parameters:
initial_scale (float) – The initial loss scale
Example:
>>> # Subclass implementation >>> class MyGradScaler(BaseGradScaler): ... def update(self, overflow: bool) -> None: ... # Custom update logic ... pass
- property inv_scale: torch.Tensor¶
Returns the inverse of the loss scale.
The inverse scale is used to unscale gradients after backpropagation to restore their original magnitudes.
- Returns:
Inverse of current loss scale
- Return type:
torch.Tensor
Example:
>>> scaler = DynamicGradScaler(initial_scale=1024.0) >>> print(scaler.inv_scale.item()) # 0.0009765625 (1/1024)
- load_state_dict(state_dict: Dict) None[source]¶
Load the states of the gradient scaler from a dict object.
- Parameters:
state_dict (Dict) – The states of the gradient scaler
Example:
>>> scaler = DynamicGradScaler() >>> state = {"scale": torch.tensor([2048.0])} >>> scaler.load_state_dict(state)
- property scale: torch.Tensor¶
Returns the loss scale.
- Returns:
Current loss scale
- Return type:
torch.Tensor
Example:
>>> scaler = DynamicGradScaler(initial_scale=1024.0) >>> print(scaler.scale.item()) # 1024.0
DynamicGradScaler¶
- class lightrft.strategy.fsdp.fsdp_utils.DynamicGradScaler(initial_scale: float = 65536, growth_factor: float = 2, backoff_factor: float = 0.5, growth_interval: int = 1000, min_scale: float | None = 1, max_scale: float | None = 16777216, hysteresis: int = 2, dtype: torch.dtype = torch.bfloat16)[source]¶
A gradient scaler which uses dynamic loss scale.
This scaler automatically adjusts the loss scale based on gradient overflow detection. It increases the scale when training is stable and decreases it when overflow occurs, providing automatic mixed precision training support.
- Parameters:
initial_scale (float) – The initial loss scale
growth_factor (float) – The multiplication factor for increasing loss scale
backoff_factor (float) – The multiplication factor for decreasing loss scale
growth_interval (int) – The number of steps to increase loss scale when no overflow occurs
min_scale (Optional[float]) – The minimum loss scale
max_scale (Optional[float]) – The maximum loss scale
hysteresis (int) – The number of overflows before decreasing loss scale
dtype (torch.dtype) – The data type used for training
Example:
>>> scaler = DynamicGradScaler(initial_scale=2**16, growth_factor=2.0) >>> # In training loop >>> for epoch in range(num_epochs): ... for batch in dataloader: ... # Forward pass with scaled loss ... scaled_loss = loss * scaler.scale ... scaled_loss.backward() ... # Check for overflow and update scaler ... overflow = check_overflow(model.parameters()) ... scaler.update(overflow)
- load_state_dict(state_dict)[source]¶
Load the states of the gradient scaler from a dict object.
This method restores the scaler state from a previously saved state dictionary, enabling seamless checkpoint restoration.
- Parameters:
state_dict (dict) – The states of the gradient scaler
Example:
>>> scaler = DynamicGradScaler() >>> scaler.load_state_dict({ ... "_scale": 2048.0, ... "_growth_step": 0, ... "_hysteresis_step": 0 ... })
- state_dict()[source]¶
Returns the states of the gradient scaler as a dict object.
This method provides a complete state dictionary that can be saved and restored to maintain training consistency across checkpoints.
- Returns:
A dictionary containing the current state of the gradient scaler
- Return type:
dict
Example:
>>> scaler = DynamicGradScaler() >>> scaler_state = scaler.state_dict() >>> print(scaler_state.keys()) dict_keys(['_scale', '_growth_step', '_hysteresis_step'])
- update(overflow: bool) None[source]¶
Update the loss scale based on whether overflow occurred.
This method implements the dynamic scaling algorithm. When overflow occurs, it increments the hysteresis counter and resets growth progress. When no overflow occurs for a sufficient period, it increases the scale to maximize gradient precision.
- Parameters:
overflow (bool) – Whether overflow occurs
Example:
>>> scaler = DynamicGradScaler() >>> # Simulate training steps >>> scaler.update(False) # No overflow, increment growth counter >>> scaler.update(True) # Overflow detected, may decrease scale
BaseOptimizer¶
- class lightrft.strategy.fsdp.fsdp_utils.BaseOptimizer(*args: Any, **kwargs: Any)[source]¶
Base Optimizer class that wraps a PyTorch optimizer.
This class provides a wrapper around PyTorch optimizers, exposing the same interface while allowing for additional functionality like custom backward passes and gradient clipping. It serves as a foundation for building more sophisticated optimizers with enhanced features for distributed training, gradient scaling, and custom optimization strategies.
- Parameters:
optim (torch.optim.Optimizer) – The PyTorch optimizer to wrap
Example:
>>> import torch.optim as optim >>> model = torch.nn.Linear(10, 1) >>> pytorch_optimizer = optim.Adam(model.parameters(), lr=0.001) >>> wrapped_optimizer = BaseOptimizer(pytorch_optimizer) >>> # Use wrapped_optimizer like a regular optimizer >>> wrapped_optimizer.zero_grad() >>> loss.backward() >>> wrapped_optimizer.step()
- add_param_group(*args, **kwargs)[source]¶
Add a parameter group to the optimizer.
This method allows adding new parameter groups with potentially different optimization settings during training.
- Parameters:
args – Positional arguments to pass to the wrapped optimizer
kwargs – Keyword arguments to pass to the wrapped optimizer
- Returns:
Result from the wrapped optimizer’s add_param_group method
Example:
>>> optimizer = BaseOptimizer(torch.optim.Adam(model.parameters())) >>> new_params = [torch.randn(10, requires_grad=True)] >>> optimizer.add_param_group({'params': new_params, 'lr': 0.01})
- backward(loss)[source]¶
Compute gradients of the loss.
This method performs backpropagation to compute gradients of the loss with respect to the model parameters.
- Parameters:
loss (torch.Tensor) – The loss tensor to compute gradients for
Example:
>>> optimizer = BaseOptimizer(torch.optim.Adam(model.parameters())) >>> loss = criterion(output, target) >>> optimizer.backward(loss)
- backward_by_grad(tensor, grad)[source]¶
Compute gradients of the tensor with respect to the provided gradients.
This method allows for custom gradient computation, useful in scenarios like gradient accumulation or custom loss functions.
- Parameters:
tensor (torch.Tensor) – The tensor to compute gradients for
grad (torch.Tensor) – The gradients to backpropagate
Example:
>>> optimizer = BaseOptimizer(torch.optim.Adam(model.parameters())) >>> output = model(input) >>> custom_grad = torch.randn_like(output) >>> optimizer.backward_by_grad(output, custom_grad)
- clip_grad_norm()[source]¶
Clip the gradient norm.
This is a placeholder method that should be implemented by subclasses to provide gradient clipping functionality. Gradient clipping helps prevent gradient explosion in deep networks.
Example:
>>> class MyOptimizer(BaseOptimizer): ... def clip_grad_norm(self): ... torch.nn.utils.clip_grad_norm_(self.param_groups[0]['params'], max_norm=1.0)
- property defaults¶
Access to the default parameters of the wrapped optimizer.
- Returns:
Default parameters dictionary
- Return type:
dict
Example:
>>> optimizer = BaseOptimizer(torch.optim.Adam(model.parameters(), lr=0.001)) >>> print(optimizer.defaults['lr']) # 0.001
- load_state_dict(*args, **kwargs)[source]¶
Load the optimizer state.
This method restores the optimizer’s internal state from a state dictionary, enabling checkpoint restoration and training resumption.
- Parameters:
args – Positional arguments to pass to the wrapped optimizer
kwargs – Keyword arguments to pass to the wrapped optimizer
Example:
>>> optimizer = BaseOptimizer(torch.optim.Adam(model.parameters())) >>> state_dict = torch.load('optimizer_checkpoint.pth') >>> optimizer.load_state_dict(state_dict)
- property param_groups¶
Access to the parameter groups of the wrapped optimizer.
Parameter groups allow different sets of parameters to have different optimization settings like learning rates, weight decay, etc.
- Returns:
List of parameter groups
- Return type:
list
Example:
>>> optimizer = BaseOptimizer(torch.optim.Adam(model.parameters())) >>> print(len(optimizer.param_groups)) # Number of parameter groups >>> print(optimizer.param_groups[0]['lr']) # Learning rate of first group
- state_dict()[source]¶
Return the state of the optimizer as a dict.
This method provides the optimizer’s complete state for checkpointing, including parameter states and hyperparameters.
- Returns:
The state of the optimizer
- Return type:
dict
Example:
>>> optimizer = BaseOptimizer(torch.optim.Adam(model.parameters())) >>> state_dict = optimizer.state_dict() >>> torch.save(state_dict, 'optimizer_checkpoint.pth')
- step(*args, **kwargs)[source]¶
Perform a single optimization step.
This method executes one optimization step, updating the model parameters based on their gradients and the optimizer’s algorithm.
- Parameters:
args – Positional arguments to pass to the wrapped optimizer
kwargs – Keyword arguments to pass to the wrapped optimizer
- Returns:
Result from the wrapped optimizer’s step method
Example:
>>> optimizer = BaseOptimizer(torch.optim.Adam(model.parameters())) >>> loss.backward() >>> optimizer.step()
- zero_grad(*args, **kwargs)[source]¶
Reset the gradients of all optimized tensors.
This method clears the gradients of all parameters, which is typically done before each backward pass to prevent gradient accumulation.
- Parameters:
args – Positional arguments to pass to the wrapped optimizer
kwargs – Keyword arguments to pass to the wrapped optimizer
Example:
>>> optimizer = BaseOptimizer(torch.optim.Adam(model.parameters())) >>> optimizer.zero_grad() # Clear gradients >>> loss.backward() # Compute new gradients >>> optimizer.step() # Update parameters