Source code for lightrft.strategy.deepspeed.deepspeed_utils
"""
DeepSpeed Configuration and Optimization Utilities Module.
This module provides utility functions for configuring DeepSpeed for training and evaluation,
managing optimizer parameters, and handling DeepSpeed ZeRO stage 3 states. It includes
functions for creating DeepSpeed configurations with various optimization options,
organizing model parameters for optimizers with weight decay control, and
offloading/reloading DeepSpeed states to manage memory efficiently.
"""
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
[docs]def get_train_ds_config( # pylint: disable=R0917
offload,
adam_offload=True,
stage=2,
bf16=True,
max_norm=1.0,
zpg=8,
grad_accum_dtype=None,
overlap_comm=False,
):
"""
Generate a DeepSpeed configuration dictionary for training.
:param offload: Whether to offload parameters to CPU.
:type offload: bool
:param adam_offload: Whether to offload Adam optimizer states to CPU.
:type adam_offload: bool
:param stage: ZeRO optimization stage (0, 1, 2, or 3).
:type stage: int
:param bf16: Whether to use bfloat16 precision.
:type bf16: bool
:param max_norm: Maximum norm for gradient clipping.
:type max_norm: float
:param zpg: ZeRO++ partition size.
:type zpg: int
:param grad_accum_dtype: Data type for gradient accumulation.
:type grad_accum_dtype: str or None
:param overlap_comm: Whether to overlap communication with computation.
:type overlap_comm: bool
:return: DeepSpeed configuration dictionary for training.
:rtype: dict
"""
device = "cpu" if offload else "none"
zero_opt_dict = {
"stage": stage,
"offload_param": {
"device": device
},
"offload_optimizer": {
"device": "cpu" if adam_offload else "none",
"pin_memory": True,
},
"sub_group_size": "auto",
"stage3_max_live_parameters": "auto",
"stage3_max_reuse_distance": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_prefetch_bucket_size": "auto",
"reduce_bucket_size": "auto",
# ZeRO++
"zero_hpz_partition_size": zpg,
"zero_quantized_weights": False,
"zero_quantized_gradients": False,
}
if overlap_comm:
zero_opt_dict["overlap_comm"] = True
zero_opt_dict["contiguous_gradients"] = True
return {
"steps_per_print": 100,
"zero_optimization": zero_opt_dict,
"bf16": {
"enabled": bf16,
},
"gradient_clipping": max_norm,
"prescale_gradients": False,
"wall_clock_breakdown": False,
"data_types": {
"grad_accum_dtype": grad_accum_dtype
},
}
[docs]def get_eval_ds_config(
offload,
stage=0,
bf16=True,
):
"""
Generate a DeepSpeed configuration dictionary for evaluation.
:param offload: Whether to offload parameters to CPU.
:type offload: bool
:param stage: ZeRO optimization stage (0, 1, 2, or 3).
:type stage: int
:param bf16: Whether to use bfloat16 precision.
:type bf16: bool
:return: DeepSpeed configuration dictionary for evaluation.
:rtype: dict
"""
zero_opt_dict = {
"stage": stage,
"stage3_param_persistence_threshold": "auto",
"offload_param": {
"device": "cpu" if offload else "none",
"pin_memory": True,
},
}
return {
"steps_per_print": 100,
"zero_optimization": zero_opt_dict,
"bf16": {
"enabled": bf16,
},
"gradient_clipping": 1.0,
"prescale_gradients": False,
"wall_clock_breakdown": False,
}
def _z3_params_to_fetch(param_list):
"""
Filter parameters that need to be fetched in ZeRO stage 3.
:param param_list: List of parameters to check.
:type param_list: list
:return: List of parameters that are not available and need to be fetched.
:rtype: list
"""
return [p for p in param_list if hasattr(p, "ds_id") and p.ds_status == ZeroParamStatus.NOT_AVAILABLE]
[docs]def offload_deepspeed_states(model, pin_memory=True, non_blocking=True):
"""
Offload DeepSpeed optimizer states to CPU to save GPU memory.
This function is particularly useful for ZeRO stage 3 when not using Adam optimizer offloading.
It offloads various states to CPU, empties partition cache, and synchronizes devices.
:param model: DeepSpeed model with optimizer.
:type model: deepspeed.DeepSpeedEngine
:param pin_memory: Whether to use pinned memory for offloaded states.
:type pin_memory: bool
:param non_blocking: Whether to perform non-blocking transfers.
:type non_blocking: bool
:raises NotImplementedError: If ZeRO stage is not 3.
"""
zero_stage = model.zero_optimization_stage() # config['zero_optimization']['stage']
adam_offload = model.config["zero_optimization"]["offload_optimizer"]["device"] == "cpu"
# state offloading not required when using Adam optimizer offloading
if adam_offload:
return
if zero_stage != 3:
raise NotImplementedError("Only Zero stage 3 is currently supported")
# if zero_stage == 3 and not adam_offload:
import torch
from deepspeed.runtime.zero.offload_config import (
OffloadDeviceEnum,
OffloadStateTypeEnum,
)
model.optimizer.offload_states(
include=[
OffloadStateTypeEnum.optim_states,
OffloadStateTypeEnum.contiguous_grad_buffer,
OffloadStateTypeEnum.hp_params,
# Not released yet, fixed in https://github.com/deepspeedai/DeepSpeed/pull/7050
# OffloadStateTypeEnum.lp_grads,
# OffloadStateTypeEnum.lp_params,
],
device=OffloadDeviceEnum.cpu,
pin_memory=pin_memory,
non_blocking=non_blocking,
)
model.empty_partition_cache()
torch.cuda.empty_cache()
torch.distributed.barrier()
torch.cuda.synchronize()
[docs]def reload_deepspeed_states(model, non_blocking=True):
"""
Reload DeepSpeed optimizer states from CPU back to GPU.
This function is used to restore states previously offloaded with offload_deepspeed_states().
:param model: DeepSpeed model with optimizer.
:type model: deepspeed.DeepSpeedEngine
:param non_blocking: Whether to perform non-blocking transfers.
:type non_blocking: bool
:raises NotImplementedError: If ZeRO stage is not 3.
"""
zero_stage = model.zero_optimization_stage() # config['zero_optimization']['stage']
adam_offload = model.config["zero_optimization"]["offload_optimizer"]["device"] == "cpu"
# state offloading not required when using Adam optimizer offloading
if adam_offload:
return
if zero_stage != 3:
raise NotImplementedError("Only Zero stage 3 is currently supported")
# if zero_stage == 3 and not adam_offload:
import torch
model.reload_states(non_blocking=non_blocking)
torch.cuda.empty_cache()
torch.distributed.barrier()
torch.cuda.synchronize()