grl.numerical_methods¶
ODE¶
SDE¶
DPMSolver¶
- class grl.numerical_methods.DPMSolver(order, device, atol=1e-05, rtol=1e-05, steps=None, type='dpm_solver', method='singlestep', solver_type='dpm_solver', skip_type='time_uniform', denoise=False)[source]¶
- Overview:
The DPM-Solver for sampling from the diffusion process.
- Interface:
__init__
,integrate
- __init__(order, device, atol=1e-05, rtol=1e-05, steps=None, type='dpm_solver', method='singlestep', solver_type='dpm_solver', skip_type='time_uniform', denoise=False)[source]¶
- Overview:
Initialize the DPM-Solver.
- Parameters:
order (
int
) – The order of the DPM-Solver, which should be 1, 2, or 3.device (
str
) – The device for the computation.denoise (
bool
) – Whether to denoise at the final step.atol (
float
) – The absolute tolerance for the adaptive solver.rtol (
float
) – The relative tolerance for the adaptive solver.steps (
int
) – The total number of function evaluations (NFE).type (
str
) – The type for the DPM-Solver, which should be ‘dpm_solver’ or ‘dpm_solver++’.method (
str
) – The method for the DPM-Solver, which should be ‘singlestep’, ‘multistep’, ‘singlestep_fixed’, or ‘adaptive’.solver_type (
str
) – The type for the high-order solvers, which should be ‘dpm_solver’ or ‘taylor’. The type slightly impacts the performance. We recommend to use ‘dpm_solver’ type.skip_type (
str
) – The type for the spacing of the time steps, which should be ‘logSNR’, ‘time_uniform’, or ‘time_quadratic’.denoise – Whether to denoise at the final step.
- integrate(diffusion_process, noise_function, data_prediction_function, x, condition=None, steps=None, save_intermediate=False)[source]¶
- Overview:
Integrate the diffusion process by the DPM-Solver.
- Parameters:
diffusion_process (
DiffusionProcess
) – The diffusion process.noise_function (
Callable
) – The noise prediction model.data_prediction_function (
Callable
) – The data prediction model.x (
Union[torch.Tensor, TensorDict]
) – The initial value at time t_start.condition (
Union[torch.Tensor, TensorDict]
) – The condition for the data prediction model.steps (
int
) – The total number of function evaluations (NFE).save_intermediate (
bool
) – If true, also return the intermediate model values.
- Returns:
The approximated solution at time t_end.
- Return type:
x_end (
torch.Tensor
)
ODESolver¶
- class grl.numerical_methods.ODESolver(ode_solver='euler', dt=0.01, atol=1e-05, rtol=1e-05, library='torchdyn', **kwargs)[source]¶
- Overview:
The ODE solver class.
- Interfaces:
__init__
,integrate
- __init__(ode_solver='euler', dt=0.01, atol=1e-05, rtol=1e-05, library='torchdyn', **kwargs)[source]¶
- Overview:
Initialize the ODE solver using torchdiffeq or torchdyn library.
- Parameters:
ode_solver (
str
) – The ODE solver to use.dt (
float
) – The time step.atol (
float
) – The absolute tolerance.rtol (
float
) – The relative tolerance.library (
str
) – The library to use for the ODE solver. Currently, it supports ‘torchdiffeq’ and ‘torchdyn’.**kwargs – Additional arguments for the ODE solver.
- integrate(drift, x0, t_span, **kwargs)[source]¶
- Overview:
Integrate the ODE.
- Parameters:
drift (
Union[nn.Module, Callable]
) – The drift term of the ODE.x0 (
Union[torch.Tensor, TensorDict]
) – The input initial state.t_span (
torch.Tensor
) – The time at which to evaluate the ODE. The first element is the initial time, and the last element is the final time. For example, t = torch.tensor([0.0, 1.0]).
- Returns:
The output trajectory of the ODE, which has the same data type as x0 and the shape of (len(t_span), *x0.shape).
- Return type:
trajectory (
Union[torch.Tensor, TensorDict]
)
SDESolver¶
- class grl.numerical_methods.SDESolver(sde_solver='euler', sde_noise_type='diagonal', sde_type='ito', dt=0.001, atol=1e-05, rtol=1e-05, library='torchsde', **kwargs)[source]¶
- __init__(sde_solver='euler', sde_noise_type='diagonal', sde_type='ito', dt=0.001, atol=1e-05, rtol=1e-05, library='torchsde', **kwargs)[source]¶
- Overview:
Initialize the SDE solver using torchsde library.
- Parameters:
sde_solver (
str
) – The SDE solver to use.sde_noise_type (
str
) – The type of noise of the SDE. It can be ‘diagonal’, ‘general’, ‘scalar’ or ‘additive’.sde_type (
str
) – The type of the SDE. It can be ‘ito’ or ‘stratonovich’.dt (
float
) – The time step.atol (
float
) – The absolute tolerance.rtol (
float
) – The relative tolerance.library (
str
) – The library to use for the ODE solver. Currently, it supports ‘torchsde’.**kwargs – Additional arguments for the ODE solver.
GaussianConditionalProbabilityPath¶
- class grl.numerical_methods.GaussianConditionalProbabilityPath(config)[source]¶
- Overview:
Gaussian conditional probability path.
General case:
\[p(x(t)|x(0))=\mathcal{N}(x(t);\mu(t,x(0)),\sigma^2(t,x(0))I)\]If written in the form of SDE:
\[\mathrm{d}x=f(x,t)\mathrm{d}t+g(t)w_{t}\]where \(f(x,t)\) is the drift term, \(g(t)\) is the diffusion term, and \(w_{t}\) is the Wiener process.
For diffusion model:
\[p(x(t)|x(0))=\mathcal{N}(x(t);s(t)x(0),\sigma^2(t)I)\]or,
\[p(x(t)|x(0))=\mathcal{N}(x(t);s(t)x(0),s^2(t)e^{-2\lambda(t)}I)\]If written in the form of SDE:
\[\mathrm{d}x=\frac{s'(t)}{s(t)}x(t)\mathrm{d}t+s^2(t)\sqrt{2(\frac{s'(t)}{s(t)}-\lambda'(t))}e^{-\lambda(t)}\mathrm{d}w_{t}\]or,
\[\mathrm{d}x=f(t)x(t)\mathrm{d}t+g(t)w_{t}\]- where \(s(t)\) is the scale factor, \(\sigma^2(t)I\) is the covariance matrix,
\(\sigma(t)\) is the standard deviation with the scale factor, \(e^{-2\lambda(t)}I\) is the covariance matrix without the scale factor, \(\lambda(t)\) is the half-logSNR, which is the difference between the log scale factor and the log standard deviation, \(\lambda(t)=\log(s(t))-\log(\sigma(t))\).
For VP SDE:
\[p(x(t)|x(0))=\mathcal{N}(x(t);x(0)e^{-\frac{1}{2}\int_{0}^{t}{\beta(s)\mathrm{d}s}},(1-e^{-\int_{0}^{t}{\beta(s)\mathrm{d}s}})I)\]For Linear VP SDE:
\[p(x(t)|x(0))=\mathcal{N}(x(t);x(0)e^{-\frac{\beta(1)-\beta(0)}{4}t^2-\frac{\beta(0)}{2}t},(1-e^{-\frac{\beta(1)-\beta(0)}{2}t^2-\beta(0)t})I)\]#TODO: add more details for Cosine VP SDE; General VE SDE; OPT-Flow;
- HalfLogSNR(t)[source]¶
- Overview:
Compute the half-logSNR of the Gaussian conditional probability path, which is
\[\log(s(t))-\log(\sigma(t))\]
- Parameters:
t (
torch.Tensor
) – The input time.- Returns:
The half-logSNR.
- Return type:
HalfLogSNR (
torch.Tensor
)
- InverseHalfLogSNR(HalfLogSNR)[source]¶
- Overview:
Compute the inverse function of the half-logSNR of the Gaussian conditional probability path. Since the half-logSNR is an invertible function, we can compute the time t from the half-logSNR. For linear VP SDE, the inverse function is
\[t(\lambda)=\frac{1}{\beta_1-\beta_0}(\sqrt{\beta_0^2+2(\beta_1-\beta_0)\log{(e^{-2\lambda}+1)}}-\beta_0)\]or,
\[t(\lambda)=\frac{2(\beta_1-\beta_0)\log{(e^{-2\lambda}+1)}}{\sqrt{\beta_0^2+2(\beta_1-\beta_0)\log{(e^{-2\lambda}+1)}}+\beta_0}\]
- Parameters:
HalfLogSNR (
torch.Tensor
) – The input half-logSNR.- Returns:
The time.
- Return type:
t (
torch.Tensor
)
- __init__(config)[source]¶
- Overview:
Initialize the Gaussian conditional probability path.
- Parameters:
config (
EasyDict
) – The configuration of the Gaussian conditional probability path.
- covariance(t)[source]¶
- Overview:
Compute the covariance matrix of the Gaussian conditional probability path, which is
\[\Sigma(t)\]or
\[\sigma^2(t)I\]or
\[s^2(t)e^{-2\lambda(t)}I\]
- Parameters:
t (
torch.Tensor
) – The input time.- Returns:
The covariance matrix.
- Return type:
covariance (
torch.Tensor
)
- d_covariance_dt(t)[source]¶
- Overview:
Compute the time derivative of the covariance matrix of the Gaussian conditional probability path, which is
\[\frac{\mathrm{d}\Sigma(t)}{\mathrm{d}t}\]
- Parameters:
t (
torch.Tensor
) – The input time.- Returns:
The time derivative of the covariance matrix.
- Return type:
d_covariance_dt (
torch.Tensor
)
- d_log_scale_dt(t)[source]¶
- Overview:
Compute the time derivative of the log scale factor of the Gaussian conditional probability path, which is
\[\log(s'(t))\]
- Parameters:
t (
torch.Tensor
) – The input time.- Returns:
The time derivative of the log scale factor.
- Return type:
d_log_scale_dt (
Union[torch.Tensor, TensorDict]
)
- d_scale_dt(t)[source]¶
- Overview:
Compute the time derivative of the scale factor of the Gaussian conditional probability path, which is
\[s'(t)\]
- Parameters:
t (
torch.Tensor
) – The input time.- Returns:
The time derivative of the scale factor.
- Return type:
d_scale_dt (
Union[torch.Tensor, TensorDict]
)
- d_std_dt(t)[source]¶
- Overview:
Compute the time derivative of the standard deviation of the Gaussian conditional probability path, which is
\[\frac{\mathrm{d}\sigma(t)}{\mathrm{d}t}\]
- Parameters:
t (
torch.Tensor
) – The input time.- Return type:
Tensor
- diffusion(t)[source]¶
- Overview:
Return the diffusion term of the Gaussian conditional probability path. The diffusion term is given by the following:
\[g(x,t)\]
- Parameters:
t (
torch.Tensor
) – The input time.x (
Union[torch.Tensor, TensorDict]
) – The input state.
- Returns:
The output diffusion term.
- Return type:
diffusion (
Union[torch.Tensor, TensorDict]
)
- diffusion_squared(t)[source]¶
- Overview:
Return the diffusion term of the Gaussian conditional probability path. The diffusion term is given by the following:
\[g^2(x,t)\]
- Parameters:
t (
torch.Tensor
) – The input time.x (
Union[torch.Tensor, TensorDict]
) – The input state.
- Returns:
The output diffusion term.
- Return type:
diffusion (
Union[torch.Tensor, TensorDict]
)
- drift(t, x=None)[source]¶
- Overview:
Return the drift term of the Gaussian conditional probability path. The drift term is given by the following:
\[f(x,t)\]
- Parameters:
t (
torch.Tensor
) – The input time.x (
Union[torch.Tensor, TensorDict]
) – The input state.
- Returns:
The output drift term.
- Return type:
drift (
Union[torch.Tensor, TensorDict]
)
- drift_coefficient(t)[source]¶
- Overview:
Return the drift coefficient term of the Gaussian conditional probability path. The drift term is given by the following:
\[f(t)\]which satisfies the following SDE:
\[\mathrm{d}x=f(t)x(t)\mathrm{d}t+g(t)w_{t}\]
- Parameters:
t (
torch.Tensor
) – The input time.- Returns:
The output drift term.
- Return type:
drift (
Union[torch.Tensor, TensorDict]
)
- log_scale(t)[source]¶
- Overview:
Compute the log scale factor of the Gaussian conditional probability path, which is
\[\log(s(t))\]
- Parameters:
t (
torch.Tensor
) – The input time.- Returns:
The log scale factor.
- Return type:
log_scale (
torch.Tensor
)