Shortcuts

Source code for ding.torch_utils.model_helper

import torch


[docs]def get_num_params(model: torch.nn.Module) -> int: """ Overview: Return the number of parameters in the model. Arguments: - model (:obj:`torch.nn.Module`): The model object to calculate the parameter number. Returns: - n_params (:obj:`int`): The calculated number of parameters. Examples: >>> model = torch.nn.Linear(3, 5) >>> num = get_num_params(model) >>> assert num == 15 """ return sum(p.numel() for p in model.parameters())