Source code for lightrft.strategy.utils.optimizer_utils
"""
PyTorch Optimization Utilities Module
This module provides utility functions for optimizing PyTorch models,
particularly focused on parameter grouping for optimizers with
customized weight decay settings. It includes support for both regular
tensors and distributed tensors (DTensor) with specialized grouping
strategies for optimal performance in distributed training scenarios.
"""
from collections import defaultdict
from typing import List, Optional
import torch
from torch.distributed.tensor import DTensor
_DEFAULT_NO_DECAY_NAME_LIST = [
"bias",
"layer_norm.weight",
"layernorm.weight",
"norm.weight",
"ln_f.weight",
]
[docs]def get_optimizer_grouped_parameters( # pylint: disable=W0102
model,
weight_decay,
no_decay_name_list: Optional[List[str]] = None,
):
"""
Prepare parameter groups for optimizer with weight decay control.
Groups parameters into two groups:
- Parameters that should have weight decay applied
- Parameters that should not have weight decay applied (typically normalization layers and biases)
:param model: The model whose parameters will be organized.
:type model: torch.nn.Module
:param weight_decay: Weight decay value to apply to applicable parameters.
:type weight_decay: float
:param no_decay_name_list: List of parameter name patterns that should not have weight decay.
If None, defaults to _DEFAULT_NO_DECAY_NAME_LIST.
:type no_decay_name_list: Optional[List[str]]
:return: List of parameter groups for the optimizer.
:rtype: list
Example::
>>> import torch.nn as nn
>>> model = nn.Sequential(nn.Linear(10, 10), nn.LayerNorm(10))
>>> grouped_params = get_optimizer_grouped_parameters(model, weight_decay=0.01)
>>> optimizer = torch.optim.AdamW(grouped_params)
"""
if no_decay_name_list is None:
no_decay_name_list = _DEFAULT_NO_DECAY_NAME_LIST
optimizer_grouped_parameters = [
{
"params": [
p for n, p in model.named_parameters()
if (not any(nd in n for nd in no_decay_name_list) and p.requires_grad)
],
"weight_decay": weight_decay,
},
{
"params": [
p for n, p in model.named_parameters()
if (any(nd in n for nd in no_decay_name_list) and p.requires_grad)
],
"weight_decay": 0.0,
},
]
return optimizer_grouped_parameters
def group_parameters_for_optimizer_dtensor(
model: torch.nn.Module, weight_decay: float, no_decay_name_list: Optional[List[str]] = None
):
"""
Groups model parameters for optimizer by weight decay, dtype, and device mesh for DTensor.
This function creates parameter groups optimized for distributed tensor scenarios by
considering not only weight decay patterns but also data types and device mesh
configurations. This grouping strategy helps optimize memory usage and communication
patterns in distributed training setups.
:param model: The model whose parameters will be organized.
:type model: torch.nn.Module
:param weight_decay: Weight decay value to apply to applicable parameters.
:type weight_decay: float
:param no_decay_name_list: List of parameter name patterns that should not have weight decay.
If None, defaults to _DEFAULT_NO_DECAY_NAME_LIST.
:type no_decay_name_list: Optional[List[str]]
:return: Dictionary mapping group keys to parameter lists. Group keys are tuples of
(weight_decay_value, dtype, mesh_info) where mesh_info describes the
distributed tensor configuration.
:rtype: defaultdict[tuple, list]
Example::
>>> import torch.nn as nn
>>> model = nn.Sequential(nn.Linear(10, 10), nn.LayerNorm(10))
>>> grouped_params = group_parameters_for_optimizer_dtensor(model, weight_decay=0.01)
>>> # Convert to optimizer format
>>> optimizer_groups = [{"params": params, "weight_decay": wd}
... for (wd, dtype, mesh), params in grouped_params.items()]
"""
if no_decay_name_list is None:
no_decay_name_list = _DEFAULT_NO_DECAY_NAME_LIST
# Use a dict to store unique groups, keyed by (weight_decay_status, dtype, mesh_info)
grouped_params_temp = defaultdict(list)
for name, param in model.named_parameters():
if not param.requires_grad:
continue
# Determine weight decay application
apply_weight_decay = not any(nd in name for nd in no_decay_name_list)
current_weight_decay = weight_decay if apply_weight_decay else 0.0
# Determine device mesh info
param_mesh_info = "full_tensor"
# Check if param is a DTensor (assuming it has _is_dtensor and device_mesh)
if isinstance(param, DTensor):
param_mesh_info = f"dtensor_{param.device_mesh.shape[0]}"
# Grouping key: (weight_decay_value, param_dtype, param_mesh_info)
group_key = (current_weight_decay, param.dtype, param_mesh_info)
grouped_params_temp[group_key].append(param)
return grouped_params_temp