Source code for lightrft.strategy.fsdp.fsdpv2
"""
Hugging Face FSDP (Fully Sharded Data Parallel) Strategy Module.
This module provides implementations for distributed training using PyTorch's FSDP.
It includes utilities for model wrapping, optimization, checkpointing, and state management
in a distributed training environment. The module supports FSDP v2 strategy,
with special handling for model sharding, mixed precision training, and optimizer state
management.
"""
import os
import shutil
from collections import defaultdict
from contextlib import contextmanager
from typing import List, Tuple, Union
import torch
from torch import distributed as dist
from torch import nn, optim
try:
from torch.distributed.fsdp import (
CPUOffloadPolicy,
FSDPModule,
MixedPrecisionPolicy,
OffloadPolicy,
fully_shard,
register_fsdp_forward_method,
)
except ImportError:
from torch.distributed._composable.fsdp import (
fully_shard,
register_fsdp_forward_method,
MixedPrecisionPolicy,
CPUOffloadPolicy,
OffloadPolicy,
FSDPModule,
)
import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.state_dict import (
get_model_state_dict,
get_optimizer_state_dict,
set_model_state_dict,
set_optimizer_state_dict,
)
from torch.distributed.device_mesh import init_device_mesh
from torch.optim import Optimizer
from transformers.trainer_pt_utils import get_module_class_from_name
from lightrft.strategy.strategy_base import StrategyBase, is_actor
from lightrft.strategy.utils.optimizer_utils import group_parameters_for_optimizer_dtensor
from lightrft.strategy.utils.ckpt_utils import find_latest_checkpoint_dir
from .fsdp_optimizer import (
FSDPadaptOptimizer,
load_fsdp_optimizer,
offload_fsdp_optimizer,
)
from .fsdp_utils import is_meta_initialized
ModelOptimPair = Tuple[nn.Module, Optimizer]
ModelOrModelOptimPair = Union[nn.Module, ModelOptimPair]
manual_transformer_cls_names_to_wrap = [
"Embedding",
"Qwen2VLDecoderLayer",
"Qwen2VLVisionBlock",
"Qwen2_5_VLVisionBlock",
"Qwen2_5_VLDecoderLayer",
"Qwen2DecoderLayer",
"LlamaDecoderLayer", # for DeepSeek-R1-Distill-Llama-70B
"DeepseekDecoderLayer",
]
vit_transformer_cls_names = [
"Qwen2VLVisionBlock",
"Qwen2_5_VLVisionBlock",
]
[docs]class FSDPV2Strategy(StrategyBase):
"""
The strategy for training with PyTorch's Fully Sharded Data Parallel V2.
This strategy implements model sharding using PyTorch's FSDP to enable training
of large models across multiple GPUs with memory efficiency.
:param seed: Random seed for reproducibility.
:type seed: int
:param max_norm: Maximum gradient norm for gradient clipping. If 0.0, no clipping is performed.
:type max_norm: float
:param micro_train_batch_size: Batch size for a single training step.
:type micro_train_batch_size: int
:param train_batch_size: Total batch size for training.
:type train_batch_size: int
:param bf16: Whether to use bfloat16 precision.
:type bf16: bool
:param args: Additional arguments for the strategy.
:type args: object
"""
[docs] def __init__( # pylint: disable=R0917
self,
seed: int = 42,
max_norm: float = 0.0,
micro_train_batch_size: int = 1,
train_batch_size: int = 1,
bf16: bool = True,
args=None,
) -> None:
"""
Initialize the FSDP V2 strategy.
:param seed: Random seed for reproducibility
:type seed: int
:param max_norm: Maximum gradient norm for gradient clipping. If 0.0, no clipping is performed
:type max_norm: float
:param micro_train_batch_size: Batch size for a single training step
:type micro_train_batch_size: int
:param train_batch_size: Total batch size for training
:type train_batch_size: int
:param bf16: Whether to use bfloat16 precision
:type bf16: bool
:param args: Additional arguments for the strategy
:type args: object
"""
super().__init__(seed, max_norm, micro_train_batch_size, train_batch_size, args)
self.bf16 = bf16
self.mixed_mm_data = self.config.mixed_mm_data
self.use_naive_opt = not self.config.use_mp_opt
self.cur_step = defaultdict(int)
# fsdp cpu offload automatically offloads optimizer
if self.config.fsdp_cpu_offload:
self.config.adam_offload = False
self.print("FSDPV2Strategy fsdp_cpu_offload is True")
self.print("FSDPV2Strategy Initialized")
[docs] def create_optimizer(self, model, **kwargs) -> Optimizer:
"""
Create an optimizer for the model with proper parameter grouping.
Groups parameters by dtype, dtensor shard size, and weight decay to avoid errors
during gradient clipping and optimization steps.
:param model: The model for which to create the optimizer
:type model: torch.nn.Module
:param kwargs: Additional arguments for the optimizer, including weight_decay
:return: The created optimizer
:rtype: torch.optim.Optimizer
Example::
>>> optimizer = strategy.create_optimizer(model, lr=1e-4, weight_decay=0.01)
"""
if is_actor(model):
model = model.model
# group params by (dtype, dtensor shard size, weight_dacay) to avoid error in clip_grad and opt.step
self.grouped_params = group_parameters_for_optimizer_dtensor(model, kwargs["weight_decay"])
# Convert the grouped parameters into the final format for the optimizer
optim_params = []
for (wd_val, _, _), params_list in self.grouped_params.items():
optim_params.append({
"params": params_list,
"weight_decay": wd_val,
})
optim = torch.optim.AdamW(optim_params, fused=True, **kwargs)
self.print(f"Creating optimizer with {self.use_naive_opt=} ")
if self.use_naive_opt:
return optim
else:
return FSDPadaptOptimizer(optim)
[docs] def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: optim.Optimizer, **kwargs) -> None:
"""
Perform backward pass for the loss.
:param loss: The loss tensor
:type loss: torch.Tensor
:param model: The model
:type model: torch.nn.Module
:param optimizer: The optimizer
:type optimizer: torch.optim.Optimizer
:param kwargs: Additional arguments
"""
loss.backward()
[docs] def optimizer_step(
self,
optimizer: optim.Optimizer,
model: nn.Module,
scheduler,
name="model",
**kwargs,
) -> None:
"""
Perform an optimization step.
Handles gradient accumulation by only stepping the optimizer and scheduler
after the specified number of accumulation steps.
:param optimizer: The optimizer
:type optimizer: torch.optim.Optimizer
:param model: The model
:type model: torch.nn.Module
:param scheduler: The learning rate scheduler
:param name: Name identifier for the model
:type name: str
:param kwargs: Additional arguments
"""
self.cur_step[name] += 1
if self.cur_step[name] == self.accumulated_gradient:
if is_actor(model):
model = model.model
grad_norms = []
for param_group in self.grouped_params.values():
grad_norms.append(torch.nn.utils.clip_grad_norm_(param_group, max_norm=self.max_norm))
# if grad_norm is not finite, skip the update
if not all(torch.isfinite(grad_norm) for grad_norm in grad_norms):
print(f"WARN: grad_norm is not finite: {grad_norms}")
optimizer.zero_grad()
else:
optimizer.step()
scheduler.step()
optimizer.zero_grad()
self.cur_step[name] = 0
[docs] def prepare_model(self,
model,
is_training=False,
shard_size=-1,
reshard_after_forward=True) -> Union[List[ModelOrModelOptimPair], ModelOrModelOptimPair]:
"""
Prepares a model for FSDP training.
:param model: The model to prepare.
:type model: torch.nn.Module or None
:return: The prepared model wrapped with FSDP.
:rtype: torch.nn.Module
Example::
>>> prepared_model = strategy.prepare_model(model)
"""
def get_auto_shard_size(model, is_training):
"""
Automatically determine the shard size based on model size and training mode.
:param model: The model to analyze
:type model: torch.nn.Module
:param is_training: Whether the model is being prepared for training
:type is_training: bool
:return: The recommended shard size
:rtype: int
"""
if is_training:
return -1
# Calculate total number of parameters (in billions)
total_params = sum(p.numel() for p in model.parameters())
if total_params < 1e10: # < 10B
return 1
elif total_params < 8e10: # 10B-80B
return 8
else: # ≥ 80B
return -1
shard_size = get_auto_shard_size(model, is_training) if shard_size == "auto" else shard_size
if model is None or not isinstance(model, torch.nn.Module):
return model
if hasattr(model, "base_model"):
# when RM.base_model is an Engine or already wrapped by FSDP, skip fsdp init
if not isinstance(model.base_model, torch.nn.Module):
return model
elif isinstance(model.base_model, FSDPModule):
return model
return self._fsdp_init_model(
model, is_training=is_training, shard_size=shard_size, reshard_after_forward=reshard_after_forward
)
@torch.no_grad()
def _fsdp_init_model(self, model, is_training, shard_size=-1, reshard_after_forward=True):
"""
Initialize a model with FSDP wrapping.
Sets up mixed precision, auto-wrapping policy, and sharding strategy for the model.
:param model: The model to initialize with FSDP
:type model: torch.nn.Module
:param is_training: Whether the model is being prepared for training
:type is_training: bool
:param shard_size: The shard size for FSDP
:type shard_size: int
:param reshard_after_forward: Whether to reshard parameters after forward pass
:type reshard_after_forward: bool
:return: The FSDP-wrapped model
:rtype: torch.nn.Module
"""
self.report_memory("before FSDP2 wrap model")
naive_mp_training = self.use_naive_opt and is_training
model_to_wrap = model.model if is_actor(model) else model
if isinstance(model_to_wrap, FSDPModule):
return model
self.report_memory("before FSDP2 wrap model pos2")
# this is not sufficient enough, for example, it will only return Qwen2DecoderLayer for qwen2
default_transformer_cls_names_to_wrap = getattr(model_to_wrap, "_no_split_modules", [])
# so we add some manual rules
transformer_cls_names_to_wrap = default_transformer_cls_names_to_wrap
for cls in manual_transformer_cls_names_to_wrap:
if cls not in transformer_cls_names_to_wrap:
transformer_cls_names_to_wrap.append(cls)
if self.mixed_mm_data:
# Note:if we have mixed multi-modal data across DP ranks
# (e.g. some ranks pure text, other ranks contains images)
# we either keep vision model in full state, or keep it in FSDP's root module.
# below we keep vit in root module to avoid stuck
for cls_name in vit_transformer_cls_names:
transformer_cls_names_to_wrap.remove(cls_name)
transformer_cls_to_wrap = list() # noqa
vit_transformer_cls = list() # noqa
for layer_class in transformer_cls_names_to_wrap:
transformer_cls = get_module_class_from_name(model_to_wrap, layer_class)
if transformer_cls is not None:
transformer_cls_to_wrap.append(transformer_cls)
# Note: in this way, we keep vit in full state by passing no_shard_mesh
# this is less memory efficient compared to keep vit in root module
# vit_transformer_cls = list()
# for layer_class in vit_transformer_cls_names:
# transformer_cls = get_module_class_from_name(model_to_wrap, layer_class)
# if transformer_cls is not None:
# vit_transformer_cls.append(transformer_cls)
if len(transformer_cls_to_wrap) == 0:
self.print("len(transformer_cls_to_wrap)=0", model_to_wrap)
raise NotImplementedError("len(transformer_cls_to_wrap) == 0, please check the wrapping rules!")
mp_policy = MixedPrecisionPolicy(
param_dtype=torch.bfloat16 if naive_mp_training else None,
reduce_dtype=torch.float32 if naive_mp_training else None,
)
mesh = None
world_size = torch.distributed.get_world_size()
if shard_size != -1:
assert world_size % shard_size == 0
mesh = init_device_mesh(
"cuda", (world_size // shard_size, shard_size), mesh_dim_names=("replicate", "shard")
)
else:
mesh = init_device_mesh("cuda", (1, world_size), mesh_dim_names=("replicate", "shard"))
no_shard_mesh = init_device_mesh("cuda", (world_size, 1), mesh_dim_names=("replicate", "shard")) # noqa
offload_policy = CPUOffloadPolicy() if is_training and self.args.fsdp_cpu_offload else OffloadPolicy()
fsdp_kwargs = {
"reshard_after_forward": reshard_after_forward,
"mp_policy": mp_policy,
"offload_policy": offload_policy,
"mesh": mesh,
}
# fsdp_kwargs_no_shard = fsdp_kwargs.copy()
# fsdp_kwargs_no_shard['mesh'] = no_shard_mesh
# fsdp_kwargs_no_shard['reshard_after_forward'] = True
# fsdp_kwargs_vit = fsdp_kwargs_no_shard if self.no_shard_vit else fsdp_kwargs
for cls_to_wrap in transformer_cls_to_wrap:
for module in model_to_wrap.modules():
if isinstance(module, cls_to_wrap):
# if cls_to_wrap in vit_transformer_cls:
# fully_shard(module, **fsdp_kwargs_vit)
fully_shard(module, **fsdp_kwargs)
if not self.args.fused_linear_logprob:
# In fused linear logprob implementation, lm_head.weight is directly used to calculate logprob
# If lm_head is sharded here, it may stuck in actor forward.
for name, module in model_to_wrap.named_modules():
if "lm_head" in name:
fully_shard(module, **fsdp_kwargs)
fully_shard(model_to_wrap, **fsdp_kwargs)
if naive_mp_training:
# cast model into fp32 to create optimizer with fp32 states
# https://github.com/pytorch/torchtitan/issues/1133#issuecomment-2824429682
model_to_wrap = model_to_wrap.to(torch.float32)
if hasattr(model_to_wrap, "generate"):
register_fsdp_forward_method(model_to_wrap, "generate")
if is_meta_initialized(model_to_wrap):
model.to_empty(device="cuda")
self.print(f"after _fsdp2_init_model: {model_to_wrap}")
self.report_memory("after FSDP2 wrap model")
return model
[docs] @contextmanager
def init_model_context(self, meta_init=False):
"""
Context manager for model initialization, for large models it can be initialized on Meta device.
:param meta_init: if init on meta device
:type meta_init: bool
"""
try:
if meta_init:
with torch.device("meta"):
yield
else:
yield
finally:
self.report_memory("Finished init_model_context")
[docs] def unwrap_model(self, model) -> nn.Module:
"""
Unwraps the model from any wrapper modules.
:param model: The model to unwrap.
:type model: torch.nn.Module
:return: The unwrapped model.
:rtype: torch.nn.Module
Example::
>>> unwrapped_model = strategy.unwrap_model(wrapped_model)
"""
if hasattr(model, "module"):
return model.module
else:
return model
[docs] def save_model(self, *args, **kwargs) -> None:
"""
Save the model, its configuration, and tokenizer.
This method handles gathering and saving the full model parameters in a distributed setting.
Only rank 0 process saves the model to disk.
:param model: The model to save
:type model: torch.nn.Module
:param tokenizer: The tokenizer to save
:param output_dir: Directory to save the model to
:type output_dir: str
:param kwargs: Additional arguments to pass to model.save_pretrained
"""
self.print("FSDP save model is not implemented, please use offline tools to convert to huggingface model")
[docs] def save_ckpt(
self,
model,
save_dir,
tag=None,
max_num=3,
max_mem=1000,
client_state={},
save_latest=True,
optimizer=None,
scheduler=None,
): # pylint: disable=R0917,W0102
"""
Save model checkpoints in a distributed environment with automatic rotation.
This method saves the sharded model weights to disk and manages the number and total size
of checkpoints by removing older ones when necessary. It ensures proper synchronization
between distributed processes.
:param model: The model to save, typically an FSDP-wrapped model
:type model: torch.nn.Module
:param save_dir: Directory where checkpoints will be saved
:type save_dir: str
:param optimizer: The optimizer to save
:type optimizer: torch.optim, optional
:param scheduler: The scheduler to save
:type scheduler: torch.lr_scheduler, optional
:param tag: Subdirectory name for this specific checkpoint
:type tag: str, optional
:param max_num: Maximum number of checkpoints to keep
:type max_num: int, default=3
:param max_mem: Maximum disk space in GB for all checkpoints
:type max_mem: int, default=1000
:param client_state: Additional state to save (not currently used)
:type client_state: dict, default={}
:param save_latest: Whether to save a copy as the latest checkpoint (not used)
:type save_latest: bool, default=True
Example::
>>> trainer.save_ckpt(model, "checkpoints", tag="step_1000", max_num=5)
"""
if self.is_rank_0():
os.makedirs(save_dir, exist_ok=True)
while True:
subdirs = sorted(
[(os.path.join(save_dir, d), os.path.getmtime(os.path.join(save_dir, d)))
for d in os.listdir(save_dir)
if os.path.isdir(os.path.join(save_dir, d))],
key=lambda x: x[1],
)
if len(subdirs) >= max_num:
oldest_dir = subdirs[0][0]
if os.path.exists(oldest_dir):
shutil.rmtree(oldest_dir)
self.print(f"Deleted oldest ckpt {oldest_dir}")
else:
break
dist.barrier()
fsdp_state_dict = get_model_state_dict(model)
fp = os.path.join(save_dir, tag)
os.makedirs(fp, exist_ok=True)
dcp.save(fsdp_state_dict, checkpoint_id=fp)
self.print(f"DCP checkpoint saved to {fp}")
if optimizer is not None:
opt_base_dir = os.path.join(save_dir, tag, "optim_states")
os.makedirs(opt_base_dir, exist_ok=True)
if is_mp_optimizer(optimizer):
opt_ckpt_path = os.path.join(opt_base_dir, f"_rank{torch.distributed.get_rank()}")
torch.save(optimizer.state_dict(), opt_ckpt_path)
else:
# DCP can only be use with naive optimizer
fsdp_optim_state_dict = get_optimizer_state_dict(model, optimizer)
dcp.save(fsdp_optim_state_dict, checkpoint_id=opt_base_dir)
client_ckpt_path = os.path.join(fp, "client_state.pt")
torch.save(client_state, client_ckpt_path)
self.print(f"client_state save to {client_ckpt_path}, content: {client_state}")
if scheduler is not None:
sched_ckpt_path = os.path.join(fp, "scheduler_state.pt")
scheduler_state = scheduler.state_dict()
torch.save(scheduler_state, sched_ckpt_path)
self.print(f"scheduler_state save to {sched_ckpt_path}")
[docs] def load_ckpt( # pylint: disable=R0917
self,
model,
load_dir,
optimizer=None,
scheduler=None,
load_module_strict=True,
load_optimizer_states=True,
load_lr_scheduler_states=True,
**kwargs,
):
"""
Load model checkpoints in a distributed environment.
This method loads sharded model weights from disk for each distributed process.
It handles the proper loading of FSDP-sharded state dictionaries.
:param model: The model to load weights into, typically an FSDP-wrapped model
:type model: torch.nn.Module
:param optimizer: The optimizer to load weights into
:type optimizer: torch.optim, optional
:param scheduler: The scheduler to load weights into
:type scheduler: torch.lr_scheduler, optional
:param load_dir: Directory containing the checkpoints
:type load_dir: str
:param load_module_strict: Whether to strictly enforce that the keys in the model state dict match
:type load_module_strict: bool, default=True
:param load_optimizer_states: Whether to load optimizer states
:type load_optimizer_states: bool, default=True
:param load_lr_scheduler_states: Whether to load learning rate scheduler states
:type load_lr_scheduler_states: bool, default=True
:return: A tuple of (load_dir, client_states) where load_dir is the directory from which the
checkpoint was loaded and client_states contains additional saved state
:rtype: tuple
Example::
>>> load_dir, client_states = trainer.load_ckpt(model, "checkpoints/step_1000")
"""
latest_path = find_latest_checkpoint_dir(load_dir)
self.print(f"Loading DCP checkpoint from {latest_path}")
fsdp_state_dict = get_model_state_dict(model)
dcp.load(state_dict=fsdp_state_dict, checkpoint_id=latest_path)
set_model_state_dict(model, fsdp_state_dict)
if optimizer is not None and load_optimizer_states:
opt_ckpt_path = os.path.join(latest_path, "optim_states")
if not os.path.exists(opt_ckpt_path):
self.print(f"WARNING: Opt ckpt {opt_ckpt_path} does not exist! Skipping ... ")
else:
self.print(f"Loading DCP checkpoint from {opt_ckpt_path}")
if is_mp_optimizer(optimizer):
opt_ckpt_path = os.path.join(latest_path, "optim_states", f"_rank{torch.distributed.get_rank()}")
opt_states = torch.load(opt_ckpt_path)
optimizer.load_state_dict(opt_states)
else:
# DCP can only be use with naive optimizer
fsdp_optim_state_dict = get_optimizer_state_dict(model, optimizer)
opt_ckpt_path = os.path.join(latest_path, "optim_states")
dcp.load(state_dict=fsdp_optim_state_dict, checkpoint_id=opt_ckpt_path)
set_optimizer_state_dict(model, optimizer, fsdp_optim_state_dict)
if scheduler is not None and load_lr_scheduler_states:
sched_ckpt_path = os.path.join(latest_path, "scheduler_state.pt")
if not os.path.exists(sched_ckpt_path):
self.print(f"WARNING: Scheduler ckpt {sched_ckpt_path} does not exist! Skipping ... ")
else:
self.print(f"Loading lr_scheduler_states from {sched_ckpt_path}")
loaded_scheduler_state = torch.load(sched_ckpt_path)
scheduler.load_state_dict(loaded_scheduler_state)
client_states = {}
client_ckpt_path = os.path.join(latest_path, "client_state.pt")
if os.path.exists(client_ckpt_path):
client_states = torch.load(client_ckpt_path)
self.print(f"Loaded client states: {client_states=}")
self.sync_and_clear_cache()
return latest_path, client_states
[docs] def maybe_offload_optimizer(self, optimizer):
"""
Offload FSDP optimizer states to CPU if adam_offload is enabled.
:param optimizer: The optimizer to potentially offload
:type optimizer: torch.optim.Optimizer
:return: The offloaded optimizer if adam_offload is enabled, otherwise the original optimizer
:rtype: torch.optim.Optimizer
"""
if self.args.adam_offload:
return offload_fsdp_optimizer(optimizer)
[docs] def maybe_load_optimizer(self, optimizer, device=torch.cuda.current_device()):
"""
Load FSDP optimizer states back to GPU if adam_offload is enabled.
:param optimizer: The optimizer to potentially load
:type optimizer: torch.optim.Optimizer
:param device: The device to load the optimizer to
:type device: torch.device
:return: The loaded optimizer if adam_offload is enabled, otherwise the original optimizer
:rtype: torch.optim.Optimizer
"""
if self.args.adam_offload:
return load_fsdp_optimizer(optimizer, device)
@torch.no_grad()
def offload_model(self, models, empty_cache: bool = True):
"""
Offload model(s) to CPU to free GPU memory.
This method moves the model parameters and buffers to CPU memory, which can be useful
for memory management during training when certain models are not actively being used.
:param models: Single model or list/tuple of models to offload
:type models: torch.nn.Module or list/tuple of torch.nn.Module
:param empty_cache: Whether to clear CUDA cache after offloading
:type empty_cache: bool, default=True
Example::
>>> strategy.offload_model([actor_model, critic_model])
"""
def offload_single(model):
"""
Offload a single model to CPU.
:param model: The model to offload
:type model: torch.nn.Module
"""
if not isinstance(model, torch.nn.Module):
return
model.to(torch.device("cpu"))
# the following code does not work for QwenVL+FSDP2, it will result in some unreleased memory.
# for param in model.parameters():
# param = param.to(torch.device("cpu"), non_blocking=True)
# for buf in model.buffers():
# buf.data = buf.data.to(torch.device("cpu"), non_blocking=True)
if isinstance(models, (list, tuple)):
for model in models:
offload_single(model)
else:
offload_single(models)
if empty_cache:
self.sync_and_clear_cache()
self.report_memory("after offload_model")
@torch.no_grad()
def reload_model(self, models):
"""
Reload model(s) from CPU back to GPU.
This method moves the model parameters and buffers back to GPU memory after they
have been offloaded to CPU.
:param models: Single model or list/tuple of models to reload
:type models: torch.nn.Module or list/tuple of torch.nn.Module
Example::
>>> strategy.reload_model([actor_model, critic_model])
"""
device = torch.cuda.current_device()
def reload_single(model):
"""
Reload a single model to GPU.
:param model: The model to reload
:type model: torch.nn.Module
"""
if not isinstance(model, torch.nn.Module):
return
model.to(device)
# the following code does not work for QwenVL+FSDP2, it will result in some unreleased memory.
# for param in model.parameters():
# param.data = param.data.to(device, non_blocking=True)
# for buf in model.buffers():
# buf.data = buf.data.to(device, non_blocking=True)
if isinstance(models, (list, tuple)):
for model in models:
reload_single(model)
else:
reload_single(models)
self.report_memory("after reload_model")
[docs]def is_mp_optimizer(optim):
"""
Check if an optimizer is an instance of FSDPadaptOptimizer.
This function determines whether the provided optimizer is a model parallel
optimizer specifically designed for FSDP (Fully Sharded Data Parallel).
:param optim: The optimizer to check
:type optim: torch.optim.Optimizer or similar
:return: True if the optimizer is an instance of FSDPadaptOptimizer, False otherwise
:rtype: bool
Example::
>>> optimizer = FSDPadaptOptimizer(model_parameters, lr=0.01)
>>> is_mp = is_mp_optimizer(optimizer)
>>> print(is_mp)
True
"""
return isinstance(optim, FSDPadaptOptimizer)