grl.generative_models¶
DiffusionModel¶
- class grl.generative_models.DiffusionModel(config)[source]¶
- Overview:
General diffusion model class that supports various types of continuous-time diffusion paths, which supports sampling, computing score function and velocity function. It can be modeled via score function, noise function, velocity function, or data prediction function. Both score matching loss and flow matching loss are supported.
- Interfaces:
__init__
,sample
,score_function
,score_matching_loss
,velocity_function
,flow_matching_loss
.
- __init__(config)[source]¶
- Overview:
Initialization of Diffusion Model.
- Parameters:
config (
EasyDict
) – The configuration.
- data_prediction_function(t, x, condition=None)[source]¶
- Overview:
Return data prediction function of the model at time t given the initial state.
\[\frac{- \sigma(t) x_t + \sigma^2(t) \nabla_{x_t} \log p_{\theta}(x_t)}{s(t)}\]
- Parameters:
t (
torch.Tensor
) – The input time.x (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input state.condition (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input condition.
- Return type:
Union
[Tensor
,TensorDict
,Tensor
]
- dpo_loss(ref_dm, data, beta)[source]¶
- Return type:
Tensor
- Overview:
The loss function for training the diffusion process by Direct Policy Optimization (DPO). This is an in-development feature and is not recommended for general use.
- flow_matching_loss(x, condition=None, average=True)[source]¶
- Overview:
Return the flow matching loss function of the model given the initial state and the condition.
- Parameters:
x (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input state.condition (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input condition.average (
bool
) – Whether to average the loss across the batch.
- Return type:
Tensor
- forward_sample(x, t_span, condition=None, with_grad=False, solver_config=None)[source]¶
- Overview:
Use forward path of the diffusion model given the sampled x. Note that this is not the reverse process, and thus is not designed for sampling form the diffusion model. Rather, it is used for encode a sampled x to the latent space.
- Parameters:
x (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input state.t_span (
torch.Tensor
) – The time span.condition (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input condition.with_grad (
bool
) – Whether to return the gradient.solver_config (
EasyDict
) – The configuration of the solver.
- forward_sample_process(x, t_span, condition=None, with_grad=False, solver_config=None)[source]¶
- Overview:
Use forward path of the diffusion model given the sampled x. Note that this is not the reverse process, and thus is not designed for sampling form the diffusion model. Rather, it is used for encode a sampled x to the latent space. Return all intermediate states.
- Parameters:
x (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input state.t_span (
torch.Tensor
) – The time span.condition (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input condition.with_grad (
bool
) – Whether to return the gradient.solver_config (
EasyDict
) – The configuration of the solver.
- log_prob(x=None, condition=None, using_Hutchinson_trace_estimator=True, with_grad=False)[source]¶
- Overview:
Return the log probability of the model given the initial state and the condition.
\[\log p_{\theta}(x)\]
- Parameters:
x (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input state.condition (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input condition.with_grad (
bool
) – Whether to return the gradient.
- noise_function(t, x, condition=None)[source]¶
- Overview:
Return noise function of the model at time t given the initial state.
\[- \sigma(t) \nabla_{x_t} \log p_{\theta}(x_t)\]
- Parameters:
t (
torch.Tensor
) – The input time.x (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input state.condition (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input condition.
- Return type:
Union
[Tensor
,TensorDict
,Tensor
]
- sample(t_span=None, batch_size=None, x_0=None, condition=None, with_grad=False, solver_config=None)[source]¶
- Overview:
Sample from the diffusion model, return the final state.
- Parameters:
t_span (
torch.Tensor
) – The time span.batch_size (
Union[torch.Size, int, Tuple[int], List[int]]
) – The batch size.x_0 (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The initial state, if not provided, it will be sampled from the Gaussian distribution.condition (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input condition.with_grad (
bool
) – Whether to return the gradient.solver_config (
EasyDict
) – The configuration of the solver.
- Returns:
The sampled result.
- Return type:
x (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
)
- Shapes:
t_span: \((T)\), where \(T\) is the number of time steps. batch_size: \((B)\), where \(B\) is the batch size of data, which could be a scalar or a tensor such as \((B1, B2)\). x_0: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the state, which could be a scalar or a tensor such as \((D1, D2)\). condition: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the condition, which could be a scalar or a tensor such as \((D1, D2)\). x: \((N, D)\), if extra batch size \(B\) is provided, the shape will be \((B, N, D)\). If x_0 is not provided, the shape will be \((B, D)\). If x_0 and condition are not provided, the shape will be \((D)\).
- sample_forward_process(t_span=None, batch_size=None, x_0=None, condition=None, with_grad=False, solver_config=None)[source]¶
- Overview:
Sample from the diffusion model, return all intermediate states.
- Parameters:
t_span (
torch.Tensor
) – The time span.batch_size (
Union[torch.Size, int, Tuple[int], List[int]]
) – An extra batch size used for repeated sampling with the same initial state.x_0 (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The initial state, if not provided, it will be sampled from the Gaussian distribution.condition (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input condition.with_grad (
bool
) – Whether to return the gradient.solver_config (
EasyDict
) – The configuration of the solver.
- Returns:
The sampled result.
- Return type:
x (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
)
- Shapes:
t_span: \((T)\), where \(T\) is the number of time steps. batch_size: \((B)\), where \(B\) is the batch size of data, which could be a scalar or a tensor such as \((B1, B2)\). x_0: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the state, which could be a scalar or a tensor such as \((D1, D2)\). condition: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the condition, which could be a scalar or a tensor such as \((D1, D2)\). x: \((T, N, D)\), if extra batch size \(B\) is provided, the shape will be \((T, B, N, D)\). If x_0 is not provided, the shape will be \((T, B, D)\). If x_0 and condition are not provided, the shape will be \((T, D)\).
- sample_forward_process_with_fixed_x(fixed_x, fixed_mask, t_span=None, batch_size=None, x_0=None, condition=None, with_grad=False, solver_config=None)[source]¶
- Overview:
Sample from the diffusion model with fixed x, return all intermediate states.
- Parameters:
fixed_x (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The fixed x.fixed_mask (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The fixed mask.t_span (
torch.Tensor
) – The time span.batch_size (
Union[torch.Size, int, Tuple[int], List[int]]
) – The batch size.x_0 (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The initial state, if not provided, it will be sampled from the Gaussian distribution.condition (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input condition.with_grad (
bool
) – Whether to return the gradient.solver_config (
EasyDict
) – The configuration of the solver.
- Returns:
The sampled result.
- Return type:
x (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
)
- Shapes:
fixed_x: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the state, which could be a scalar or a tensor such as \((D1, D2)\). fixed_mask: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the mask, which could be a scalar or a tensor such as \((D1, D2)\). t_span: \((T)\), where \(T\) is the number of time steps. batch_size: \((B)\), where \(B\) is the batch size of data, which could be a scalar or a tensor such as \((B1, B2)\). x_0: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the state, which could be a scalar or a tensor such as \((D1, D2)\). condition: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the condition, which could be a scalar or a tensor such as \((D1, D2)\). x: \((T, N, D)\), if extra batch size \(B\) is provided, the shape will be \((T, B, N, D)\). If x_0 is not provided, the shape will be \((T, B, D)\). If x_0 and condition are not provided, the shape will be \((T, D)\).
- sample_with_fixed_x(fixed_x, fixed_mask, t_span=None, batch_size=None, x_0=None, condition=None, with_grad=False, solver_config=None)[source]¶
- Overview:
Sample from the diffusion model with fixed x, return the final state.
- Parameters:
fixed_x (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The fixed x.fixed_mask (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The fixed mask.t_span (
torch.Tensor
) – The time span.batch_size (
Union[torch.Size, int, Tuple[int], List[int]]
) – The batch size.x_0 (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The initial state, if not provided, it will be sampled from the Gaussian distribution.condition (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input condition.with_grad (
bool
) – Whether to return the gradient.solver_config (
EasyDict
) – The configuration of the solver.
- Returns:
The sampled result.
- Return type:
x (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
)
- Shapes:
fixed_x: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the state, which could be a scalar or a tensor such as \((D1, D2)\). fixed_mask: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the mask, which could be a scalar or a tensor such as \((D1, D2)\). t_span: \((T)\), where \(T\) is the number of time steps. batch_size: \((B)\), where \(B\) is the batch size of data, which could be a scalar or a tensor such as \((B1, B2)\). x_0: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the state, which could be a scalar or a tensor such as \((D1, D2)\). condition: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the condition, which could be a scalar or a tensor such as \((D1, D2)\). x: \((N, D)\), if extra batch size \(B\) is provided, the shape will be \((B, N, D)\). If x_0 is not provided, the shape will be \((B, D)\). If x_0 and condition are not provided, the shape will be \((D,)\).
- sample_with_log_prob(t_span, batch_size=None, condition=None, with_grad=False, solver_config=None)[source]¶
- Overview:
Sample from the model and return the log probability of the sampled result.
- Parameters:
t_span (
torch.Tensor
) – The time span.batch_size (
Union[torch.Size, int, Tuple[int], List[int]]
) – The batch size.condition (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input condition.with_grad (
bool
) – Whether to return the gradient.solver_config (
EasyDict
) – The configuration of the solver.
- score_function(t, x, condition=None)[source]¶
- Overview:
Return score function of the model at time t given the initial state, which is the gradient of the log-likelihood.
\[\nabla_{x_t} \log p_{\theta}(x_t)\]
- Parameters:
t (
torch.Tensor
) – The input time.x (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input state.condition (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input condition.
- Return type:
Union
[Tensor
,TensorDict
,Tensor
]
- score_matching_loss(x, condition=None, weighting_scheme=None, average=True)[source]¶
- Overview:
The loss function for training unconditional diffusion model.
- Parameters:
x (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input.condition (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input condition.weighting_scheme (
str
) – The weighting scheme for score matching loss, which can be “maximum_likelihood” or “vanilla”...note:: –
“maximum_likelihood”: The weighting scheme is based on the maximum likelihood estimation. Refer to the paper “Maximum Likelihood Training of Score-Based Diffusion Models” for more details. The weight \(\lambda(t)\) is denoted as:
\[\lambda(t) = g^2(t)\]for numerical stability, we use Monte Carlo sampling to approximate the integral of \(\lambda(t)\).
\[\lambda(t) = g^2(t) = p(t)\sigma^2(t)\]“vanilla”: The weighting scheme is based on the vanilla score matching, which balances the MSE loss by scaling the model output to the noise value. Refer to the paper “Score-Based Generative Modeling through Stochastic Differential Equations” for more details. The weight \(\lambda(t)\) is denoted as:
\[\lambda(t) = \sigma^2(t)\]
- Return type:
Tensor
- velocity_function(t, x, condition=None)[source]¶
- Overview:
Return velocity of the model at time t given the initial state.
\[v_{\theta}(t, x)\]
- Parameters:
t (
torch.Tensor
) – The input time.x (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input state at time t.condition (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input condition.
- Return type:
Union
[Tensor
,TensorDict
,Tensor
]
EnergyConditionalDiffusionModel¶
- class grl.generative_models.EnergyConditionalDiffusionModel(config, energy_model)[source]¶
- Overview:
Energy Conditional Diffusion Model, which is a diffusion model conditioned on energy.
\[p_{\text{E}}(x|c)\sim \frac{\exp{\mathcal{E}(x,c)}}{Z(c)}p(x|c)\]- Interfaces:
__init__
,sample
,sample_without_energy_guidance
,sample_forward_process
,score_function
,score_function_with_energy_guidance
,score_matching_loss
,velocity_function
,flow_matching_loss
,energy_guidance_loss
- __init__(config, energy_model)[source]¶
- Overview:
Initialization of Energy Conditional Diffusion Model.
- Parameters:
config (
EasyDict
) – The configuration.energy_model (
Union[torch.nn.Module, torch.nn.ModuleDict]
) – The energy model.
- data_prediction_function(t, x, condition=None)[source]¶
- Overview:
Return data prediction function of the model at time t given the initial state.
\[\frac{- \sigma(t) x_t + \sigma^2(t) \nabla_{x_t} \log p_{\theta}(x_t)}{s(t)}\]
- Parameters:
t (
torch.Tensor
) – The input time.x (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input state.condition (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input condition.
- Return type:
Union
[Tensor
,TensorDict
,Tensor
]
- data_prediction_function_with_energy_guidance(t, x, condition=None, guidance_scale=1.0)[source]¶
- Overview:
The data prediction function for energy guidance.
- Parameters:
t (
torch.Tensor
) – The input time.x (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input state.condition (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input condition.guidance_scale (
float
) – The scale of guidance.
- Returns:
The score function.
- Return type:
x (
torch.Tensor
)
- energy_guidance_loss(x, condition=None)[source]¶
- Overview:
The loss function for training Energy Guidance, CEP guidance method, as proposed in the paper “Contrastive Energy Prediction for Exact Energy-Guided Diffusion Sampling in Offline Reinforcement Learning”
- Parameters:
x (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input.condition (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input condition.
- flow_matching_loss(x, condition=None, average=True)[source]¶
- Overview:
Return the flow matching loss function of the model given the initial state and the condition.
- Parameters:
x (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input state.condition (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input condition.average (
bool
) – Whether to average the loss across the batch.
- Return type:
Tensor
- noise_function(t, x, condition=None)[source]¶
- Overview:
Return noise function of the model at time t given the initial state.
\[- \sigma(t) \nabla_{x_t} \log p_{\theta}(x_t)\]
- Parameters:
t (
torch.Tensor
) – The input time.x (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input state.condition (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input condition.
- Return type:
Union
[Tensor
,TensorDict
,Tensor
]
- noise_function_with_energy_guidance(t, x, condition=None, guidance_scale=1.0)[source]¶
- Overview:
The noise function for energy guidance.
- Parameters:
t (
torch.Tensor
) – The input time.x (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input state.condition (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input condition.guidance_scale (
float
) – The scale of guidance.
- Returns:
The nose function.
- Return type:
noise (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
)
- sample(t_span=None, batch_size=None, x_0=None, condition=None, guidance_scale=1.0, with_grad=False, solver_config=None)[source]¶
- Overview:
Sample from the energy conditioned diffusion model by using score function.
\[\nabla p_{\text{E}}(x|c) = \nabla p(x|c) + \nabla \mathcal{E}(x,c,t)\]
- Parameters:
t_span (
torch.Tensor
) – The time span.batch_size (
Union[torch.Size, int, Tuple[int], List[int]]
) – The batch size.x_0 (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The initial state, if not provided, it will be sampled from the Gaussian distribution.condition (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input condition.guidance_scale (
float
) – The scale of guidance.with_grad (
bool
) – Whether to return the gradient.solver_config (
EasyDict
) – The configuration of the solver.
- Returns:
The sampled result.
- Return type:
x (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
)
- Shapes:
t_span: \((T)\), where \(T\) is the number of time steps. batch_size: \((B)\), where \(B\) is the batch size of data, which could be a scalar or a tensor such as \((B1, B2)\). x_0: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the state, which could be a scalar or a tensor such as \((D1, D2)\). condition: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the condition, which could be a scalar or a tensor such as \((D1, D2)\). x: \((N, D)\), if extra batch size \(B\) is provided, the shape will be \((B, N, D)\). If x_0 is not provided, the shape will be \((B, D)\). If x_0 and condition are not provided, the shape will be \((D)\).
- sample_forward_process(t_span=None, batch_size=None, x_0=None, condition=None, guidance_scale=1.0, with_grad=False, solver_config=None)[source]¶
- Overview:
Sample from the diffusion model.
- Parameters:
t_span (
torch.Tensor
) – The time span.batch_size (
Union[torch.Size, int, Tuple[int], List[int]]
) – The batch size.x_0 (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The initial state, if not provided, it will be sampled from the Gaussian distribution.condition (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input condition.guidance_scale (
float
) – The scale of guidance.with_grad (
bool
) – Whether to return the gradient.solver_config (
EasyDict
) – The configuration of the solver.
- Returns:
The sampled result.
- Return type:
x (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
)
- sample_forward_process_with_fixed_x(fixed_x, fixed_mask, t_span=None, batch_size=None, x_0=None, condition=None, guidance_scale=1.0, with_grad=False, solver_config=None)[source]¶
- Overview:
Sample from the diffusion model with fixed x.
- Parameters:
fixed_x (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The fixed x.fixed_mask (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The fixed mask.t_span (
torch.Tensor
) – The time span.batch_size (
Union[torch.Size, int, Tuple[int], List[int]]
) – The batch size.x_0 (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The initial state, if not provided, it will be sampled from the Gaussian distribution.condition (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input condition.guidance_scale (
float
) – The scale of guidance.with_grad (
bool
) – Whether to return the gradient.solver_config (
EasyDict
) – The configuration of the solver.
- Returns:
The sampled result.
- Return type:
x (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
)
- Shapes:
fixed_x: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the state, which could be a scalar or a tensor such as \((D1, D2)\). fixed_mask: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the mask, which could be a scalar or a tensor such as \((D1, D2)\). t_span: \((T)\), where \(T\) is the number of time steps. batch_size: \((B)\), where \(B\) is the batch size of data, which could be a scalar or a tensor such as \((B1, B2)\). x_0: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the state, which could be a scalar or a tensor such as \((D1, D2)\). condition: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the condition, which could be a scalar or a tensor such as \((D1, D2)\). x: \((T, N, D)\), if extra batch size \(B\) is provided, the shape will be \((T, B, N, D)\). If x_0 is not provided, the shape will be \((T, B, D)\). If x_0 and condition are not provided, the shape will be \((T, D)\).
- sample_with_fixed_x(fixed_x, fixed_mask, t_span=None, batch_size=None, x_0=None, condition=None, guidance_scale=1.0, with_grad=False, solver_config=None)[source]¶
- Overview:
Sample from the diffusion model with fixed x.
- Parameters:
fixed_x (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The fixed x.fixed_mask (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The fixed mask.t_span (
torch.Tensor
) – The time span.batch_size (
Union[torch.Size, int, Tuple[int], List[int]]
) – The batch size.x_0 (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The initial state, if not provided, it will be sampled from the Gaussian distribution.condition (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input condition.guidance_scale (
float
) – The scale of guidance.with_grad (
bool
) – Whether to return the gradient.solver_config (
EasyDict
) – The configuration of the solver.
- Returns:
The sampled result.
- Return type:
x (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
)
- Shapes:
fixed_x: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the state, which could be a scalar or a tensor such as \((D1, D2)\). fixed_mask: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the mask, which could be a scalar or a tensor such as \((D1, D2)\). t_span: \((T)\), where \(T\) is the number of time steps. batch_size: \((B)\), where \(B\) is the batch size of data, which could be a scalar or a tensor such as \((B1, B2)\). x_0: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the state, which could be a scalar or a tensor such as \((D1, D2)\). condition: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the condition, which could be a scalar or a tensor such as \((D1, D2)\). x: \((N, D)\), if extra batch size \(B\) is provided, the shape will be \((B, N, D)\). If x_0 is not provided, the shape will be \((B, D)\). If x_0 and condition are not provided, the shape will be \((D,)\).
- sample_with_fixed_x_without_energy_guidance(fixed_x, fixed_mask, t_span=None, batch_size=None, x_0=None, condition=None, with_grad=False, solver_config=None)[source]¶
- Overview:
Sample from the diffusion model with fixed x without energy guidance.
- Parameters:
fixed_x (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The fixed x.fixed_mask (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The fixed mask.t_span (
torch.Tensor
) – The time span.batch_size (
Union[torch.Size, int, Tuple[int], List[int]]
) – The batch size.x_0 (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The initial state, if not provided, it will be sampled from the Gaussian distribution.condition (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input condition.with_grad (
bool
) – Whether to return the gradient.solver_config (
EasyDict
) – The configuration of the solver.
- Returns:
The sampled result.
- Return type:
x (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
)
- Shapes:
fixed_x: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the state, which could be a scalar or a tensor such as \((D1, D2)\). fixed_mask: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the mask, which could be a scalar or a tensor such as \((D1, D2)\). t_span: \((T)\), where \(T\) is the number of time steps. batch_size: \((B)\), where \(B\) is the batch size of data, which could be a scalar or a tensor such as \((B1, B2)\). x_0: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the state, which could be a scalar or a tensor such as \((D1, D2)\). condition: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the condition, which could be a scalar or a tensor such as \((D1, D2)\). x: \((N, D)\), if extra batch size \(B\) is provided, the shape will be \((B, N, D)\). If x_0 is not provided, the shape will be \((B, D)\). If x_0 and condition are not provided, the shape will be \((D,)\).
- sample_without_energy_guidance(t_span=None, batch_size=None, x_0=None, condition=None, with_grad=False, solver_config=None)[source]¶
- Overview:
Sample from the diffusion model without energy guidance.
- Parameters:
t_span (
torch.Tensor
) – The time span.batch_size (
Union[torch.Size, int, Tuple[int], List[int]]
) – The batch size.x_0 (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The initial state, if not provided, it will be sampled from the Gaussian distribution.condition (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input condition.with_grad (
bool
) – Whether to return the gradient.solver_config (
EasyDict
) – The configuration of the solver.
- Returns:
The sampled result.
- Return type:
x (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
)
- Shapes:
t_span: \((T)\), where \(T\) is the number of time steps. batch_size: \((B)\), where \(B\) is the batch size of data, which could be a scalar or a tensor such as \((B1, B2)\). x_0: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the state, which could be a scalar or a tensor such as \((D1, D2)\). condition: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the condition, which could be a scalar or a tensor such as \((D1, D2)\). x: \((T, N, D)\), if extra batch size \(B\) is provided, the shape will be \((T, B, N, D)\). If x_0 is not provided, the shape will be \((T, B, D)\). If x_0 and condition are not provided, the shape will be \((T, D)\).
- score_function(t, x, condition=None)[source]¶
- Overview:
Return score function of the model at time t given the initial state, which is the gradient of the log-likelihood.
\[\nabla_{x_t} \log p_{\theta}(x_t)\]
- Parameters:
t (
torch.Tensor
) – The input time.x (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input state.condition (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input condition.
- Return type:
Union
[Tensor
,TensorDict
,Tensor
]
- score_function_with_energy_guidance(t, x, condition=None, guidance_scale=1.0)[source]¶
- Overview:
The score function for energy guidance.
- Parameters:
t (
torch.Tensor
) – The input time.x (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input state.condition (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input condition.guidance_scale (
float
) – The scale of guidance.
- Returns:
The score function.
- Return type:
score (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
)
- score_matching_loss(x, condition=None, weighting_scheme=None, average=True)[source]¶
- Overview:
The loss function for training unconditional diffusion model.
- Parameters:
x (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input.condition (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input condition.weighting_scheme (
str
) – The weighting scheme for score matching loss, which can be “maximum_likelihood” or “vanilla”.average (
bool
) – Whether to average the loss across the batch...note:: –
“maximum_likelihood”: The weighting scheme is based on the maximum likelihood estimation. Refer to the paper “Maximum Likelihood Training of Score-Based Diffusion Models” for more details. The weight \(\lambda(t)\) is denoted as:
\[\lambda(t) = g^2(t)\]for numerical stability, we use Monte Carlo sampling to approximate the integral of \(\lambda(t)\).
\[\lambda(t) = g^2(t) = p(t)\sigma^2(t)\]“vanilla”: The weighting scheme is based on the vanilla score matching, which balances the MSE loss by scaling the model output to the noise value. Refer to the paper “Score-Based Generative Modeling through Stochastic Differential Equations” for more details. The weight \(\lambda(t)\) is denoted as:
\[\lambda(t) = \sigma^2(t)\]
- Return type:
Tensor
- velocity_function(t, x, condition=None)[source]¶
- Overview:
Return velocity of the model at time t given the initial state.
\[v_{\theta}(t, x)\]
- Parameters:
t (
torch.Tensor
) – The input time.x (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input state at time t.condition (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input condition.
- Return type:
Union
[Tensor
,TensorDict
,Tensor
]
IndependentConditionalFlowModel¶
- class grl.generative_models.IndependentConditionalFlowModel(config)[source]¶
- Overview:
The independent conditional flow model, which is a flow model with independent conditional probability paths.
- Interfaces:
__init__
,get_type
,sample
,sample_forward_process
,flow_matching_loss
- __init__(config)[source]¶
- Overview:
Initialize the model.
- Parameters:
config (
EasyDict
) – The configuration of the model.
- flow_matching_loss(x0, x1, condition=None, average=True)[source]¶
- Overview:
Return the flow matching loss function of the model given the initial state and the condition.
- Parameters:
x0 (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The initial state.x1 (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The final state.condition (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The condition for the flow matching loss.average (
bool
) – Whether to average the loss across the batch.
- Return type:
Tensor
- flow_matching_loss_with_mask(x0, x1, condition=None, mask=None, average=True)[source]¶
- Overview:
Return the flow matching loss function of the model given the initial state and the condition with a mask.
- Parameters:
x0 (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The initial state.x1 (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The final state.condition (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The condition for the flow matching loss.mask (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The mask signal for x0, which is either True or False and has same shape as x0, if it is True, the corresponding element in x0 will not be used for the loss computation, and the true value of that element is usually not provided in condition.average (
bool
) – Whether to average the loss across the batch.
- Return type:
Tensor
- forward_sample(x, t_span, condition=None, with_grad=False, solver_config=None)[source]¶
- Overview:
Use forward path of the flow model given the sampled x. Note that this is not the reverse process, and thus is not designed for sampling form the flow model. Rather, it is used for encode a sampled x to the latent space.
- Parameters:
x (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input state.t_span (
torch.Tensor
) – The time span.condition (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input condition.with_grad (
bool
) – Whether to return the gradient.solver_config (
EasyDict
) – The configuration of the solver.
- forward_sample_process(x, t_span, condition=None, with_grad=False, solver_config=None)[source]¶
- Overview:
Use forward path of the diffusion model given the sampled x. Note that this is not the reverse process, and thus is not designed for sampling form the diffusion model. Rather, it is used for encode a sampled x to the latent space. Return all intermediate states.
- Parameters:
x (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input state.t_span (
torch.Tensor
) – The time span.condition (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input condition.with_grad (
bool
) – Whether to return the gradient.solver_config (
EasyDict
) – The configuration of the solver.
- log_prob(x_1, log_prob_x_0=None, function_log_prob_x_0=None, condition=None, t=None, using_Hutchinson_trace_estimator=True)[source]¶
- Overview:
Compute the log probability of the final state given the initial state and the condition.
- Parameters:
x_1 (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The final state.log_prob_x_0 (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The log probability of the initial state.function_log_prob_x_0 (
Union[callable, nn.Module]
) – The function to compute the log probability of the initial state.condition (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The condition.t (
torch.Tensor
) – The time span.using_Hutchinson_trace_estimator (
bool
) – Whether to use Hutchinson trace estimator. It is an approximation of the trace of the Jacobian of the drift function, which is faster but less accurate. We recommend setting it to True for high dimensional data.
- Returns:
The log likelihood of the final state given the initial state and the condition.
- Return type:
log_likelihood (
torch.Tensor
)
- optimal_transport_flow_matching_loss(x0, x1, condition=None, average=True)[source]¶
- Overview:
Return the flow matching loss function of the model given the initial state and the condition, using the optimal transport plan to match samples from two distributions.
- Parameters:
x (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input state.condition (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input condition.
- Return type:
Tensor
- sample(t_span=None, batch_size=None, x_0=None, condition=None, with_grad=False, solver_config=None)[source]¶
- Overview:
Sample from the model, return the final state.
- Parameters:
t_span (
torch.Tensor
) – The time span.batch_size (
Union[torch.Size, int, Tuple[int], List[int]]
) – The batch size.x_0 (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The initial state, if not provided, it will be sampled from the Gaussian distribution.condition (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input condition.with_grad (
bool
) – Whether to return the gradient.solver_config (
EasyDict
) – The configuration of the solver.
- Returns:
The sampled result.
- Return type:
x (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
)
- Shapes:
t_span: \((T)\), where \(T\) is the number of time steps. batch_size: \((B)\), where \(B\) is the batch size of data, which could be a scalar or a tensor such as \((B1, B2)\). x_0: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the state, which could be a scalar or a tensor such as \((D1, D2)\). condition: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the condition, which could be a scalar or a tensor such as \((D1, D2)\). x: \((N, D)\), if extra batch size \(B\) is provided, the shape will be \((B, N, D)\). If x_0 is not provided, the shape will be \((B, D)\). If x_0 and condition are not provided, the shape will be \((D)\).
- sample_forward_process(t_span=None, batch_size=None, x_0=None, condition=None, with_grad=False, solver_config=None)[source]¶
- Overview:
Sample from the diffusion model, return all intermediate states.
- Parameters:
t_span (
torch.Tensor
) – The time span.batch_size (
Union[torch.Size, int, Tuple[int], List[int]]
) – An extra batch size used for repeated sampling with the same initial state.x_0 (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The initial state, if not provided, it will be sampled from the Gaussian distribution.condition (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input condition.with_grad (
bool
) – Whether to return the gradient.solver_config (
EasyDict
) – The configuration of the solver.
- Returns:
The sampled result.
- Return type:
x (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
)
- Shapes:
t_span: \((T)\), where \(T\) is the number of time steps. batch_size: \((B)\), where \(B\) is the batch size of data, which could be a scalar or a tensor such as \((B1, B2)\). x_0: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the state, which could be a scalar or a tensor such as \((D1, D2)\). condition: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the condition, which could be a scalar or a tensor such as \((D1, D2)\). x: \((T, N, D)\), if extra batch size \(B\) is provided, the shape will be \((T, B, N, D)\). If x_0 is not provided, the shape will be \((T, B, D)\). If x_0 and condition are not provided, the shape will be \((T, D)\).
- sample_with_log_prob(t_span=None, batch_size=None, x_0=None, condition=None, with_grad=False, solver_config=None, using_Hutchinson_trace_estimator=True)[source]¶
- Overview:
Sample from the model, return the final state and the log probability of the initial state.
- Parameters:
t_span (
torch.Tensor
) – The time span.batch_size (
Union[torch.Size, int, Tuple[int], List[int]]
) – The batch size.x_0 (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The initial state, if not provided, it will be sampled from the Gaussian distribution.condition (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input condition.with_grad (
bool
) – Whether to return the gradient.solver_config (
EasyDict
) – The configuration of the solver.using_Hutchinson_trace_estimator (
bool
) – Whether to use Hutchinson trace estimator. It is an approximation of the trace of the Jacobian of the drift function, which is faster but less accurate. We recommend setting it to True for high dimensional data.
- Returns:
The sampled result. log_prob_x_0 (
torch.Tensor
): The log probability of the initial state.- Return type:
x (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
)
- sample_with_mask(t_span=None, batch_size=None, x_0=None, condition=None, mask=None, x_1=None, with_grad=False, solver_config=None)[source]¶
- Overview:
Sample from the model with masked element, return the final state.
- Parameters:
t_span (
torch.Tensor
) – The time span.batch_size (
Union[torch.Size, int, Tuple[int], List[int]]
) – The batch size.x_0 (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The initial state, if not provided, it will be sampled from the Gaussian distribution.condition (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input condition.mask (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The mask.x_1 (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The masked element, same shape as x_1.with_grad (
bool
) – Whether to return the gradient.solver_config (
EasyDict
) – The configuration of the solver.
- Returns:
The sampled result.
- Return type:
x (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
)
- Shapes:
t_span: \((T)\), where \(T\) is the number of time steps. batch_size: \((B)\), where \(B\) is the batch size of data, which could be a scalar or a tensor such as \((B1, B2)\). x_0: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the state, which could be a scalar or a tensor such as \((D1, D2)\). condition: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the condition, which could be a scalar or a tensor such as \((D1, D2)\). mask: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the mask, which could be a scalar or a tensor such as \((D1, D2)\). x_1: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the masked element, which could be a scalar or a tensor such as \((D1, D2)\). x: \((N, D)\), if extra batch size \(B\) is provided, the shape will be \((B, N, D)\). If x_0 is not provided, the shape will be \((B, D)\). If x_0 and condition are not provided, the shape will be \((D)\).
- sample_with_mask_forward_process(t_span=None, batch_size=None, x_0=None, condition=None, mask=None, x_1=None, with_grad=False, solver_config=None)[source]¶
- Overview:
Sample from the diffusion model, return all intermediate states.
- Parameters:
t_span (
torch.Tensor
) – The time span.batch_size (
Union[torch.Size, int, Tuple[int], List[int]]
) – An extra batch size used for repeated sampling with the same initial state.x_0 (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The initial state, if not provided, it will be sampled from the Gaussian distribution.condition (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input condition.mask (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The mask.x_1 (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The masked element, same shape as x_1.with_grad (
bool
) – Whether to return the gradient.solver_config (
EasyDict
) – The configuration of the solver.
- Returns:
The sampled result.
- Return type:
x (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
)
- Shapes:
t_span: \((T)\), where \(T\) is the number of time steps. batch_size: \((B)\), where \(B\) is the batch size of data, which could be a scalar or a tensor such as \((B1, B2)\). x_0: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the state, which could be a scalar or a tensor such as \((D1, D2)\). condition: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the condition, which could be a scalar or a tensor such as \((D1, D2)\). mask: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the mask, which could be a scalar or a tensor such as \((D1, D2)\). x_1: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the masked element, which could be a scalar or a tensor such as \((D1, D2)\). x: \((T, N, D)\), if extra batch size \(B\) is provided, the shape will be \((T, B, N, D)\). If x_0 is not provided, the shape will be \((T, B, D)\). If x_0 and condition are not provided, the shape will be \((T, D)\).
OptimalTransportConditionalFlowModel¶
- class grl.generative_models.OptimalTransportConditionalFlowModel(config)[source]¶
- Overview:
The optimal transport conditional flow model, which is based on an optimal transport plan between two distributions.
- Interfaces:
__init__
,get_type
,sample
,sample_forward_process
,flow_matching_loss
- __init__(config)[source]¶
- Overview:
Initialize the model.
- Parameters:
config (-) – The configuration of the model.
- flow_matching_loss(x0, x1, condition=None, average=True)[source]¶
- Overview:
Return the flow matching loss function of the model given the initial state and the condition, using the optimal transport plan to match samples from two distributions.
- Parameters:
x (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input state.condition (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input condition.
- Return type:
Tensor
- flow_matching_loss_small_batch_OT_plan(x0, x1, condition=None, average=True)[source]¶
- Overview:
Return the flow matching loss function of the model given the initial state and the condition, using the optimal transport plan for small batch size to accelerate the computation.
- Parameters:
x (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input state.condition (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input condition.
- Return type:
Tensor
- sample(t_span=None, batch_size=None, x_0=None, condition=None, with_grad=False, solver_config=None)[source]¶
- Overview:
Sample from the model, return the final state.
- Parameters:
t_span (
torch.Tensor
) – The time span.batch_size (
Union[torch.Size, int, Tuple[int], List[int]]
) – The batch size.x_0 (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The initial state, if not provided, it will be sampled from the Gaussian distribution.condition (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input condition.with_grad (
bool
) – Whether to return the gradient.solver_config (
EasyDict
) – The configuration of the solver.
- Returns:
The sampled result.
- Return type:
x (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
)
- Shapes:
t_span: \((T)\), where \(T\) is the number of time steps. batch_size: \((B)\), where \(B\) is the batch size of data, which could be a scalar or a tensor such as \((B1, B2)\). x_0: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the state, which could be a scalar or a tensor such as \((D1, D2)\). condition: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the condition, which could be a scalar or a tensor such as \((D1, D2)\). x: \((N, D)\), if extra batch size \(B\) is provided, the shape will be \((B, N, D)\). If x_0 is not provided, the shape will be \((B, D)\). If x_0 and condition are not provided, the shape will be \((D)\).
- sample_forward_process(t_span=None, batch_size=None, x_0=None, condition=None, with_grad=False, solver_config=None)[source]¶
- Overview:
Sample from the model, return all intermediate states.
- Parameters:
t_span (
torch.Tensor
) – The time span.batch_size (
Union[torch.Size, int, Tuple[int], List[int]]
) – An extra batch size used for repeated sampling with the same initial state.x_0 (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The initial state, if not provided, it will be sampled from the Gaussian distribution.condition (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
) – The input condition.with_grad (
bool
) – Whether to return the gradient.solver_config (
EasyDict
) – The configuration of the solver.
- Returns:
The sampled result.
- Return type:
x (
Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]
)
- Shapes:
t_span: \((T)\), where \(T\) is the number of time steps. batch_size: \((B)\), where \(B\) is the batch size of data, which could be a scalar or a tensor such as \((B1, B2)\). x_0: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the state, which could be a scalar or a tensor such as \((D1, D2)\). condition: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the condition, which could be a scalar or a tensor such as \((D1, D2)\). x: \((T, N, D)\), if extra batch size \(B\) is provided, the shape will be \((T, B, N, D)\). If x_0 is not provided, the shape will be \((T, B, D)\). If x_0 and condition are not provided, the shape will be \((T, D)\).