Source code for grl.numerical_methods.numerical_solvers.dpm_solver
#############################################################
# This DPM-Solver snippet is from https://github.com/ChenDRAG/CEP-energy-guided-diffusion
# wich is based on https://github.com/LuChengTHU/dpm-solver
#############################################################
from typing import Any, Callable, Dict, List, Tuple, Union
import torch
from tensordict import TensorDict
from torch import nn
[docs]class DPMSolver:
"""
Overview:
The DPM-Solver for sampling from the diffusion process.
Interface:
``__init__``, ``integrate``
"""
[docs] def __init__(
self,
order: int,
device: str,
atol: float = 1e-5,
rtol: float = 1e-5,
steps: int = None,
type: str = "dpm_solver",
method: str = "singlestep",
solver_type: str = "dpm_solver",
skip_type: str = "time_uniform",
denoise: bool = False,
):
"""
Overview:
Initialize the DPM-Solver.
Arguments:
order (:obj:`int`): The order of the DPM-Solver, which should be 1, 2, or 3.
device (:obj:`str`): The device for the computation.
denoise (:obj:`bool`): Whether to denoise at the final step.
atol (:obj:`float`): The absolute tolerance for the adaptive solver.
rtol (:obj:`float`): The relative tolerance for the adaptive solver.
steps (:obj:`int`): The total number of function evaluations (NFE).
type (:obj:`str`): The type for the DPM-Solver, which should be 'dpm_solver' or 'dpm_solver++'.
method (:obj:`str`): The method for the DPM-Solver, which should be 'singlestep', 'multistep', 'singlestep_fixed', or 'adaptive'.
solver_type (:obj:`str`): The type for the high-order solvers, which should be 'dpm_solver' or 'taylor'.
The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
skip_type (:obj:`str`): The type for the spacing of the time steps, which should be 'logSNR', 'time_uniform', or 'time_quadratic'.
denoise (:obj:`bool`): Whether to denoise at the final step.
"""
self.type = type
assert self.type in ["dpm_solver", "dpm_solver++"]
if self.type == "dpm_solver++":
self.use_dpm_solver_plus_plus = True
else:
self.use_dpm_solver_plus_plus = False
self.atol = atol
self.rtol = rtol
self.steps = steps
self.order = order
assert self.order in [1, 2, 3]
self.method = method
assert self.method in [
"singlestep",
"multistep",
"singlestep_fixed",
"adaptive",
]
self.solver_type = solver_type
assert self.solver_type in ["dpm_solver", "taylor"]
if self.solver_type == "dpm_solver":
self.default_high_order_solvers = True
else:
self.default_high_order_solvers = False
self.skip_type = skip_type
assert self.skip_type in ["logSNR", "time_uniform", "time_quadratic"]
self.denoise = denoise
self.device = device
self.nfe = 0
# TODO: support dynamic thresholding for dpm_solver++
[docs] def integrate(
self,
diffusion_process,
noise_function: Callable,
data_prediction_function: Callable,
x: Union[torch.Tensor, TensorDict],
condition: Union[torch.Tensor, TensorDict] = None,
steps: int = None,
save_intermediate: bool = False,
):
"""
Overview:
Integrate the diffusion process by the DPM-Solver.
Arguments:
diffusion_process (:obj:`DiffusionProcess`): The diffusion process.
noise_function (:obj:`Callable`): The noise prediction model.
data_prediction_function (:obj:`Callable`): The data prediction model.
x (:obj:`Union[torch.Tensor, TensorDict]`): The initial value at time `t_start`.
condition (:obj:`Union[torch.Tensor, TensorDict]`): The condition for the data prediction model.
steps (:obj:`int`): The total number of function evaluations (NFE).
save_intermediate (:obj:`bool`): If true, also return the intermediate model values.
Returns:
x_end (:obj:`torch.Tensor`): The approximated solution at time `t_end`.
"""
steps = (
steps if steps is not None else self.steps if self.steps is not None else 20
)
def model_fn(
x: Union[torch.Tensor, TensorDict], t: torch.Tensor
) -> Union[torch.Tensor, TensorDict]:
"""
Overview:
Convert the model to the noise prediction model or the data prediction model.
Arguments:
x (:obj:`Union[torch.Tensor, TensorDict]`): The input tensor.
t (:obj:`torch.Tensor`): The time tensor.
"""
if self.use_dpm_solver_plus_plus:
return data_prediction_function(t, x, condition)
else:
return noise_function(t, x, condition)
def get_time_steps(t_T, t_0, N):
"""
Overview:
Compute the intermediate time steps for sampling.
Arguments:
t_T (:obj:`float`): The starting time of the sampling (default is T).
t_0 (:obj:`float`): The ending time of the sampling (default is epsilon).
N (:obj:`int`): The total number of the spacing of the time steps.
Returns:
t (:obj:`torch.Tensor`): A pytorch tensor of the time steps, with the shape (N + 1,).
"""
if self.skip_type == "logSNR":
lambda_T = diffusion_process.HalfLogSNR(t_T).to(self.device)
lambda_0 = diffusion_process.HalfLogSNR(t_0).to(self.device)
logSNR_steps = torch.linspace(lambda_T, lambda_0, N + 1)
return self.diffusion_process.InverseHalfLogSNR(logSNR_steps)
elif self.skip_type == "time_uniform":
return torch.linspace(t_T, t_0, N + 1).to(self.device)
elif self.skip_type == "time_quadratic":
t_order = 2
t = (
torch.linspace(
t_T ** (1.0 / t_order), t_0 ** (1.0 / t_order), N + 1
)
.pow(t_order)
.to(self.device)
)
return t
else:
raise ValueError(
"Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(
self.skip_type
)
)
def get_orders_for_singlestep_solver(steps: int, order: int) -> List[int]:
"""
Overview:
Get the order of each step for sampling by the singlestep DPM-Solver.
Arguments:
steps (:obj:`int`): The total number of function evaluations (NFE).
order (:obj:`int`): The max order for the solver (2 or 3).
Returns:
orders (:obj:`List[int]`): A list of the solver order of each step.
.. note::
We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast".
Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is:
If order == 1:
We take `steps` of DPM-Solver-1 (i.e. DDIM).
If order == 2:
Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling.
If steps % 2 == 0, we use K steps of DPM-Solver-2.
If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1.
If order == 3:
Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, \
and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.
If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.
If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2
"""
if order == 3:
K = steps // 3 + 1
if steps % 3 == 0:
orders = [
3,
] * (
K - 2
) + [2, 1]
elif steps % 3 == 1:
orders = [
3,
] * (
K - 1
) + [1]
else:
orders = [
3,
] * (
K - 1
) + [2]
return orders
elif order == 2:
K = steps // 2
if steps % 2 == 0:
# orders = [2,] * K
K = steps // 2 + 1
orders = [
2,
] * (K - 2) + [
1,
] * 2
else:
orders = [
2,
] * K + [1]
return orders
elif order == 1:
return [
1,
] * steps
else:
raise ValueError("'order' must be '1' or '2' or '3'.")
def denoise_fn(
x: Union[torch.Tensor, TensorDict],
s: torch.Tensor,
condition: Union[torch.Tensor, TensorDict],
) -> Union[torch.Tensor, TensorDict]:
"""
Overview:
Denoise at the final step, which is equivalent to solve the ODE \
from lambda_s to infty by first-order discretization.
Arguments:
x (:obj:`Union[torch.Tensor, TensorDict]`): The input tensor.
s (:obj:`torch.Tensor`): The time tensor.
condition (:obj:`Union[torch.Tensor, TensorDict]`): The condition for the data prediction model.
Returns:
x (:obj:`Union[torch.Tensor, TensorDict]`): The denoised output.
"""
return data_prediction_function(s, x, condition)
def dpm_solver_first_update(
x: Union[torch.Tensor, TensorDict],
s: torch.Tensor,
t: torch.Tensor,
model_s: Union[torch.Tensor, TensorDict] = None,
return_intermediate: bool = False,
):
"""
Overview:
DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`.
Arguments:
x (:obj:`Union[torch.Tensor, TensorDict]`): The initial value at time `s`.
s (:obj:`torch.Tensor`): The starting time, with the shape (x.shape[0],).
t (:obj:`torch.Tensor`): The ending time, with the shape (x.shape[0],).
model_s (:obj:`Union[torch.Tensor, TensorDict]`): The model function evaluated at time `s`.
If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
return_intermediate (:obj:`bool`): If true, also return the model value at time `s`.
Returns:
x_t (:obj:`Union[torch.Tensor, TensorDict]`): The approximated solution at time `t`.
"""
lambda_s = diffusion_process.HalfLogSNR(s, x)
lambda_t = diffusion_process.HalfLogSNR(t, x)
h = lambda_t - lambda_s
log_alpha_s = diffusion_process.log_scale(s, x)
log_alpha_t = diffusion_process.log_scale(t, x)
sigma_s = diffusion_process.std(s, x)
sigma_t = diffusion_process.std(t, x)
alpha_t = torch.exp(log_alpha_t)
if self.use_dpm_solver_plus_plus:
phi_1 = torch.expm1(-h)
if model_s is None:
model_s = model_fn(x, s)
x_t = sigma_t / sigma_s * x - alpha_t * phi_1 * model_s
if return_intermediate:
return x_t, {"model_s": model_s}
else:
return x_t
else:
phi_1 = torch.expm1(h)
if model_s is None:
model_s = model_fn(x, s)
x_t = (
torch.exp(log_alpha_t - log_alpha_s) * x - sigma_t * phi_1 * model_s
)
if return_intermediate:
return x_t, {"model_s": model_s}
else:
return x_t
def singlestep_dpm_solver_second_update(
x: Union[torch.Tensor, TensorDict],
s: torch.Tensor,
t: torch.Tensor,
r1: float = 0.5,
model_s: Union[torch.Tensor, TensorDict] = None,
return_intermediate: bool = False,
):
"""
Overview:
Singlestep solver DPM-Solver-2 from time `s` to time `t`.
Arguments:
x (:obj:`Union[torch.Tensor, TensorDict]`): The initial value at time `s`.
s (:obj:`torch.Tensor`): The starting time, with the shape (x.shape[0],).
t (:obj:`torch.Tensor`): The ending time, with the shape (x.shape[0],).
r1 (:obj:`float`): The hyperparameter of the second-order solver.
model_s (:obj:`Union[torch.Tensor, TensorDict]`): The model function evaluated at time `s`.
If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
return_intermediate (:obj:`bool`): If true, also return the model value at time `s` and `s1` (the intermediate time).
Returns:
x_t (:obj:`Union[torch.Tensor, TensorDict]`): The approximated solution at time `t`.
"""
if r1 is None:
r1 = 0.5
lambda_s = diffusion_process.HalfLogSNR(s, x)
lambda_t = diffusion_process.HalfLogSNR(t, x)
h = lambda_t - lambda_s
lambda_s1 = lambda_s + r1 * h
s1 = diffusion_process.InverseHalfLogSNR(lambda_s1)[:, 0]
log_alpha_s = diffusion_process.log_scale(s, x)
log_alpha_s1 = diffusion_process.log_scale(s1, x)
log_alpha_t = diffusion_process.log_scale(t, x)
sigma_s = diffusion_process.std(s, x)
sigma_s1 = diffusion_process.std(s1, x)
sigma_t = diffusion_process.std(t, x)
alpha_s1 = torch.exp(log_alpha_s1)
alpha_t = torch.exp(log_alpha_t)
if self.use_dpm_solver_plus_plus:
phi_11 = torch.expm1(-r1 * h)
phi_1 = torch.expm1(-h)
if model_s is None:
model_s = model_fn(x, s)
x_s1 = sigma_s1 / sigma_s * x - alpha_s1 * phi_11 * model_s
model_s1 = model_fn(x_s1, s1)
if self.default_high_order_solvers:
x_t = (
sigma_t / sigma_s * x
- alpha_t * phi_1 * model_s
- (0.5 / r1) * alpha_t * phi_1 * (model_s1 - model_s)
)
else:
x_t = (
sigma_t / sigma_s * x
- alpha_t * phi_1 * model_s
+ (1.0 / r1)
* alpha_t
* ((torch.exp(-h) - 1.0) / h + 1.0)
* (model_s1 - model_s)
)
else:
phi_11 = torch.expm1(r1 * h)
phi_1 = torch.expm1(h)
if model_s is None:
model_s = model_fn(x, s)
x_s1 = (
torch.exp(log_alpha_s1 - log_alpha_s) * x
- sigma_s1 * phi_11 * model_s
)
model_s1 = model_fn(x_s1, s1)
if self.default_high_order_solvers:
x_t = (
torch.exp(log_alpha_t - log_alpha_s) * x
- sigma_t * phi_1 * model_s
- (0.5 / r1) * sigma_t * phi_1 * (model_s1 - model_s)
)
else:
x_t = (
torch.exp(log_alpha_t - log_alpha_s) * x
- sigma_t * phi_1 * model_s
- (1.0 / r1)
* sigma_t
* ((torch.exp(h) - 1.0) / h - 1.0)
* (model_s1 - model_s)
)
if return_intermediate:
return x_t, {"model_s": model_s, "model_s1": model_s1}
else:
return x_t
def singlestep_dpm_solver_third_update(
x: Union[torch.Tensor, TensorDict],
s: torch.Tensor,
t: torch.Tensor,
r1: float = 1.0 / 3.0,
r2: float = 2.0 / 3.0,
model_s: Union[torch.Tensor, TensorDict] = None,
model_s1: Union[torch.Tensor, TensorDict] = None,
return_intermediate: bool = False,
):
"""
Overview:
Singlestep solver DPM-Solver-3 from time `s` to time `t`.
Arguments:
x (:obj:`Union[torch.Tensor, TensorDict]`): The initial value at time `s`.
s (:obj:`torch.Tensor`): The starting time, with the shape (x.shape[0],).
t (:obj:`torch.Tensor`): The ending time, with the shape (x.shape[0],).
r1 (:obj:`float`): The hyperparameter of the third-order solver.
r2 (:obj:`float`): The hyperparameter of the third-order solver.
model_s (:obj:`Union[torch.Tensor, TensorDict]`): The model function evaluated at time `s`.
If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
model_s1 (:obj:`Union[torch.Tensor, TensorDict]`): The model function evaluated at time `s1` (the intermediate time given by `r1`).
If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it.
return_intermediate (:obj:`bool`): If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
Returns:
x_t (:obj:`Union[torch.Tensor, TensorDict]`): The approximated solution at time `t`.
"""
if r1 is None:
r1 = 1.0 / 3.0
if r2 is None:
r2 = 2.0 / 3.0
lambda_s = diffusion_process.HalfLogSNR(s, x)
lambda_t = diffusion_process.HalfLogSNR(t, x)
h = lambda_t - lambda_s
lambda_s1 = lambda_s + r1 * h
lambda_s2 = lambda_s + r2 * h
s1 = diffusion_process.InverseHalfLogSNR(lambda_s1)[:, 0]
s2 = diffusion_process.InverseHalfLogSNR(lambda_s2)[:, 0]
log_alpha_s = diffusion_process.log_scale(s, x)
log_alpha_s1 = diffusion_process.log_scale(s1, x)
log_alpha_s2 = diffusion_process.log_scale(s2, x)
log_alpha_t = diffusion_process.log_scale(t, x)
sigma_s = diffusion_process.std(s, x)
sigma_s1 = diffusion_process.std(s1, x)
sigma_s2 = diffusion_process.std(s2, x)
sigma_t = diffusion_process.std(t, x)
alpha_s1 = torch.exp(log_alpha_s1)
alpha_s2 = torch.exp(log_alpha_s2)
alpha_t = torch.exp(log_alpha_t)
if self.use_dpm_solver_plus_plus:
phi_11 = torch.expm1(-r1 * h)
phi_12 = torch.expm1(-r2 * h)
phi_1 = torch.expm1(-h)
phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.0
phi_2 = phi_1 / h + 1.0
phi_3 = phi_2 / h - 0.5
if model_s is None:
model_s = model_fn(x, s)
if model_s1 is None:
x_s1 = sigma_s1 / sigma_s * x - alpha_s1 * phi_11 * model_s
model_s1 = model_fn(x_s1, s1)
x_s2 = (
sigma_s2 / sigma_s * x
- alpha_s2 * phi_12 * model_s
+ r2 / r1 * alpha_s2 * phi_22 * (model_s1 - model_s)
)
model_s2 = model_fn(x_s2, s2)
if self.default_high_order_solvers:
x_t = (
sigma_t / sigma_s * x
- alpha_t * phi_1 * model_s
+ (1.0 / r2) * alpha_t * phi_2 * (model_s2 - model_s)
)
else:
D1_0 = (1.0 / r1) * (model_s1 - model_s)
D1_1 = (1.0 / r2) * (model_s2 - model_s)
D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
D2 = 2.0 * (D1_1 - D1_0) / (r2 - r1)
x_t = (
sigma_t / sigma_s * x
- alpha_t * phi_1 * model_s
+ alpha_t * phi_2 * D1
- alpha_t * phi_3 * D2
)
else:
phi_11 = torch.expm1(r1 * h)
phi_12 = torch.expm1(r2 * h)
phi_1 = torch.expm1(h)
phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.0
phi_2 = phi_1 / h - 1.0
phi_3 = phi_2 / h - 0.5
if model_s is None:
model_s = model_fn(x, s)
if model_s1 is None:
x_s1 = (
torch.exp(log_alpha_s1 - log_alpha_s) * x
- sigma_s1 * phi_11 * model_s
)
model_s1 = model_fn(x_s1, s1)
x_s2 = (
torch.exp(log_alpha_s2 - log_alpha_s) * x
- sigma_s2 * phi_12 * model_s
- r2 / r1 * sigma_s2 * phi_22 * (model_s1 - model_s)
)
model_s2 = model_fn(x_s2, s2)
if self.default_high_order_solvers:
x_t = (
torch.exp(log_alpha_t - log_alpha_s) * x
- sigma_t * phi_1 * model_s
- (1.0 / r2) * sigma_t * phi_2 * (model_s2 - model_s)
)
else:
D1_0 = (1.0 / r1) * (model_s1 - model_s)
D1_1 = (1.0 / r2) * (model_s2 - model_s)
D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
D2 = 2.0 * (D1_1 - D1_0) / (r2 - r1)
x_t = (
torch.exp(log_alpha_t - log_alpha_s) * x
- sigma_t * phi_1 * model_s
- sigma_t * phi_2 * D1
- sigma_t * phi_3 * D2
)
if return_intermediate:
return x_t, {
"model_s": model_s,
"model_s1": model_s1,
"model_s2": model_s2,
}
else:
return x_t
def multistep_dpm_solver_second_update(
x: Union[torch.Tensor, TensorDict],
model_prev_list: List[Union[torch.Tensor, TensorDict]],
t_prev_list: List[torch.Tensor],
t: torch.Tensor,
):
"""
Overview:
Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`.
Arguments:
x (:obj:`Union[torch.Tensor, TensorDict]`): The initial value at time `s`.
model_prev_list (:obj:`List[Union[torch.Tensor, TensorDict]]`): The previous computed model values.
t_prev_list (:obj:`List[torch.Tensor]`): The previous times, each time has the shape (x.shape[0],).
t (:obj:`torch.Tensor`): The ending time, with the shape (x.shape[0],).
Returns:
x_t (:obj:`Union[torch.Tensor, TensorDict]`): The approximated solution at time `t`.
"""
model_prev_1, model_prev_0 = model_prev_list
t_prev_1, t_prev_0 = t_prev_list
lambda_prev_1 = diffusion_process.HalfLogSNR(t=t_prev_1, x=x)
lambda_prev_0 = diffusion_process.HalfLogSNR(t=t_prev_0, x=x)
lambda_t = diffusion_process.HalfLogSNR(t=t, x=x)
log_alpha_prev_0 = diffusion_process.log_scale(t=t_prev_0, x=x)
log_alpha_t = diffusion_process.log_scale(t=t, x=x)
sigma_prev_0 = diffusion_process.std(t=t_prev_0, x=x)
sigma_t = diffusion_process.std(t=t, x=x)
alpha_t = torch.exp(log_alpha_t)
h_0 = lambda_prev_0 - lambda_prev_1
h = lambda_t - lambda_prev_0
r0 = h_0 / h
D1_0 = 1.0 / r0 * (model_prev_0 - model_prev_1)
if self.use_dpm_solver_plus_plus:
if self.solver_type == "dpm_solver":
x_t = (
sigma_t / sigma_prev_0 * x
- alpha_t * (torch.exp(-h) - 1.0) * model_prev_0
- 0.5 * alpha_t * (torch.exp(-h) - 1.0) * D1_0
)
elif self.solver_type == "taylor":
x_t = (
sigma_t / sigma_prev_0 * x
- alpha_t * (torch.exp(-h) - 1.0) * model_prev_0
+ alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0) * D1_0
)
else:
if self.solver_type == "dpm_solver":
x_t = (
torch.exp(log_alpha_t - log_alpha_prev_0) * x
- sigma_t * (torch.exp(h) - 1.0) * model_prev_0
- 0.5 * sigma_t * (torch.exp(h) - 1.0) * D1_0
)
elif self.solver_type == "taylor":
x_t = (
torch.exp(log_alpha_t - log_alpha_prev_0) * x
- sigma_t * (torch.exp(h) - 1.0) * model_prev_0
- sigma_t * ((torch.exp(h) - 1.0) / h - 1.0) * D1_0
)
return x_t
def multistep_dpm_solver_third_update(
x: Union[torch.Tensor, TensorDict],
model_prev_list: List[Union[torch.Tensor, TensorDict]],
t_prev_list: List[torch.Tensor],
t: torch.Tensor,
):
"""
Overview:
Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`.
Arguments:
x (:obj:`Union[torch.Tensor, TensorDict]`): The initial value at time `s`.
model_prev_list (:obj:`List[Union[torch.Tensor, TensorDict]]`): The previous computed model values.
t_prev_list (:obj:`List[torch.Tensor]`): The previous times, each time has the shape (x.shape[0],).
t (:obj:`torch.Tensor`): The ending time, with the shape (x.shape[0],).
Returns:
x_t (:obj:`Union[torch.Tensor, TensorDict]`): The approximated solution at time `t`.
"""
model_prev_2, model_prev_1, model_prev_0 = model_prev_list
t_prev_2, t_prev_1, t_prev_0 = t_prev_list
lambda_prev_2 = diffusion_process.HalfLogSNR(t=t_prev_2, x=x)
lambda_prev_1 = diffusion_process.HalfLogSNR(t=t_prev_1, x=x)
lambda_prev_0 = diffusion_process.HalfLogSNR(t=t_prev_0, x=x)
lambda_t = diffusion_process.HalfLogSNR(t=t, x=x)
log_alpha_prev_0 = diffusion_process.log_scale(t=t_prev_0, x=x)
log_alpha_t = diffusion_process.log_scale(t=t, x=x)
sigma_prev_0 = diffusion_process.std(t=t_prev_0, x=x)
sigma_t = diffusion_process.std(t=t, x=x)
alpha_t = torch.exp(log_alpha_t)
h_1 = lambda_prev_1 - lambda_prev_2
h_0 = lambda_prev_0 - lambda_prev_1
h = lambda_t - lambda_prev_0
r0, r1 = h_0 / h, h_1 / h
D1_0 = 1.0 / r0 * (model_prev_0 - model_prev_1)
D1_1 = 1.0 / r1 * (model_prev_1 - model_prev_2)
D1 = D1_0 + r0 / (r0 + r1) * (D1_0 - D1_1)
D2 = 1.0 / (r0 + r1) * (D1_0 - D1_1)
if self.use_dpm_solver_plus_plus:
x_t = (
sigma_t / sigma_prev_0 * x
- alpha_t * (torch.exp(-h) - 1.0) * model_prev_0
+ alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0) * D1
- alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5) * D2
)
else:
x_t = (
torch.exp(log_alpha_t - log_alpha_prev_0) * x
- sigma_t * (torch.exp(h) - 1.0) * model_prev_0
- sigma_t * ((torch.exp(h) - 1.0) / h - 1.0) * D1
- sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5) * D2
)
return x_t
def singlestep_dpm_solver_update(
x: Union[torch.Tensor, TensorDict],
s: torch.Tensor,
t: torch.Tensor,
order: float,
return_intermediate: bool = False,
r1: float = None,
r2: float = None,
):
"""
Overview:
Singlestep DPM-Solver with the order `order` from time `s` to time `t`.
Arguments:
x (:obj:`Union[torch.Tensor, TensorDict]`): The initial value at time `s`.
s (:obj:`torch.Tensor`): The starting time, with the shape (x.shape[0],).
t (:obj:`torch.Tensor`): The ending time, with the shape (x.shape[0],).
order (:obj:`int`): The order of DPM-Solver. We only support order == 1 or 2 or 3.
return_intermediate (:obj:`bool`): If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
r1 (:obj:`float`): The hyperparameter of the second-order or third-order solver.
r2 (:obj:`float`): The hyperparameter of the third-order solver.
Returns:
x_t (:obj:`Union[torch.Tensor, TensorDict]`): The approximated solution at time `t`.
"""
if order == 1:
return dpm_solver_first_update(
x, s, t, return_intermediate=return_intermediate
)
elif order == 2:
return singlestep_dpm_solver_second_update(
x, s, t, return_intermediate=return_intermediate, r1=r1
)
elif order == 3:
return singlestep_dpm_solver_third_update(
x, s, t, return_intermediate=return_intermediate, r1=r1, r2=r2
)
else:
raise ValueError(
"Solver order must be 1 or 2 or 3, got {}".format(order)
)
def multistep_dpm_solver_update(
x: Union[torch.Tensor, TensorDict],
model_prev_list: List[Union[torch.Tensor, TensorDict]],
t_prev_list: List[torch.Tensor],
t: torch.Tensor,
order: int,
):
"""
Overview:
Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`.
Arguments:
x (:obj:`Union[torch.Tensor, TensorDict]`): The initial value at time `s`.
model_prev_list (:obj:`List[Union[torch.Tensor, TensorDict]]`): The previous computed model values.
t_prev_list (:obj:`List[torch.Tensor]`): The previous times, each time has the shape (x.shape[0],).
t (:obj:`torch.Tensor`): The ending time, with the shape (x.shape[0],).
order (:obj:`int`): The order of DPM-Solver. We only support order == 1 or 2 or 3.
Returns:
x_t (:obj:`Union[torch.Tensor, TensorDict]`): The approximated solution at time `t`.
"""
if order == 1:
return dpm_solver_first_update(
x, t_prev_list[-1], t, model_s=model_prev_list[-1]
)
elif order == 2:
return multistep_dpm_solver_second_update(
x, model_prev_list, t_prev_list, t
)
elif order == 3:
return multistep_dpm_solver_third_update(
x, model_prev_list, t_prev_list, t
)
else:
raise ValueError(
"Solver order must be 1 or 2 or 3, got {}".format(order)
)
def dpm_solver_adaptive(
x: Union[torch.Tensor, TensorDict],
t_T: float,
t_0: float,
h_init: float = 0.05,
theta: float = 0.9,
t_err: float = 1e-5,
save_intermediate: bool = False,
):
"""
Overview:
The adaptive step size solver based on singlestep DPM-Solver.
Arguments:
x (:obj:`Union[torch.Tensor, TensorDict]`): The initial value at time `t_T`.
t_T (:obj:`float`): The starting time of the sampling (default is T).
t_0 (:obj:`float`): The ending time of the sampling (default is epsilon).
h_init (:obj:`float`): The initial step size (for logSNR).
theta (:obj:`float`): The safety hyperparameter for adapting the step size.
t_err (:obj:`float`): The tolerance for the time.
save_intermediate (:obj:`bool`): If true, also return the intermediate values.
Returns:
x_0 (:obj:`Union[torch.Tensor, TensorDict]`): The approximated solution at time `t_0`.
References:
[1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, \
"Gotta go fast when generating data with score-based models," \
arXiv preprint arXiv:2105.14080, 2021.
"""
s = t_T * torch.ones((x.shape[0],)).to(x)
lambda_s = diffusion_process.HalfLogSNR(t=t_T, x=x)
lambda_0 = diffusion_process.HalfLogSNR(t=t_0, x=x)
h = h_init * torch.ones_like(s).to(x)
x_prev = x
x_list = []
if self.order == 2:
r1 = 0.5
lower_update = lambda x, s, t: dpm_solver_first_update(
x, s, t, return_intermediate=True
)
higher_update = (
lambda x, s, t, **kwargs: singlestep_dpm_solver_second_update(
x, s, t, r1=r1, **kwargs
)
)
elif self.order == 3:
r1, r2 = 1.0 / 3.0, 2.0 / 3.0
lower_update = lambda x, s, t: singlestep_dpm_solver_second_update(
x, s, t, r1=r1, return_intermediate=True
)
higher_update = (
lambda x, s, t, **kwargs: singlestep_dpm_solver_third_update(
x, s, t, r1=r1, r2=r2, **kwargs
)
)
else:
raise ValueError(
"For adaptive step size solver, order must be 2 or 3, got {}".format(
self.order
)
)
while torch.abs((s - t_0)).mean() > t_err:
t = diffusion_process.InverseHalfLogSNR(HalfLogSNR=lambda_s + h)[:, 0]
x_lower, lower_noise_kwargs = lower_update(x, s, t)
x_higher = higher_update(x, s, t, **lower_noise_kwargs)
delta = torch.max(
torch.ones_like(x).to(x) * self.atol,
self.rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)),
)
norm_fn = lambda v: torch.sqrt(
torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True)
)
E = norm_fn((x_higher - x_lower) / delta).max()
if save_intermediate:
x_list.append(x_higher.clone())
if torch.all(E <= 1.0):
x = x_higher
s = t
x_prev = x_lower
lambda_s = diffusion_process.HalfLogSNR(t=t, x=x)
h = torch.min(
theta * h * torch.float_power(E, -1.0 / self.order).float(),
lambda_0 - lambda_s,
)
self.nfe += self.order
if save_intermediate:
return x, x_list
else:
return x
t_0 = 0.00001
t_T = diffusion_process.t_max
if save_intermediate:
x_list = [x.clone()]
if self.method == "adaptive":
with torch.no_grad():
if save_intermediate:
x, x_list_ = dpm_solver_adaptive(
x, t_T=t_T, t_0=t_0, save_intermediate=True
)
x_list.extend(x_list_)
else:
x = dpm_solver_adaptive(
x, t_T=t_T, t_0=t_0, save_intermediate=False
)
elif self.method == "multistep":
assert steps >= self.order
timesteps = get_time_steps(t_T=t_T, t_0=t_0, N=steps)
assert timesteps.shape[0] - 1 == steps
with torch.no_grad():
vec_t = timesteps[0].expand((x.shape[0]))
model_prev_list = [model_fn(x, vec_t)]
t_prev_list = [vec_t]
# Init the first `order` values by lower order multistep DPM-Solver.
for init_order in range(1, self.order):
vec_t = timesteps[init_order].expand(x.shape[0])
x = multistep_dpm_solver_update(
x, model_prev_list, t_prev_list, vec_t, init_order
)
if save_intermediate:
x_list.append(x.clone())
model_prev_list.append(model_fn(x, vec_t))
t_prev_list.append(vec_t)
# Compute the remaining values by `order`-th order multistep DPM-Solver.
for step in range(self.order, steps + 1):
vec_t = timesteps[step].expand(x.shape[0])
x = multistep_dpm_solver_update(
x, model_prev_list, t_prev_list, vec_t, self.order
)
if save_intermediate:
x_list.append(x.clone())
for i in range(self.order - 1):
t_prev_list[i] = t_prev_list[i + 1]
model_prev_list[i] = model_prev_list[i + 1]
t_prev_list[-1] = vec_t
# We do not need to evaluate the final model value.
if step < steps:
model_prev_list[-1] = model_fn(x, vec_t)
elif self.method in ["singlestep", "singlestep_fixed"]:
if self.method == "singlestep":
orders = get_orders_for_singlestep_solver(steps=steps, order=self.order)
timesteps = get_time_steps(t_T=t_T, t_0=t_0, N=steps)
elif self.method == "singlestep_fixed":
K = steps // order
orders = [
order,
] * K
timesteps = get_time_steps(t_T=t_T, t_0=t_0, N=(K * order))
with torch.no_grad():
i = 0
for order in orders:
vec_s, vec_t = timesteps[i].expand(x.shape[0]), timesteps[
i + order
].expand(x.shape[0])
h = diffusion_process.HalfLogSNR(
t=timesteps[i + order]
) - diffusion_process.HalfLogSNR(t=timesteps[i])
r1 = (
None
if order <= 1
else (
diffusion_process.HalfLogSNR(t=timesteps[i + 1])
- diffusion_process.HalfLogSNR(t=timesteps[i])
)
/ h
)
r2 = (
None
if order <= 2
else (
diffusion_process.HalfLogSNR(t=timesteps[i + 2])
- diffusion_process.HalfLogSNR(t=timesteps[i])
)
/ h
)
x = singlestep_dpm_solver_update(
x, vec_s, vec_t, order, r1=r1, r2=r2
)
if save_intermediate:
x_list.append(x.clone())
i += order
if self.denoise:
x = denoise_fn(x, torch.ones((x.shape[0],)).to(self.device) * t_0)
if save_intermediate:
x_list[-1] = x.clone()
if save_intermediate:
return torch.stack(x_list, dim=0)
else:
return x