• Docs >
  • grl.generative_models
Shortcuts

grl.generative_models

DiffusionModel

class grl.generative_models.DiffusionModel(config)[source]
Overview:

General diffusion model class that supports various types of continuous-time diffusion paths, which supports sampling, computing score function and velocity function. It can be modeled via score function, noise function, velocity function, or data prediction function. Both score matching loss and flow matching loss are supported.

Interfaces:

__init__, sample, score_function, score_matching_loss, velocity_function, flow_matching_loss.

__init__(config)[source]
Overview:

Initialization of Diffusion Model.

Parameters:

config (EasyDict) – The configuration.

data_prediction_function(t, x, condition=None)[source]
Overview:

Return data prediction function of the model at time t given the initial state.

\[\frac{- \sigma(t) x_t + \sigma^2(t) \nabla_{x_t} \log p_{\theta}(x_t)}{s(t)}\]
Parameters:
  • t (torch.Tensor) – The input time.

  • x (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input state.

  • condition (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input condition.

Return type:

Union[Tensor, TensorDict, Tensor]

dpo_loss(ref_dm, data, beta)[source]
Return type:

Tensor

Overview:

The loss function for training the diffusion process by Direct Policy Optimization (DPO). This is an in-development feature and is not recommended for general use.

flow_matching_loss(x, condition=None, average=True)[source]
Overview:

Return the flow matching loss function of the model given the initial state and the condition.

Parameters:
  • x (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input state.

  • condition (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input condition.

  • average (bool) – Whether to average the loss across the batch.

Return type:

Tensor

forward_sample(x, t_span, condition=None, with_grad=False, solver_config=None)[source]
Overview:

Use forward path of the diffusion model given the sampled x. Note that this is not the reverse process, and thus is not designed for sampling form the diffusion model. Rather, it is used for encode a sampled x to the latent space.

Parameters:
  • x (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input state.

  • t_span (torch.Tensor) – The time span.

  • condition (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input condition.

  • with_grad (bool) – Whether to return the gradient.

  • solver_config (EasyDict) – The configuration of the solver.

forward_sample_process(x, t_span, condition=None, with_grad=False, solver_config=None)[source]
Overview:

Use forward path of the diffusion model given the sampled x. Note that this is not the reverse process, and thus is not designed for sampling form the diffusion model. Rather, it is used for encode a sampled x to the latent space. Return all intermediate states.

Parameters:
  • x (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input state.

  • t_span (torch.Tensor) – The time span.

  • condition (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input condition.

  • with_grad (bool) – Whether to return the gradient.

  • solver_config (EasyDict) – The configuration of the solver.

log_prob(x=None, condition=None, using_Hutchinson_trace_estimator=True, with_grad=False)[source]
Overview:

Return the log probability of the model given the initial state and the condition.

\[\log p_{\theta}(x)\]
Parameters:
  • x (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input state.

  • condition (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input condition.

  • with_grad (bool) – Whether to return the gradient.

noise_function(t, x, condition=None)[source]
Overview:

Return noise function of the model at time t given the initial state.

\[- \sigma(t) \nabla_{x_t} \log p_{\theta}(x_t)\]
Parameters:
  • t (torch.Tensor) – The input time.

  • x (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input state.

  • condition (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input condition.

Return type:

Union[Tensor, TensorDict, Tensor]

sample(t_span=None, batch_size=None, x_0=None, condition=None, with_grad=False, solver_config=None)[source]
Overview:

Sample from the diffusion model, return the final state.

Parameters:
  • t_span (torch.Tensor) – The time span.

  • batch_size (Union[torch.Size, int, Tuple[int], List[int]]) – The batch size.

  • x_0 (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The initial state, if not provided, it will be sampled from the Gaussian distribution.

  • condition (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input condition.

  • with_grad (bool) – Whether to return the gradient.

  • solver_config (EasyDict) – The configuration of the solver.

Returns:

The sampled result.

Return type:

x (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor])

Shapes:

t_span: \((T)\), where \(T\) is the number of time steps. batch_size: \((B)\), where \(B\) is the batch size of data, which could be a scalar or a tensor such as \((B1, B2)\). x_0: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the state, which could be a scalar or a tensor such as \((D1, D2)\). condition: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the condition, which could be a scalar or a tensor such as \((D1, D2)\). x: \((N, D)\), if extra batch size \(B\) is provided, the shape will be \((B, N, D)\). If x_0 is not provided, the shape will be \((B, D)\). If x_0 and condition are not provided, the shape will be \((D)\).

sample_forward_process(t_span=None, batch_size=None, x_0=None, condition=None, with_grad=False, solver_config=None)[source]
Overview:

Sample from the diffusion model, return all intermediate states.

Parameters:
  • t_span (torch.Tensor) – The time span.

  • batch_size (Union[torch.Size, int, Tuple[int], List[int]]) – An extra batch size used for repeated sampling with the same initial state.

  • x_0 (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The initial state, if not provided, it will be sampled from the Gaussian distribution.

  • condition (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input condition.

  • with_grad (bool) – Whether to return the gradient.

  • solver_config (EasyDict) – The configuration of the solver.

Returns:

The sampled result.

Return type:

x (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor])

Shapes:

t_span: \((T)\), where \(T\) is the number of time steps. batch_size: \((B)\), where \(B\) is the batch size of data, which could be a scalar or a tensor such as \((B1, B2)\). x_0: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the state, which could be a scalar or a tensor such as \((D1, D2)\). condition: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the condition, which could be a scalar or a tensor such as \((D1, D2)\). x: \((T, N, D)\), if extra batch size \(B\) is provided, the shape will be \((T, B, N, D)\). If x_0 is not provided, the shape will be \((T, B, D)\). If x_0 and condition are not provided, the shape will be \((T, D)\).

sample_forward_process_with_fixed_x(fixed_x, fixed_mask, t_span=None, batch_size=None, x_0=None, condition=None, with_grad=False, solver_config=None)[source]
Overview:

Sample from the diffusion model with fixed x, return all intermediate states.

Parameters:
  • fixed_x (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The fixed x.

  • fixed_mask (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The fixed mask.

  • t_span (torch.Tensor) – The time span.

  • batch_size (Union[torch.Size, int, Tuple[int], List[int]]) – The batch size.

  • x_0 (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The initial state, if not provided, it will be sampled from the Gaussian distribution.

  • condition (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input condition.

  • with_grad (bool) – Whether to return the gradient.

  • solver_config (EasyDict) – The configuration of the solver.

Returns:

The sampled result.

Return type:

x (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor])

Shapes:

fixed_x: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the state, which could be a scalar or a tensor such as \((D1, D2)\). fixed_mask: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the mask, which could be a scalar or a tensor such as \((D1, D2)\). t_span: \((T)\), where \(T\) is the number of time steps. batch_size: \((B)\), where \(B\) is the batch size of data, which could be a scalar or a tensor such as \((B1, B2)\). x_0: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the state, which could be a scalar or a tensor such as \((D1, D2)\). condition: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the condition, which could be a scalar or a tensor such as \((D1, D2)\). x: \((T, N, D)\), if extra batch size \(B\) is provided, the shape will be \((T, B, N, D)\). If x_0 is not provided, the shape will be \((T, B, D)\). If x_0 and condition are not provided, the shape will be \((T, D)\).

sample_with_fixed_x(fixed_x, fixed_mask, t_span=None, batch_size=None, x_0=None, condition=None, with_grad=False, solver_config=None)[source]
Overview:

Sample from the diffusion model with fixed x, return the final state.

Parameters:
  • fixed_x (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The fixed x.

  • fixed_mask (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The fixed mask.

  • t_span (torch.Tensor) – The time span.

  • batch_size (Union[torch.Size, int, Tuple[int], List[int]]) – The batch size.

  • x_0 (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The initial state, if not provided, it will be sampled from the Gaussian distribution.

  • condition (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input condition.

  • with_grad (bool) – Whether to return the gradient.

  • solver_config (EasyDict) – The configuration of the solver.

Returns:

The sampled result.

Return type:

x (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor])

Shapes:

fixed_x: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the state, which could be a scalar or a tensor such as \((D1, D2)\). fixed_mask: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the mask, which could be a scalar or a tensor such as \((D1, D2)\). t_span: \((T)\), where \(T\) is the number of time steps. batch_size: \((B)\), where \(B\) is the batch size of data, which could be a scalar or a tensor such as \((B1, B2)\). x_0: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the state, which could be a scalar or a tensor such as \((D1, D2)\). condition: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the condition, which could be a scalar or a tensor such as \((D1, D2)\). x: \((N, D)\), if extra batch size \(B\) is provided, the shape will be \((B, N, D)\). If x_0 is not provided, the shape will be \((B, D)\). If x_0 and condition are not provided, the shape will be \((D,)\).

sample_with_log_prob(t_span, batch_size=None, condition=None, with_grad=False, solver_config=None)[source]
Overview:

Sample from the model and return the log probability of the sampled result.

Parameters:
  • t_span (torch.Tensor) – The time span.

  • batch_size (Union[torch.Size, int, Tuple[int], List[int]]) – The batch size.

  • condition (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input condition.

  • with_grad (bool) – Whether to return the gradient.

  • solver_config (EasyDict) – The configuration of the solver.

score_function(t, x, condition=None)[source]
Overview:

Return score function of the model at time t given the initial state, which is the gradient of the log-likelihood.

\[\nabla_{x_t} \log p_{\theta}(x_t)\]
Parameters:
  • t (torch.Tensor) – The input time.

  • x (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input state.

  • condition (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input condition.

Return type:

Union[Tensor, TensorDict, Tensor]

score_matching_loss(x, condition=None, weighting_scheme=None, average=True)[source]
Overview:

The loss function for training unconditional diffusion model.

Parameters:
  • x (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input.

  • condition (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input condition.

  • weighting_scheme (str) – The weighting scheme for score matching loss, which can be “maximum_likelihood” or “vanilla”.

  • ..note::

    • “maximum_likelihood”: The weighting scheme is based on the maximum likelihood estimation. Refer to the paper “Maximum Likelihood Training of Score-Based Diffusion Models” for more details. The weight \(\lambda(t)\) is denoted as:

      \[\lambda(t) = g^2(t)\]

      for numerical stability, we use Monte Carlo sampling to approximate the integral of \(\lambda(t)\).

      \[\lambda(t) = g^2(t) = p(t)\sigma^2(t)\]
    • “vanilla”: The weighting scheme is based on the vanilla score matching, which balances the MSE loss by scaling the model output to the noise value. Refer to the paper “Score-Based Generative Modeling through Stochastic Differential Equations” for more details. The weight \(\lambda(t)\) is denoted as:

      \[\lambda(t) = \sigma^2(t)\]

Return type:

Tensor

velocity_function(t, x, condition=None)[source]
Overview:

Return velocity of the model at time t given the initial state.

\[v_{\theta}(t, x)\]
Parameters:
  • t (torch.Tensor) – The input time.

  • x (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input state at time t.

  • condition (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input condition.

Return type:

Union[Tensor, TensorDict, Tensor]

EnergyConditionalDiffusionModel

class grl.generative_models.EnergyConditionalDiffusionModel(config, energy_model)[source]
Overview:

Energy Conditional Diffusion Model, which is a diffusion model conditioned on energy.

\[p_{\text{E}}(x|c)\sim \frac{\exp{\mathcal{E}(x,c)}}{Z(c)}p(x|c)\]
Interfaces:

__init__, sample, sample_without_energy_guidance, sample_forward_process, score_function, score_function_with_energy_guidance, score_matching_loss, velocity_function, flow_matching_loss, energy_guidance_loss

__init__(config, energy_model)[source]
Overview:

Initialization of Energy Conditional Diffusion Model.

Parameters:
  • config (EasyDict) – The configuration.

  • energy_model (Union[torch.nn.Module, torch.nn.ModuleDict]) – The energy model.

data_prediction_function(t, x, condition=None)[source]
Overview:

Return data prediction function of the model at time t given the initial state.

\[\frac{- \sigma(t) x_t + \sigma^2(t) \nabla_{x_t} \log p_{\theta}(x_t)}{s(t)}\]
Parameters:
  • t (torch.Tensor) – The input time.

  • x (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input state.

  • condition (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input condition.

Return type:

Union[Tensor, TensorDict, Tensor]

data_prediction_function_with_energy_guidance(t, x, condition=None, guidance_scale=1.0)[source]
Overview:

The data prediction function for energy guidance.

Parameters:
  • t (torch.Tensor) – The input time.

  • x (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input state.

  • condition (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input condition.

  • guidance_scale (float) – The scale of guidance.

Returns:

The score function.

Return type:

x (torch.Tensor)

energy_guidance_loss(x, condition=None)[source]
Overview:

The loss function for training Energy Guidance, CEP guidance method, as proposed in the paper “Contrastive Energy Prediction for Exact Energy-Guided Diffusion Sampling in Offline Reinforcement Learning”

Parameters:
  • x (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input.

  • condition (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input condition.

flow_matching_loss(x, condition=None, average=True)[source]
Overview:

Return the flow matching loss function of the model given the initial state and the condition.

Parameters:
  • x (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input state.

  • condition (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input condition.

  • average (bool) – Whether to average the loss across the batch.

Return type:

Tensor

noise_function(t, x, condition=None)[source]
Overview:

Return noise function of the model at time t given the initial state.

\[- \sigma(t) \nabla_{x_t} \log p_{\theta}(x_t)\]
Parameters:
  • t (torch.Tensor) – The input time.

  • x (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input state.

  • condition (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input condition.

Return type:

Union[Tensor, TensorDict, Tensor]

noise_function_with_energy_guidance(t, x, condition=None, guidance_scale=1.0)[source]
Overview:

The noise function for energy guidance.

Parameters:
  • t (torch.Tensor) – The input time.

  • x (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input state.

  • condition (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input condition.

  • guidance_scale (float) – The scale of guidance.

Returns:

The nose function.

Return type:

noise (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor])

sample(t_span=None, batch_size=None, x_0=None, condition=None, guidance_scale=1.0, with_grad=False, solver_config=None)[source]
Overview:

Sample from the energy conditioned diffusion model by using score function.

\[\nabla p_{\text{E}}(x|c) = \nabla p(x|c) + \nabla \mathcal{E}(x,c,t)\]
Parameters:
  • t_span (torch.Tensor) – The time span.

  • batch_size (Union[torch.Size, int, Tuple[int], List[int]]) – The batch size.

  • x_0 (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The initial state, if not provided, it will be sampled from the Gaussian distribution.

  • condition (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input condition.

  • guidance_scale (float) – The scale of guidance.

  • with_grad (bool) – Whether to return the gradient.

  • solver_config (EasyDict) – The configuration of the solver.

Returns:

The sampled result.

Return type:

x (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor])

Shapes:

t_span: \((T)\), where \(T\) is the number of time steps. batch_size: \((B)\), where \(B\) is the batch size of data, which could be a scalar or a tensor such as \((B1, B2)\). x_0: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the state, which could be a scalar or a tensor such as \((D1, D2)\). condition: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the condition, which could be a scalar or a tensor such as \((D1, D2)\). x: \((N, D)\), if extra batch size \(B\) is provided, the shape will be \((B, N, D)\). If x_0 is not provided, the shape will be \((B, D)\). If x_0 and condition are not provided, the shape will be \((D)\).

sample_forward_process(t_span=None, batch_size=None, x_0=None, condition=None, guidance_scale=1.0, with_grad=False, solver_config=None)[source]
Overview:

Sample from the diffusion model.

Parameters:
  • t_span (torch.Tensor) – The time span.

  • batch_size (Union[torch.Size, int, Tuple[int], List[int]]) – The batch size.

  • x_0 (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The initial state, if not provided, it will be sampled from the Gaussian distribution.

  • condition (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input condition.

  • guidance_scale (float) – The scale of guidance.

  • with_grad (bool) – Whether to return the gradient.

  • solver_config (EasyDict) – The configuration of the solver.

Returns:

The sampled result.

Return type:

x (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor])

sample_forward_process_with_fixed_x(fixed_x, fixed_mask, t_span=None, batch_size=None, x_0=None, condition=None, guidance_scale=1.0, with_grad=False, solver_config=None)[source]
Overview:

Sample from the diffusion model with fixed x.

Parameters:
  • fixed_x (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The fixed x.

  • fixed_mask (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The fixed mask.

  • t_span (torch.Tensor) – The time span.

  • batch_size (Union[torch.Size, int, Tuple[int], List[int]]) – The batch size.

  • x_0 (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The initial state, if not provided, it will be sampled from the Gaussian distribution.

  • condition (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input condition.

  • guidance_scale (float) – The scale of guidance.

  • with_grad (bool) – Whether to return the gradient.

  • solver_config (EasyDict) – The configuration of the solver.

Returns:

The sampled result.

Return type:

x (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor])

Shapes:

fixed_x: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the state, which could be a scalar or a tensor such as \((D1, D2)\). fixed_mask: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the mask, which could be a scalar or a tensor such as \((D1, D2)\). t_span: \((T)\), where \(T\) is the number of time steps. batch_size: \((B)\), where \(B\) is the batch size of data, which could be a scalar or a tensor such as \((B1, B2)\). x_0: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the state, which could be a scalar or a tensor such as \((D1, D2)\). condition: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the condition, which could be a scalar or a tensor such as \((D1, D2)\). x: \((T, N, D)\), if extra batch size \(B\) is provided, the shape will be \((T, B, N, D)\). If x_0 is not provided, the shape will be \((T, B, D)\). If x_0 and condition are not provided, the shape will be \((T, D)\).

sample_with_fixed_x(fixed_x, fixed_mask, t_span=None, batch_size=None, x_0=None, condition=None, guidance_scale=1.0, with_grad=False, solver_config=None)[source]
Overview:

Sample from the diffusion model with fixed x.

Parameters:
  • fixed_x (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The fixed x.

  • fixed_mask (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The fixed mask.

  • t_span (torch.Tensor) – The time span.

  • batch_size (Union[torch.Size, int, Tuple[int], List[int]]) – The batch size.

  • x_0 (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The initial state, if not provided, it will be sampled from the Gaussian distribution.

  • condition (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input condition.

  • guidance_scale (float) – The scale of guidance.

  • with_grad (bool) – Whether to return the gradient.

  • solver_config (EasyDict) – The configuration of the solver.

Returns:

The sampled result.

Return type:

x (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor])

Shapes:

fixed_x: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the state, which could be a scalar or a tensor such as \((D1, D2)\). fixed_mask: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the mask, which could be a scalar or a tensor such as \((D1, D2)\). t_span: \((T)\), where \(T\) is the number of time steps. batch_size: \((B)\), where \(B\) is the batch size of data, which could be a scalar or a tensor such as \((B1, B2)\). x_0: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the state, which could be a scalar or a tensor such as \((D1, D2)\). condition: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the condition, which could be a scalar or a tensor such as \((D1, D2)\). x: \((N, D)\), if extra batch size \(B\) is provided, the shape will be \((B, N, D)\). If x_0 is not provided, the shape will be \((B, D)\). If x_0 and condition are not provided, the shape will be \((D,)\).

sample_with_fixed_x_without_energy_guidance(fixed_x, fixed_mask, t_span=None, batch_size=None, x_0=None, condition=None, with_grad=False, solver_config=None)[source]
Overview:

Sample from the diffusion model with fixed x without energy guidance.

Parameters:
  • fixed_x (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The fixed x.

  • fixed_mask (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The fixed mask.

  • t_span (torch.Tensor) – The time span.

  • batch_size (Union[torch.Size, int, Tuple[int], List[int]]) – The batch size.

  • x_0 (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The initial state, if not provided, it will be sampled from the Gaussian distribution.

  • condition (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input condition.

  • with_grad (bool) – Whether to return the gradient.

  • solver_config (EasyDict) – The configuration of the solver.

Returns:

The sampled result.

Return type:

x (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor])

Shapes:

fixed_x: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the state, which could be a scalar or a tensor such as \((D1, D2)\). fixed_mask: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the mask, which could be a scalar or a tensor such as \((D1, D2)\). t_span: \((T)\), where \(T\) is the number of time steps. batch_size: \((B)\), where \(B\) is the batch size of data, which could be a scalar or a tensor such as \((B1, B2)\). x_0: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the state, which could be a scalar or a tensor such as \((D1, D2)\). condition: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the condition, which could be a scalar or a tensor such as \((D1, D2)\). x: \((N, D)\), if extra batch size \(B\) is provided, the shape will be \((B, N, D)\). If x_0 is not provided, the shape will be \((B, D)\). If x_0 and condition are not provided, the shape will be \((D,)\).

sample_without_energy_guidance(t_span=None, batch_size=None, x_0=None, condition=None, with_grad=False, solver_config=None)[source]
Overview:

Sample from the diffusion model without energy guidance.

Parameters:
  • t_span (torch.Tensor) – The time span.

  • batch_size (Union[torch.Size, int, Tuple[int], List[int]]) – The batch size.

  • x_0 (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The initial state, if not provided, it will be sampled from the Gaussian distribution.

  • condition (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input condition.

  • with_grad (bool) – Whether to return the gradient.

  • solver_config (EasyDict) – The configuration of the solver.

Returns:

The sampled result.

Return type:

x (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor])

Shapes:

t_span: \((T)\), where \(T\) is the number of time steps. batch_size: \((B)\), where \(B\) is the batch size of data, which could be a scalar or a tensor such as \((B1, B2)\). x_0: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the state, which could be a scalar or a tensor such as \((D1, D2)\). condition: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the condition, which could be a scalar or a tensor such as \((D1, D2)\). x: \((T, N, D)\), if extra batch size \(B\) is provided, the shape will be \((T, B, N, D)\). If x_0 is not provided, the shape will be \((T, B, D)\). If x_0 and condition are not provided, the shape will be \((T, D)\).

score_function(t, x, condition=None)[source]
Overview:

Return score function of the model at time t given the initial state, which is the gradient of the log-likelihood.

\[\nabla_{x_t} \log p_{\theta}(x_t)\]
Parameters:
  • t (torch.Tensor) – The input time.

  • x (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input state.

  • condition (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input condition.

Return type:

Union[Tensor, TensorDict, Tensor]

score_function_with_energy_guidance(t, x, condition=None, guidance_scale=1.0)[source]
Overview:

The score function for energy guidance.

Parameters:
  • t (torch.Tensor) – The input time.

  • x (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input state.

  • condition (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input condition.

  • guidance_scale (float) – The scale of guidance.

Returns:

The score function.

Return type:

score (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor])

score_matching_loss(x, condition=None, weighting_scheme=None, average=True)[source]
Overview:

The loss function for training unconditional diffusion model.

Parameters:
  • x (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input.

  • condition (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input condition.

  • weighting_scheme (str) – The weighting scheme for score matching loss, which can be “maximum_likelihood” or “vanilla”.

  • average (bool) – Whether to average the loss across the batch.

  • ..note::

    • “maximum_likelihood”: The weighting scheme is based on the maximum likelihood estimation. Refer to the paper “Maximum Likelihood Training of Score-Based Diffusion Models” for more details. The weight \(\lambda(t)\) is denoted as:

      \[\lambda(t) = g^2(t)\]

      for numerical stability, we use Monte Carlo sampling to approximate the integral of \(\lambda(t)\).

      \[\lambda(t) = g^2(t) = p(t)\sigma^2(t)\]
    • “vanilla”: The weighting scheme is based on the vanilla score matching, which balances the MSE loss by scaling the model output to the noise value. Refer to the paper “Score-Based Generative Modeling through Stochastic Differential Equations” for more details. The weight \(\lambda(t)\) is denoted as:

      \[\lambda(t) = \sigma^2(t)\]

Return type:

Tensor

velocity_function(t, x, condition=None)[source]
Overview:

Return velocity of the model at time t given the initial state.

\[v_{\theta}(t, x)\]
Parameters:
  • t (torch.Tensor) – The input time.

  • x (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input state at time t.

  • condition (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input condition.

Return type:

Union[Tensor, TensorDict, Tensor]

IndependentConditionalFlowModel

class grl.generative_models.IndependentConditionalFlowModel(config)[source]
Overview:

The independent conditional flow model, which is a flow model with independent conditional probability paths.

Interfaces:

__init__, get_type, sample, sample_forward_process, flow_matching_loss

__init__(config)[source]
Overview:

Initialize the model.

Parameters:

config (EasyDict) – The configuration of the model.

flow_matching_loss(x0, x1, condition=None, average=True)[source]
Overview:

Return the flow matching loss function of the model given the initial state and the condition.

Parameters:
  • x0 (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The initial state.

  • x1 (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The final state.

  • condition (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The condition for the flow matching loss.

  • average (bool) – Whether to average the loss across the batch.

Return type:

Tensor

flow_matching_loss_with_mask(x0, x1, condition=None, mask=None, average=True)[source]
Overview:

Return the flow matching loss function of the model given the initial state and the condition with a mask.

Parameters:
  • x0 (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The initial state.

  • x1 (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The final state.

  • condition (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The condition for the flow matching loss.

  • mask (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The mask signal for x0, which is either True or False and has same shape as x0, if it is True, the corresponding element in x0 will not be used for the loss computation, and the true value of that element is usually not provided in condition.

  • average (bool) – Whether to average the loss across the batch.

Return type:

Tensor

forward_sample(x, t_span, condition=None, with_grad=False, solver_config=None)[source]
Overview:

Use forward path of the flow model given the sampled x. Note that this is not the reverse process, and thus is not designed for sampling form the flow model. Rather, it is used for encode a sampled x to the latent space.

Parameters:
  • x (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input state.

  • t_span (torch.Tensor) – The time span.

  • condition (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input condition.

  • with_grad (bool) – Whether to return the gradient.

  • solver_config (EasyDict) – The configuration of the solver.

forward_sample_process(x, t_span, condition=None, with_grad=False, solver_config=None)[source]
Overview:

Use forward path of the diffusion model given the sampled x. Note that this is not the reverse process, and thus is not designed for sampling form the diffusion model. Rather, it is used for encode a sampled x to the latent space. Return all intermediate states.

Parameters:
  • x (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input state.

  • t_span (torch.Tensor) – The time span.

  • condition (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input condition.

  • with_grad (bool) – Whether to return the gradient.

  • solver_config (EasyDict) – The configuration of the solver.

log_prob(x_1, log_prob_x_0=None, function_log_prob_x_0=None, condition=None, t=None, using_Hutchinson_trace_estimator=True)[source]
Overview:

Compute the log probability of the final state given the initial state and the condition.

Parameters:
  • x_1 (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The final state.

  • log_prob_x_0 (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The log probability of the initial state.

  • function_log_prob_x_0 (Union[callable, nn.Module]) – The function to compute the log probability of the initial state.

  • condition (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The condition.

  • t (torch.Tensor) – The time span.

  • using_Hutchinson_trace_estimator (bool) – Whether to use Hutchinson trace estimator. It is an approximation of the trace of the Jacobian of the drift function, which is faster but less accurate. We recommend setting it to True for high dimensional data.

Returns:

The log likelihood of the final state given the initial state and the condition.

Return type:

log_likelihood (torch.Tensor)

optimal_transport_flow_matching_loss(x0, x1, condition=None, average=True)[source]
Overview:

Return the flow matching loss function of the model given the initial state and the condition, using the optimal transport plan to match samples from two distributions.

Parameters:
  • x (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input state.

  • condition (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input condition.

Return type:

Tensor

sample(t_span=None, batch_size=None, x_0=None, condition=None, with_grad=False, solver_config=None)[source]
Overview:

Sample from the model, return the final state.

Parameters:
  • t_span (torch.Tensor) – The time span.

  • batch_size (Union[torch.Size, int, Tuple[int], List[int]]) – The batch size.

  • x_0 (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The initial state, if not provided, it will be sampled from the Gaussian distribution.

  • condition (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input condition.

  • with_grad (bool) – Whether to return the gradient.

  • solver_config (EasyDict) – The configuration of the solver.

Returns:

The sampled result.

Return type:

x (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor])

Shapes:

t_span: \((T)\), where \(T\) is the number of time steps. batch_size: \((B)\), where \(B\) is the batch size of data, which could be a scalar or a tensor such as \((B1, B2)\). x_0: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the state, which could be a scalar or a tensor such as \((D1, D2)\). condition: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the condition, which could be a scalar or a tensor such as \((D1, D2)\). x: \((N, D)\), if extra batch size \(B\) is provided, the shape will be \((B, N, D)\). If x_0 is not provided, the shape will be \((B, D)\). If x_0 and condition are not provided, the shape will be \((D)\).

sample_forward_process(t_span=None, batch_size=None, x_0=None, condition=None, with_grad=False, solver_config=None)[source]
Overview:

Sample from the diffusion model, return all intermediate states.

Parameters:
  • t_span (torch.Tensor) – The time span.

  • batch_size (Union[torch.Size, int, Tuple[int], List[int]]) – An extra batch size used for repeated sampling with the same initial state.

  • x_0 (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The initial state, if not provided, it will be sampled from the Gaussian distribution.

  • condition (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input condition.

  • with_grad (bool) – Whether to return the gradient.

  • solver_config (EasyDict) – The configuration of the solver.

Returns:

The sampled result.

Return type:

x (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor])

Shapes:

t_span: \((T)\), where \(T\) is the number of time steps. batch_size: \((B)\), where \(B\) is the batch size of data, which could be a scalar or a tensor such as \((B1, B2)\). x_0: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the state, which could be a scalar or a tensor such as \((D1, D2)\). condition: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the condition, which could be a scalar or a tensor such as \((D1, D2)\). x: \((T, N, D)\), if extra batch size \(B\) is provided, the shape will be \((T, B, N, D)\). If x_0 is not provided, the shape will be \((T, B, D)\). If x_0 and condition are not provided, the shape will be \((T, D)\).

sample_with_log_prob(t_span=None, batch_size=None, x_0=None, condition=None, with_grad=False, solver_config=None, using_Hutchinson_trace_estimator=True)[source]
Overview:

Sample from the model, return the final state and the log probability of the initial state.

Parameters:
  • t_span (torch.Tensor) – The time span.

  • batch_size (Union[torch.Size, int, Tuple[int], List[int]]) – The batch size.

  • x_0 (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The initial state, if not provided, it will be sampled from the Gaussian distribution.

  • condition (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input condition.

  • with_grad (bool) – Whether to return the gradient.

  • solver_config (EasyDict) – The configuration of the solver.

  • using_Hutchinson_trace_estimator (bool) – Whether to use Hutchinson trace estimator. It is an approximation of the trace of the Jacobian of the drift function, which is faster but less accurate. We recommend setting it to True for high dimensional data.

Returns:

The sampled result. log_prob_x_0 (torch.Tensor): The log probability of the initial state.

Return type:

x (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor])

sample_with_mask(t_span=None, batch_size=None, x_0=None, condition=None, mask=None, x_1=None, with_grad=False, solver_config=None)[source]
Overview:

Sample from the model with masked element, return the final state.

Parameters:
  • t_span (torch.Tensor) – The time span.

  • batch_size (Union[torch.Size, int, Tuple[int], List[int]]) – The batch size.

  • x_0 (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The initial state, if not provided, it will be sampled from the Gaussian distribution.

  • condition (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input condition.

  • mask (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The mask.

  • x_1 (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The masked element, same shape as x_1.

  • with_grad (bool) – Whether to return the gradient.

  • solver_config (EasyDict) – The configuration of the solver.

Returns:

The sampled result.

Return type:

x (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor])

Shapes:

t_span: \((T)\), where \(T\) is the number of time steps. batch_size: \((B)\), where \(B\) is the batch size of data, which could be a scalar or a tensor such as \((B1, B2)\). x_0: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the state, which could be a scalar or a tensor such as \((D1, D2)\). condition: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the condition, which could be a scalar or a tensor such as \((D1, D2)\). mask: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the mask, which could be a scalar or a tensor such as \((D1, D2)\). x_1: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the masked element, which could be a scalar or a tensor such as \((D1, D2)\). x: \((N, D)\), if extra batch size \(B\) is provided, the shape will be \((B, N, D)\). If x_0 is not provided, the shape will be \((B, D)\). If x_0 and condition are not provided, the shape will be \((D)\).

sample_with_mask_forward_process(t_span=None, batch_size=None, x_0=None, condition=None, mask=None, x_1=None, with_grad=False, solver_config=None)[source]
Overview:

Sample from the diffusion model, return all intermediate states.

Parameters:
  • t_span (torch.Tensor) – The time span.

  • batch_size (Union[torch.Size, int, Tuple[int], List[int]]) – An extra batch size used for repeated sampling with the same initial state.

  • x_0 (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The initial state, if not provided, it will be sampled from the Gaussian distribution.

  • condition (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input condition.

  • mask (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The mask.

  • x_1 (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The masked element, same shape as x_1.

  • with_grad (bool) – Whether to return the gradient.

  • solver_config (EasyDict) – The configuration of the solver.

Returns:

The sampled result.

Return type:

x (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor])

Shapes:

t_span: \((T)\), where \(T\) is the number of time steps. batch_size: \((B)\), where \(B\) is the batch size of data, which could be a scalar or a tensor such as \((B1, B2)\). x_0: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the state, which could be a scalar or a tensor such as \((D1, D2)\). condition: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the condition, which could be a scalar or a tensor such as \((D1, D2)\). mask: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the mask, which could be a scalar or a tensor such as \((D1, D2)\). x_1: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the masked element, which could be a scalar or a tensor such as \((D1, D2)\). x: \((T, N, D)\), if extra batch size \(B\) is provided, the shape will be \((T, B, N, D)\). If x_0 is not provided, the shape will be \((T, B, D)\). If x_0 and condition are not provided, the shape will be \((T, D)\).

OptimalTransportConditionalFlowModel

class grl.generative_models.OptimalTransportConditionalFlowModel(config)[source]
Overview:

The optimal transport conditional flow model, which is based on an optimal transport plan between two distributions.

Interfaces:

__init__, get_type, sample, sample_forward_process, flow_matching_loss

__init__(config)[source]
Overview:

Initialize the model.

Parameters:

config (-) – The configuration of the model.

flow_matching_loss(x0, x1, condition=None, average=True)[source]
Overview:

Return the flow matching loss function of the model given the initial state and the condition, using the optimal transport plan to match samples from two distributions.

Parameters:
  • x (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input state.

  • condition (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input condition.

Return type:

Tensor

flow_matching_loss_small_batch_OT_plan(x0, x1, condition=None, average=True)[source]
Overview:

Return the flow matching loss function of the model given the initial state and the condition, using the optimal transport plan for small batch size to accelerate the computation.

Parameters:
  • x (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input state.

  • condition (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input condition.

Return type:

Tensor

sample(t_span=None, batch_size=None, x_0=None, condition=None, with_grad=False, solver_config=None)[source]
Overview:

Sample from the model, return the final state.

Parameters:
  • t_span (torch.Tensor) – The time span.

  • batch_size (Union[torch.Size, int, Tuple[int], List[int]]) – The batch size.

  • x_0 (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The initial state, if not provided, it will be sampled from the Gaussian distribution.

  • condition (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input condition.

  • with_grad (bool) – Whether to return the gradient.

  • solver_config (EasyDict) – The configuration of the solver.

Returns:

The sampled result.

Return type:

x (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor])

Shapes:

t_span: \((T)\), where \(T\) is the number of time steps. batch_size: \((B)\), where \(B\) is the batch size of data, which could be a scalar or a tensor such as \((B1, B2)\). x_0: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the state, which could be a scalar or a tensor such as \((D1, D2)\). condition: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the condition, which could be a scalar or a tensor such as \((D1, D2)\). x: \((N, D)\), if extra batch size \(B\) is provided, the shape will be \((B, N, D)\). If x_0 is not provided, the shape will be \((B, D)\). If x_0 and condition are not provided, the shape will be \((D)\).

sample_forward_process(t_span=None, batch_size=None, x_0=None, condition=None, with_grad=False, solver_config=None)[source]
Overview:

Sample from the model, return all intermediate states.

Parameters:
  • t_span (torch.Tensor) – The time span.

  • batch_size (Union[torch.Size, int, Tuple[int], List[int]]) – An extra batch size used for repeated sampling with the same initial state.

  • x_0 (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The initial state, if not provided, it will be sampled from the Gaussian distribution.

  • condition (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]) – The input condition.

  • with_grad (bool) – Whether to return the gradient.

  • solver_config (EasyDict) – The configuration of the solver.

Returns:

The sampled result.

Return type:

x (Union[torch.Tensor, TensorDict, treetensor.torch.Tensor])

Shapes:

t_span: \((T)\), where \(T\) is the number of time steps. batch_size: \((B)\), where \(B\) is the batch size of data, which could be a scalar or a tensor such as \((B1, B2)\). x_0: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the state, which could be a scalar or a tensor such as \((D1, D2)\). condition: \((N, D)\), where \(N\) is the batch size of data and \(D\) is the dimension of the condition, which could be a scalar or a tensor such as \((D1, D2)\). x: \((T, N, D)\), if extra batch size \(B\) is provided, the shape will be \((T, B, N, D)\). If x_0 is not provided, the shape will be \((T, B, D)\). If x_0 and condition are not provided, the shape will be \((T, D)\).