• Docs >
  • Module code >
  • grl.generative_models.diffusion_model.energy_conditional_diffusion_model
Shortcuts

Source code for grl.generative_models.diffusion_model.energy_conditional_diffusion_model

from typing import Any, Callable, Dict, List, Tuple, Union

import torch
import torch.nn as nn
import treetensor
from easydict import EasyDict
from tensordict import TensorDict

from grl.generative_models.diffusion_process import DiffusionProcess
from grl.generative_models.intrinsic_model import IntrinsicModel
from grl.generative_models.model_functions.data_prediction_function import (
    DataPredictionFunction,
)
from grl.generative_models.model_functions.noise_function import NoiseFunction
from grl.generative_models.model_functions.score_function import ScoreFunction
from grl.generative_models.model_functions.velocity_function import VelocityFunction
from grl.generative_models.random_generator import gaussian_random_variable
from grl.numerical_methods.numerical_solvers import get_solver
from grl.numerical_methods.numerical_solvers.dpm_solver import DPMSolver
from grl.numerical_methods.numerical_solvers.ode_solver import (
    DictTensorODESolver,
    ODESolver,
)
from grl.numerical_methods.numerical_solvers.sde_solver import SDESolver
from grl.numerical_methods.probability_path import GaussianConditionalProbabilityPath
from grl.utils import find_parameters


class EnergyGuidance(nn.Module):
    """
    Overview:
        Energy Guidance for Energy Conditional Diffusion Model.
    Interfaces:
        ``__init__``, ``forward``, ``calculate_energy_guidance``
    """

    def __init__(self, config: EasyDict):
        """
        Overview:
            Initialization of Energy Guidance.

        Arguments:
            config (:obj:`EasyDict`): The configuration.
        """
        super().__init__()
        self.config = config
        self.model = IntrinsicModel(self.config)

    def forward(
        self,
        t: torch.Tensor,
        x: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor],
        condition: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        Overview:
            Return output of Energy Guidance.

        Arguments:
            t (:obj:`torch.Tensor`): The input time.
            x (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The input.
            condition (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The input condition.
        """

        return self.model(t, x, condition)

    def calculate_energy_guidance(
        self,
        t: torch.Tensor,
        x: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor],
        condition: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None,
        guidance_scale: float = 1.0,
    ) -> Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]:
        """
        Overview:
            Calculate the guidance for sampling.

        Arguments:
            t (:obj:`torch.Tensor`): The input time.
            x (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The input.
            condition (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The input condition.
            guidance_scale (:obj:`float`): The scale of guidance.
        Returns:
            guidance (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The guidance for sampling.
        """

        # TODO: make it compatible with TensorDict
        with torch.enable_grad():
            x.requires_grad_(True)
            x_t = self.forward(t, x, condition)
            guidance = guidance_scale * torch.autograd.grad(torch.sum(x_t), x)[0]
        return guidance.detach()


[docs]class EnergyConditionalDiffusionModel(nn.Module): r""" Overview: Energy Conditional Diffusion Model, which is a diffusion model conditioned on energy. .. math:: p_{\text{E}}(x|c)\sim \frac{\exp{\mathcal{E}(x,c)}}{Z(c)}p(x|c) Interfaces: ``__init__``, ``sample``, ``sample_without_energy_guidance``, ``sample_forward_process``, ``score_function``, \ ``score_function_with_energy_guidance``, ``score_matching_loss``, ``velocity_function``, ``flow_matching_loss``, \ ``energy_guidance_loss`` """
[docs] def __init__( self, config: EasyDict, energy_model: Union[torch.nn.Module, torch.nn.ModuleDict, Callable], ) -> None: """ Overview: Initialization of Energy Conditional Diffusion Model. Arguments: config (:obj:`EasyDict`): The configuration. energy_model (:obj:`Union[torch.nn.Module, torch.nn.ModuleDict]`): The energy model. """ super().__init__() self.config = config self.x_size = config.x_size self.device = config.device self.alpha = config.alpha self.gaussian_generator = gaussian_random_variable(config.x_size, config.device) self.path = GaussianConditionalProbabilityPath(config.path) if hasattr(config, "reverse_path"): self.reverse_path = GaussianConditionalProbabilityPath(config.reverse_path) else: self.reverse_path = None self.model_type = config.model.type assert self.model_type in [ "score_function", "data_prediction_function", "velocity_function", "noise_function", ], "Unknown type of model {}".format(self.model_type) self.model = IntrinsicModel(config.model.args) self.diffusion_process = DiffusionProcess(self.path) if self.reverse_path is not None: self.reverse_diffusion_process = DiffusionProcess(self.reverse_path) else: self.reverse_diffusion_process = None self.score_function_ = ScoreFunction(self.model_type, self.diffusion_process) self.velocity_function_ = VelocityFunction( self.model_type, self.diffusion_process ) self.noise_function_ = NoiseFunction(self.model_type, self.diffusion_process) self.data_prediction_function_ = DataPredictionFunction( self.model_type, self.diffusion_process ) self.energy_model = energy_model self.energy_guidance = EnergyGuidance(self.config.energy_guidance) if hasattr(config, "solver"): self.solver = get_solver(config.solver.type)(**config.solver.args)
def get_type(self): return "EnergyConditionalDiffusionModel"
[docs] def sample( self, t_span: torch.Tensor = None, batch_size: Union[torch.Size, int, Tuple[int], List[int]] = None, x_0: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, condition: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, guidance_scale: float = 1.0, with_grad: bool = False, solver_config: EasyDict = None, ): r""" Overview: Sample from the energy conditioned diffusion model by using score function. .. math:: \nabla p_{\text{E}}(x|c) = \nabla p(x|c) + \nabla \mathcal{E}(x,c,t) Arguments: t_span (:obj:`torch.Tensor`): The time span. batch_size (:obj:`Union[torch.Size, int, Tuple[int], List[int]]`): The batch size. x_0 (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The initial state, if not provided, it will be sampled from the Gaussian distribution. condition (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The input condition. guidance_scale (:obj:`float`): The scale of guidance. with_grad (:obj:`bool`): Whether to return the gradient. solver_config (:obj:`EasyDict`): The configuration of the solver. Returns: x (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The sampled result. Shapes: t_span: :math:`(T)`, where :math:`T` is the number of time steps. batch_size: :math:`(B)`, where :math:`B` is the batch size of data, which could be a scalar or a tensor such as :math:`(B1, B2)`. x_0: :math:`(N, D)`, where :math:`N` is the batch size of data and :math:`D` is the dimension of the state, which could be a scalar or a tensor such as :math:`(D1, D2)`. condition: :math:`(N, D)`, where :math:`N` is the batch size of data and :math:`D` is the dimension of the condition, which could be a scalar or a tensor such as :math:`(D1, D2)`. x: :math:`(N, D)`, if extra batch size :math:`B` is provided, the shape will be :math:`(B, N, D)`. If x_0 is not provided, the shape will be :math:`(B, D)`. If x_0 and condition are not provided, the shape will be :math:`(D)`. """ return self.sample_forward_process( t_span=t_span, batch_size=batch_size, x_0=x_0, condition=condition, guidance_scale=guidance_scale, with_grad=with_grad, solver_config=solver_config, )[-1]
[docs] def sample_without_energy_guidance( self, t_span: torch.Tensor = None, batch_size: Union[torch.Size, int, Tuple[int], List[int]] = None, x_0: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, condition: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, with_grad: bool = False, solver_config: EasyDict = None, ): """ Overview: Sample from the diffusion model without energy guidance. Arguments: t_span (:obj:`torch.Tensor`): The time span. batch_size (:obj:`Union[torch.Size, int, Tuple[int], List[int]]`): The batch size. x_0 (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The initial state, if not provided, it will be sampled from the Gaussian distribution. condition (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The input condition. with_grad (:obj:`bool`): Whether to return the gradient. solver_config (:obj:`EasyDict`): The configuration of the solver. Returns: x (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The sampled result. Shapes: t_span: :math:`(T)`, where :math:`T` is the number of time steps. batch_size: :math:`(B)`, where :math:`B` is the batch size of data, which could be a scalar or a tensor such as :math:`(B1, B2)`. x_0: :math:`(N, D)`, where :math:`N` is the batch size of data and :math:`D` is the dimension of the state, which could be a scalar or a tensor such as :math:`(D1, D2)`. condition: :math:`(N, D)`, where :math:`N` is the batch size of data and :math:`D` is the dimension of the condition, which could be a scalar or a tensor such as :math:`(D1, D2)`. x: :math:`(T, N, D)`, if extra batch size :math:`B` is provided, the shape will be :math:`(T, B, N, D)`. If x_0 is not provided, the shape will be :math:`(T, B, D)`. If x_0 and condition are not provided, the shape will be :math:`(T, D)`. """ return self.sample( t_span=t_span, batch_size=batch_size, x_0=x_0, condition=condition, with_grad=with_grad, guidance_scale=0.0, solver_config=solver_config, )
[docs] def sample_forward_process( self, t_span: torch.Tensor = None, batch_size: Union[torch.Size, int, Tuple[int], List[int]] = None, x_0: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, condition: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, guidance_scale: float = 1.0, with_grad: bool = False, solver_config: EasyDict = None, ): """ Overview: Sample from the diffusion model. Arguments: t_span (:obj:`torch.Tensor`): The time span. batch_size (:obj:`Union[torch.Size, int, Tuple[int], List[int]]`): The batch size. x_0 (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The initial state, if not provided, it will be sampled from the Gaussian distribution. condition (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The input condition. guidance_scale (:obj:`float`): The scale of guidance. with_grad (:obj:`bool`): Whether to return the gradient. solver_config (:obj:`EasyDict`): The configuration of the solver. Returns: x (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The sampled result. """ if t_span is not None: t_span = t_span.to(self.device) if batch_size is None: extra_batch_size = torch.tensor((1,), device=self.device) elif isinstance(batch_size, int): extra_batch_size = torch.tensor((batch_size,), device=self.device) else: if ( isinstance(batch_size, torch.Size) or isinstance(batch_size, Tuple) or isinstance(batch_size, List) ): extra_batch_size = torch.tensor(batch_size, device=self.device) else: assert False, "Invalid batch size" if x_0 is not None and condition is not None: assert ( x_0.shape[0] == condition.shape[0] ), "The batch size of x_0 and condition must be the same" data_batch_size = x_0.shape[0] elif x_0 is not None: data_batch_size = x_0.shape[0] elif condition is not None: data_batch_size = condition.shape[0] else: data_batch_size = 1 if solver_config is not None: solver = get_solver(solver_config.type)(**solver_config.args) else: assert hasattr( self, "solver" ), "solver must be specified in config or solver_config" solver = self.solver if x_0 is None: x = self.gaussian_generator( batch_size=torch.prod(extra_batch_size) * data_batch_size ) # x.shape = (B*N, D) else: if isinstance(self.x_size, int): assert ( torch.Size([self.x_size]) == x_0[0].shape ), "The shape of x_0 must be the same as the x_size that is specified in the config" elif ( isinstance(self.x_size, Tuple) or isinstance(self.x_size, List) or isinstance(self.x_size, torch.Size) ): assert ( torch.Size(self.x_size) == x_0[0].shape ), "The shape of x_0 must be the same as the x_size that is specified in the config" else: assert False, "Invalid x_size" x = torch.repeat_interleave(x_0, torch.prod(extra_batch_size), dim=0) # x.shape = (B*N, D) if condition is not None: if isinstance(condition, torch.Tensor): condition = torch.repeat_interleave( condition, torch.prod(extra_batch_size), dim=0 ) # condition.shape = (B*N, D) elif isinstance(condition, treetensor.torch.Tensor): for key in condition.keys(): condition[key] = torch.repeat_interleave( condition[key], torch.prod(extra_batch_size), dim=0 ) # condition.shape = (B*N, D) elif isinstance(condition, TensorDict): condition = TensorDict( { key: torch.repeat_interleave( condition[key], torch.prod(extra_batch_size), dim=0 ) for key in condition.keys() }, batch_size=torch.prod(extra_batch_size) * condition.shape, device=condition.device, ) else: raise NotImplementedError("Not implemented") if isinstance(solver, DPMSolver): def noise_function_with_energy_guidance(t, x, condition): return self.noise_function_with_energy_guidance( t, x, condition, guidance_scale ) def data_prediction_function_with_energy_guidance(t, x, condition): return self.data_prediction_function_with_energy_guidance( t, x, condition, guidance_scale ) # Note: DPMSolver does not support t_span argument assignment assert ( t_span is None ), "DPMSolver does not support t_span argument assignment" # TODO: make it compatible with TensorDict if with_grad: data = solver.integrate( diffusion_process=self.diffusion_process, noise_function=noise_function_with_energy_guidance, data_prediction_function=data_prediction_function_with_energy_guidance, x=x, condition=condition, save_intermediate=True, ) else: with torch.no_grad(): data = solver.integrate( diffusion_process=self.diffusion_process, noise_function=noise_function_with_energy_guidance, data_prediction_function=data_prediction_function_with_energy_guidance, x=x, condition=condition, save_intermediate=True, ) elif isinstance(solver, ODESolver): def score_function_with_energy_guidance(t, x, condition): # for SDE solver, the shape of t is (,) while for ODE solver, the shape of t is (B*N,) return self.score_function_with_energy_guidance( t, x, condition, guidance_scale ) # TODO: make it compatible with TensorDict if with_grad: data = solver.integrate( drift=self.diffusion_process.reverse_ode( function=score_function_with_energy_guidance, function_type="score_function", condition=condition, ).drift, x0=x, t_span=t_span, ) else: with torch.no_grad(): data = solver.integrate( drift=self.diffusion_process.reverse_ode( function=score_function_with_energy_guidance, function_type="score_function", condition=condition, ).drift, x0=x, t_span=t_span, ) elif isinstance(solver, DictTensorODESolver): def score_function_with_energy_guidance(t, x, condition): # for SDE solver, the shape of t is (,) while for ODE solver, the shape of t is (B*N,) return self.score_function_with_energy_guidance( t, x, condition, guidance_scale ) # TODO: make it compatible with TensorDict if with_grad: data = solver.integrate( drift=self.diffusion_process.reverse_ode( function=score_function_with_energy_guidance, function_type="score_function", condition=condition, ).drift, x0=x, t_span=t_span, batch_size=torch.prod(extra_batch_size) * data_batch_size, x_size=x.shape, ) else: with torch.no_grad(): data = solver.integrate( drift=self.diffusion_process.reverse_ode( function=score_function_with_energy_guidance, function_type="score_function", condition=condition, ).drift, x0=x, t_span=t_span, batch_size=torch.prod(extra_batch_size) * data_batch_size, x_size=x.shape, ) elif isinstance(solver, SDESolver): def score_function_with_energy_guidance(t, x, condition): # for SDE solver, the shape of t is (,) while for ODE solver, the shape of t is (B*N,) if len(t.shape) == 0: t = t.repeat(x.shape[0]) return self.score_function_with_energy_guidance( t, x, condition, guidance_scale ) # TODO: make it compatible with TensorDict # TODO: validate the implementation assert ( self.reverse_diffusion_process is not None ), "reverse_path must be specified in config" sde = self.diffusion_process.reverse_sde( function=score_function_with_energy_guidance, function_type="score_function", condition=condition, reverse_diffusion_function=self.reverse_diffusion_process.diffusion, reverse_diffusion_squared_function=self.reverse_diffusion_process.diffusion_squared, ) if with_grad: data = solver.integrate( drift=sde.drift, diffusion=sde.diffusion, x0=x, t_span=t_span, ) else: with torch.no_grad(): data = solver.integrate( drift=sde.drift, diffusion=sde.diffusion, x0=x, t_span=t_span, ) else: raise NotImplementedError( "Solver type {} is not implemented".format(self.config.solver.type) ) if isinstance(data, torch.Tensor): # data.shape = (T, B*N, D) if len(extra_batch_size.shape) == 0: if isinstance(self.x_size, int): data = data.reshape( -1, extra_batch_size, data_batch_size, self.x_size ) elif ( isinstance(self.x_size, Tuple) or isinstance(self.x_size, List) or isinstance(self.x_size, torch.Size) ): data = data.reshape( -1, extra_batch_size, data_batch_size, *self.x_size ) else: assert False, "Invalid x_size" else: if isinstance(self.x_size, int): data = data.reshape( -1, *extra_batch_size, data_batch_size, self.x_size ) elif ( isinstance(self.x_size, Tuple) or isinstance(self.x_size, List) or isinstance(self.x_size, torch.Size) ): data = data.reshape( -1, *extra_batch_size, data_batch_size, *self.x_size ) else: assert False, "Invalid x_size" # data.shape = (T, B, N, D) if batch_size is None: if x_0 is None and condition is None: data = data.squeeze(1).squeeze(1) # data.shape = (T, D) else: data = data.squeeze(1) # data.shape = (T, N, D) else: if x_0 is None and condition is None: data = data.squeeze(1 + len(extra_batch_size.shape)) # data.shape = (T, B, D) else: # data.shape = (T, B, N, D) pass elif isinstance(data, TensorDict): raise NotImplementedError("TensorDict is not supported yet") elif isinstance(data, treetensor.torch.Tensor): for key in data.keys(): if len(extra_batch_size.shape) == 0: if isinstance(self.x_size, int): data[key] = data[key].reshape( -1, extra_batch_size, data_batch_size, self.x_size ) elif ( isinstance(self.x_size, Tuple) or isinstance(self.x_size, List) or isinstance(self.x_size, torch.Size) ): data[key] = data[key].reshape( -1, extra_batch_size, data_batch_size, *self.x_size ) else: assert False, "Invalid x_size" else: if isinstance(self.x_size, int): data[key] = data[key].reshape( -1, *extra_batch_size, data_batch_size, self.x_size ) elif ( isinstance(self.x_size, Tuple) or isinstance(self.x_size, List) or isinstance(self.x_size, torch.Size) ): data[key] = data[key].reshape( -1, *extra_batch_size, data_batch_size, *self.x_size ) else: assert False, "Invalid x_size" # data.shape = (T, B, N, D) if batch_size is None: if x_0 is None and condition is None: data[key] = data[key].squeeze(1).squeeze(1) # data.shape = (T, D) else: data[key] = data[key].squeeze(1) # data.shape = (T, N, D) else: if x_0 is None and condition is None: data[key] = data[key].squeeze(1 + len(extra_batch_size.shape)) # data.shape = (T, B, D) else: # data.shape = (T, B, N, D) pass else: raise NotImplementedError("Unknown data type") return data
[docs] def sample_with_fixed_x( self, fixed_x: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor], fixed_mask: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor], t_span: torch.Tensor = None, batch_size: Union[torch.Size, int, Tuple[int], List[int]] = None, x_0: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, condition: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, guidance_scale: float = 1.0, with_grad: bool = False, solver_config: EasyDict = None, ): """ Overview: Sample from the diffusion model with fixed x. Arguments: fixed_x (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The fixed x. fixed_mask (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The fixed mask. t_span (:obj:`torch.Tensor`): The time span. batch_size (:obj:`Union[torch.Size, int, Tuple[int], List[int]]`): The batch size. x_0 (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The initial state, if not provided, it will be sampled from the Gaussian distribution. condition (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The input condition. guidance_scale (:obj:`float`): The scale of guidance. with_grad (:obj:`bool`): Whether to return the gradient. solver_config (:obj:`EasyDict`): The configuration of the solver. Returns: x (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The sampled result. Shapes: fixed_x: :math:`(N, D)`, where :math:`N` is the batch size of data and :math:`D` is the dimension of the state, which could be a scalar or a tensor such as :math:`(D1, D2)`. fixed_mask: :math:`(N, D)`, where :math:`N` is the batch size of data and :math:`D` is the dimension of the mask, which could be a scalar or a tensor such as :math:`(D1, D2)`. t_span: :math:`(T)`, where :math:`T` is the number of time steps. batch_size: :math:`(B)`, where :math:`B` is the batch size of data, which could be a scalar or a tensor such as :math:`(B1, B2)`. x_0: :math:`(N, D)`, where :math:`N` is the batch size of data and :math:`D` is the dimension of the state, which could be a scalar or a tensor such as :math:`(D1, D2)`. condition: :math:`(N, D)`, where :math:`N` is the batch size of data and :math:`D` is the dimension of the condition, which could be a scalar or a tensor such as :math:`(D1, D2)`. x: :math:`(N, D)`, if extra batch size :math:`B` is provided, the shape will be :math:`(B, N, D)`. If x_0 is not provided, the shape will be :math:`(B, D)`. If x_0 and condition are not provided, the shape will be :math:`(D,)`. """ return self.sample_forward_process_with_fixed_x( fixed_x=fixed_x, fixed_mask=fixed_mask, t_span=t_span, batch_size=batch_size, x_0=x_0, condition=condition, guidance_scale=guidance_scale, with_grad=with_grad, solver_config=solver_config, )[-1]
[docs] def sample_forward_process_with_fixed_x( self, fixed_x: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor], fixed_mask: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor], t_span: torch.Tensor = None, batch_size: Union[torch.Size, int, Tuple[int], List[int]] = None, x_0: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, condition: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, guidance_scale: float = 1.0, with_grad: bool = False, solver_config: EasyDict = None, ): """ Overview: Sample from the diffusion model with fixed x. Arguments: fixed_x (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The fixed x. fixed_mask (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The fixed mask. t_span (:obj:`torch.Tensor`): The time span. batch_size (:obj:`Union[torch.Size, int, Tuple[int], List[int]]`): The batch size. x_0 (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The initial state, if not provided, it will be sampled from the Gaussian distribution. condition (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The input condition. guidance_scale (:obj:`float`): The scale of guidance. with_grad (:obj:`bool`): Whether to return the gradient. solver_config (:obj:`EasyDict`): The configuration of the solver. Returns: x (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The sampled result. Shapes: fixed_x: :math:`(N, D)`, where :math:`N` is the batch size of data and :math:`D` is the dimension of the state, which could be a scalar or a tensor such as :math:`(D1, D2)`. fixed_mask: :math:`(N, D)`, where :math:`N` is the batch size of data and :math:`D` is the dimension of the mask, which could be a scalar or a tensor such as :math:`(D1, D2)`. t_span: :math:`(T)`, where :math:`T` is the number of time steps. batch_size: :math:`(B)`, where :math:`B` is the batch size of data, which could be a scalar or a tensor such as :math:`(B1, B2)`. x_0: :math:`(N, D)`, where :math:`N` is the batch size of data and :math:`D` is the dimension of the state, which could be a scalar or a tensor such as :math:`(D1, D2)`. condition: :math:`(N, D)`, where :math:`N` is the batch size of data and :math:`D` is the dimension of the condition, which could be a scalar or a tensor such as :math:`(D1, D2)`. x: :math:`(T, N, D)`, if extra batch size :math:`B` is provided, the shape will be :math:`(T, B, N, D)`. If x_0 is not provided, the shape will be :math:`(T, B, D)`. If x_0 and condition are not provided, the shape will be :math:`(T, D)`. """ if t_span is not None: t_span = t_span.to(self.device) if batch_size is None: extra_batch_size = torch.tensor((1,), device=self.device) elif isinstance(batch_size, int): extra_batch_size = torch.tensor((batch_size,), device=self.device) else: if ( isinstance(batch_size, torch.Size) or isinstance(batch_size, Tuple) or isinstance(batch_size, List) ): extra_batch_size = torch.tensor(batch_size, device=self.device) else: assert False, "Invalid batch size" data_batch_size = fixed_x.shape[0] assert ( fixed_x.shape[0] == fixed_mask.shape[0] ), "The batch size of fixed_x and fixed_mask must be the same" if x_0 is not None and condition is not None: assert ( x_0.shape[0] == condition.shape[0] ), "The batch size of x_0 and condition must be the same" assert ( x_0.shape[0] == fixed_x.shape[0] ), "The batch size of x_0 and fixed_x must be the same" elif x_0 is not None: assert ( x_0.shape[0] == fixed_x.shape[0] ), "The batch size of x_0 and fixed_x must be the same" elif condition is not None: assert ( condition.shape[0] == fixed_x.shape[0] ), "The batch size of condition and fixed_x must be the same" else: pass if solver_config is not None: solver = get_solver(solver_config.type)(**solver_config.args) else: assert hasattr( self, "solver" ), "solver must be specified in config or solver_config" solver = self.solver if x_0 is None: x = self.gaussian_generator( batch_size=torch.prod(extra_batch_size) * data_batch_size ) # x.shape = (B*N, D) else: if isinstance(self.x_size, int): assert ( torch.Size([self.x_size]) == x_0[0].shape ), "The shape of x_0 must be the same as the x_size that is specified in the config" elif ( isinstance(self.x_size, Tuple) or isinstance(self.x_size, List) or isinstance(self.x_size, torch.Size) ): assert ( torch.Size(self.x_size) == x_0[0].shape ), "The shape of x_0 must be the same as the x_size that is specified in the config" else: assert False, "Invalid x_size" x = torch.repeat_interleave(x_0, torch.prod(extra_batch_size), dim=0) # x.shape = (B*N, D) if condition is not None: condition = torch.repeat_interleave( condition, torch.prod(extra_batch_size), dim=0 ) # condition.shape = (B*N, D) fixed_x = torch.repeat_interleave(fixed_x, torch.prod(extra_batch_size), dim=0) fixed_mask = torch.repeat_interleave( fixed_mask, torch.prod(extra_batch_size), dim=0 ) if isinstance(solver, DPMSolver): def noise_function_with_energy_guidance(t, x, condition): return self.noise_function_with_energy_guidance( t, x, condition, guidance_scale ) def data_prediction_function_with_energy_guidance(t, x, condition): return self.data_prediction_function_with_energy_guidance( t, x, condition, guidance_scale ) # TODO: make it compatible with DPM solver assert False, "Not implemented" elif isinstance(solver, ODESolver): def score_function_with_energy_guidance(t, x, condition): # for SDE solver, the shape of t is (,) while for ODE solver, the shape of t is (B*N,) return self.score_function_with_energy_guidance( t, x, condition, guidance_scale ) # TODO: make it compatible with TensorDict x = fixed_x * (1 - fixed_mask) + x * fixed_mask def drift_fixed_x(t, x): xt_partially_fixed = ( self.diffusion_process.direct_sample( self.diffusion_process.t_max - t, fixed_x ) * (1 - fixed_mask) + x * fixed_mask ) return fixed_mask * self.diffusion_process.reverse_ode( function=score_function_with_energy_guidance, function_type="score_function", condition=condition, ).drift(t, xt_partially_fixed) if with_grad: data = solver.integrate( drift=drift_fixed_x, x0=x, t_span=t_span, ) else: with torch.no_grad(): data = solver.integrate( drift=drift_fixed_x, x0=x, t_span=t_span, ) elif isinstance(solver, DictTensorODESolver): def score_function_with_energy_guidance(t, x, condition): # for SDE solver, the shape of t is (,) while for ODE solver, the shape of t is (B*N,) return self.score_function_with_energy_guidance( t, x, condition, guidance_scale ) # TODO: make it compatible with TensorDict x = fixed_x * (1 - fixed_mask) + x * fixed_mask def drift_fixed_x(t, x): xt_partially_fixed = ( self.diffusion_process.direct_sample( self.diffusion_process.t_max - t, fixed_x ) * (1 - fixed_mask) + x * fixed_mask ) return fixed_mask * self.diffusion_process.reverse_ode( function=score_function_with_energy_guidance, function_type="score_function", condition=condition, ).drift(t, xt_partially_fixed) if with_grad: data = solver.integrate( drift=drift_fixed_x, x0=x, t_span=t_span, batch_size=torch.prod(extra_batch_size) * data_batch_size, x_size=x.shape, ) else: with torch.no_grad(): data = solver.integrate( drift=drift_fixed_x, x0=x, t_span=t_span, batch_size=torch.prod(extra_batch_size) * data_batch_size, x_size=x.shape, ) elif isinstance(solver, SDESolver): def score_function_with_energy_guidance(t, x, condition): # for SDE solver, the shape of t is (,) while for ODE solver, the shape of t is (B*N,) if len(t.shape) == 0: t = t.repeat(x.shape[0]) return self.score_function_with_energy_guidance( t, x, condition, guidance_scale ) # TODO: make it compatible with TensorDict # TODO: validate the implementation assert ( self.reverse_diffusion_process is not None ), "reverse_path must be specified in config" x = fixed_x * (1 - fixed_mask) + x * fixed_mask sde = self.diffusion_process.reverse_sde( function=score_function_with_energy_guidance, function_type="score_function", condition=condition, reverse_diffusion_function=self.reverse_diffusion_process.diffusion, reverse_diffusion_squared_function=self.reverse_diffusion_process.diffusion_squared, ) def drift_fixed_x(t, x): xt_partially_fixed = ( self.diffusion_process.direct_sample( self.diffusion_process.t_max - t, fixed_x ) * (1 - fixed_mask) + x * fixed_mask ) return fixed_mask * sde.drift(t, xt_partially_fixed) def diffusion_fixed_x(t, x): xt_partially_fixed = ( self.diffusion_process.direct_sample( self.diffusion_process.t_max - t, fixed_x ) * (1 - fixed_mask) + x * fixed_mask ) return fixed_mask * sde.diffusion(t, xt_partially_fixed) if with_grad: data = solver.integrate( drift=drift_fixed_x, diffusion=diffusion_fixed_x, x0=x, t_span=t_span, ) else: with torch.no_grad(): data = solver.integrate( drift=drift_fixed_x, diffusion=diffusion_fixed_x, x0=x, t_span=t_span, ) else: raise NotImplementedError( "Solver type {} is not implemented".format(self.config.solver.type) ) if isinstance(data, torch.Tensor): # data.shape = (T, B*N, D) if len(extra_batch_size.shape) == 0: if isinstance(self.x_size, int): data = data.reshape( -1, extra_batch_size, data_batch_size, self.x_size ) elif ( isinstance(self.x_size, Tuple) or isinstance(self.x_size, List) or isinstance(self.x_size, torch.Size) ): data = data.reshape( -1, extra_batch_size, data_batch_size, *self.x_size ) else: assert False, "Invalid x_size" else: if isinstance(self.x_size, int): data = data.reshape( -1, *extra_batch_size, data_batch_size, self.x_size ) elif ( isinstance(self.x_size, Tuple) or isinstance(self.x_size, List) or isinstance(self.x_size, torch.Size) ): data = data.reshape( -1, *extra_batch_size, data_batch_size, *self.x_size ) else: assert False, "Invalid x_size" # data.shape = (T, B, N, D) if batch_size is None: data = data.squeeze(1) # data.shape = (T, N, D) else: # data.shape = (T, B, N, D) pass elif isinstance(data, TensorDict): raise NotImplementedError("TensorDict is not supported yet") elif isinstance(data, treetensor.torch.Tensor): for key in data.keys(): if len(extra_batch_size.shape) == 0: if isinstance(self.x_size, int): data[key] = data[key].reshape( -1, extra_batch_size, data_batch_size, self.x_size ) elif ( isinstance(self.x_size, Tuple) or isinstance(self.x_size, List) or isinstance(self.x_size, torch.Size) ): data[key] = data[key].reshape( -1, extra_batch_size, data_batch_size, *self.x_size ) else: assert False, "Invalid x_size" else: if isinstance(self.x_size, int): data[key] = data[key].reshape( -1, *extra_batch_size, data_batch_size, self.x_size ) elif ( isinstance(self.x_size, Tuple) or isinstance(self.x_size, List) or isinstance(self.x_size, torch.Size) ): data[key] = data[key].reshape( -1, *extra_batch_size, data_batch_size, *self.x_size ) else: assert False, "Invalid x_size" # data.shape = (T, B, N, D) if batch_size is None: data[key] = data[key].squeeze(1) # data.shape = (T, N, D) else: # data.shape = (T, B, N, D) pass else: raise NotImplementedError("Unknown data type") return data
[docs] def sample_with_fixed_x_without_energy_guidance( self, fixed_x: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor], fixed_mask: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor], t_span: torch.Tensor = None, batch_size: Union[torch.Size, int, Tuple[int], List[int]] = None, x_0: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, condition: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, with_grad: bool = False, solver_config: EasyDict = None, ): """ Overview: Sample from the diffusion model with fixed x without energy guidance. Arguments: fixed_x (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The fixed x. fixed_mask (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The fixed mask. t_span (:obj:`torch.Tensor`): The time span. batch_size (:obj:`Union[torch.Size, int, Tuple[int], List[int]]`): The batch size. x_0 (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The initial state, if not provided, it will be sampled from the Gaussian distribution. condition (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The input condition. with_grad (:obj:`bool`): Whether to return the gradient. solver_config (:obj:`EasyDict`): The configuration of the solver. Returns: x (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The sampled result. Shapes: fixed_x: :math:`(N, D)`, where :math:`N` is the batch size of data and :math:`D` is the dimension of the state, which could be a scalar or a tensor such as :math:`(D1, D2)`. fixed_mask: :math:`(N, D)`, where :math:`N` is the batch size of data and :math:`D` is the dimension of the mask, which could be a scalar or a tensor such as :math:`(D1, D2)`. t_span: :math:`(T)`, where :math:`T` is the number of time steps. batch_size: :math:`(B)`, where :math:`B` is the batch size of data, which could be a scalar or a tensor such as :math:`(B1, B2)`. x_0: :math:`(N, D)`, where :math:`N` is the batch size of data and :math:`D` is the dimension of the state, which could be a scalar or a tensor such as :math:`(D1, D2)`. condition: :math:`(N, D)`, where :math:`N` is the batch size of data and :math:`D` is the dimension of the condition, which could be a scalar or a tensor such as :math:`(D1, D2)`. x: :math:`(N, D)`, if extra batch size :math:`B` is provided, the shape will be :math:`(B, N, D)`. If x_0 is not provided, the shape will be :math:`(B, D)`. If x_0 and condition are not provided, the shape will be :math:`(D,)`. """ return self.sample_with_fixed_x( fixed_x=fixed_x, fixed_mask=fixed_mask, t_span=t_span, batch_size=batch_size, x_0=x_0, condition=condition, guidance_scale=0.0, with_grad=with_grad, solver_config=solver_config, )
[docs] def score_function( self, t: torch.Tensor, x: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor], condition: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, ) -> Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]: r""" Overview: Return score function of the model at time t given the initial state, which is the gradient of the log-likelihood. .. math:: \nabla_{x_t} \log p_{\theta}(x_t) Arguments: t (:obj:`torch.Tensor`): The input time. x (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The input state. condition (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The input condition. """ return self.score_function_.forward(self.model, t, x, condition)
[docs] def score_function_with_energy_guidance( self, t: torch.Tensor, x: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor], condition: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, guidance_scale: float = 1.0, ) -> Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]: """ Overview: The score function for energy guidance. Arguments: t (:obj:`torch.Tensor`): The input time. x (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The input state. condition (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The input condition. guidance_scale (:obj:`float`): The scale of guidance. Returns: score (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The score function. """ return self.score_function_.forward( self.model, t, x, condition ) + self.energy_guidance.calculate_energy_guidance( t, x, condition, guidance_scale )
[docs] def score_matching_loss( self, x: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor], condition: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, weighting_scheme: str = None, average: bool = True, ) -> torch.Tensor: """ Overview: The loss function for training unconditional diffusion model. Arguments: x (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The input. condition (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The input condition. weighting_scheme (:obj:`str`): The weighting scheme for score matching loss, which can be "maximum_likelihood" or "vanilla". average (:obj:`bool`): Whether to average the loss across the batch. ..note:: - "maximum_likelihood": The weighting scheme is based on the maximum likelihood estimation. Refer to the paper "Maximum Likelihood Training of Score-Based Diffusion Models" for more details. The weight :math:`\lambda(t)` is denoted as: .. math:: \lambda(t) = g^2(t) for numerical stability, we use Monte Carlo sampling to approximate the integral of :math:`\lambda(t)`. .. math:: \lambda(t) = g^2(t) = p(t)\sigma^2(t) - "vanilla": The weighting scheme is based on the vanilla score matching, which balances the MSE loss by scaling the model output to the noise value. Refer to the paper "Score-Based Generative Modeling through Stochastic Differential Equations" for more details. The weight :math:`\lambda(t)` is denoted as: .. math:: \lambda(t) = \sigma^2(t) """ return self.score_function_.score_matching_loss( self.model, x, condition, self.gaussian_generator, weighting_scheme, average )
[docs] def velocity_function( self, t: torch.Tensor, x: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor], condition: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, ) -> Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]: r""" Overview: Return velocity of the model at time t given the initial state. .. math:: v_{\theta}(t, x) Arguments: t (:obj:`torch.Tensor`): The input time. x (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The input state at time t. condition (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The input condition. """ return self.velocity_function_.forward(self.model, t, x, condition)
[docs] def flow_matching_loss( self, x: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor], condition: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, average: bool = True, ) -> torch.Tensor: """ Overview: Return the flow matching loss function of the model given the initial state and the condition. Arguments: x (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The input state. condition (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The input condition. average (:obj:`bool`): Whether to average the loss across the batch. """ return self.velocity_function_.flow_matching_loss( self.model, x, condition, self.gaussian_generator, average )
[docs] def noise_function( self, t: torch.Tensor, x: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor], condition: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, ) -> Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]: r""" Overview: Return noise function of the model at time t given the initial state. .. math:: - \sigma(t) \nabla_{x_t} \log p_{\theta}(x_t) Arguments: t (:obj:`torch.Tensor`): The input time. x (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The input state. condition (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The input condition. """ return self.noise_function_.forward(self.model, t, x, condition)
[docs] def noise_function_with_energy_guidance( self, t: torch.Tensor, x: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor], condition: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, guidance_scale: float = 1.0, ) -> Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]: """ Overview: The noise function for energy guidance. Arguments: t (:obj:`torch.Tensor`): The input time. x (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The input state. condition (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The input condition. guidance_scale (:obj:`float`): The scale of guidance. Returns: noise (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The nose function. """ return -self.score_function_with_energy_guidance( t, x, condition, guidance_scale ) * self.diffusion_process.std(t, x)
[docs] def data_prediction_function( self, t: torch.Tensor, x: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor], condition: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, ) -> Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]: r""" Overview: Return data prediction function of the model at time t given the initial state. .. math:: \frac{- \sigma(t) x_t + \sigma^2(t) \nabla_{x_t} \log p_{\theta}(x_t)}{s(t)} Arguments: t (:obj:`torch.Tensor`): The input time. x (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The input state. condition (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The input condition. """ return self.data_prediction_function_.forward(self.model, t, x, condition)
[docs] def data_prediction_function_with_energy_guidance( self, t: torch.Tensor, x: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor], condition: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, guidance_scale: float = 1.0, ) -> Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]: """ Overview: The data prediction function for energy guidance. Arguments: t (:obj:`torch.Tensor`): The input time. x (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The input state. condition (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The input condition. guidance_scale (:obj:`float`): The scale of guidance. Returns: x (:obj:`torch.Tensor`): The score function. """ return ( -self.diffusion_process.std(t, x) * x + self.diffusion_process.covariance(t, x) * self.score_function_with_energy_guidance(t, x, condition, guidance_scale) ) / self.diffusion_process.scale(t, x)
[docs] def energy_guidance_loss( self, x: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor], condition: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, ): """ Overview: The loss function for training Energy Guidance, CEP guidance method, as proposed in the paper \ "Contrastive Energy Prediction for Exact Energy-Guided Diffusion Sampling in Offline Reinforcement Learning" Arguments: x (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The input. condition (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The input condition. """ # TODO: check math correctness # TODO: make it compatible with TensorDict # TODO: check eps = 1e-3 eps = 1e-3 t_random = torch.rand((x.shape[0],), device=self.device) * (1.0 - eps) + eps t_random = torch.stack([t_random] * x.shape[1], dim=1) if condition is not None: if isinstance(condition, TensorDict): condition_repeat_reshape = TensorDict( {}, batch_size=[x.shape[0] * x.shape[1]] ).to(x.device) for key, value in condition.items(): if isinstance(value, torch.Tensor): value_repeat = torch.stack([value] * x.shape[1], axis=1) value_repeat = value_repeat.reshape( value_repeat.shape[0] * value_repeat.shape[1], *value_repeat.shape[2:] ) condition_repeat_reshape.set(key, value_repeat) else: condition_repeat = torch.stack([condition] * x.shape[1], axis=1) condition_repeat_reshape = condition_repeat.reshape( condition_repeat.shape[0] * condition_repeat.shape[1], *condition_repeat.shape[2:] ) x_reshape = x.reshape(x.shape[0] * x.shape[1], *x.shape[2:]) energy = self.energy_model(x_reshape, condition_repeat_reshape).detach() energy = energy.reshape(x.shape[0], x.shape[1]).squeeze(dim=-1) else: x_reshape = x.reshape(x.shape[0] * x.shape[1], *x.shape[2:]) energy = self.energy_model(x_reshape).detach() energy = energy.reshape(x.shape[0], x.shape[1]).squeeze(dim=-1) x_t = self.diffusion_process.direct_sample(t_random, x, condition) if condition is not None: if isinstance(condition, TensorDict): condition_repeat_reshape_new = TensorDict( {}, batch_size=[x.shape[0] * x.shape[1]] ).to(x.device) for key, value in condition.items(): value_repeat = torch.stack([value] * x_t.shape[1], axis=1) value_reshape = value_repeat.reshape( value_repeat.shape[0] * value_repeat.shape[1], *value_repeat.shape[2:] ) condition_repeat_reshape_new.set(key, value_reshape) else: condition_repeat = torch.stack([condition] * x_t.shape[1], axis=1) condition_repeat_reshape_new = condition_repeat.reshape( condition_repeat.shape[0] * condition_repeat.shape[1], *condition_repeat.shape[2:] ) x_t_reshape = x_t.reshape(x_t.shape[0] * x_t.shape[1], *x_t.shape[2:]) t_random_reshape = t_random.reshape(t_random.shape[0] * t_random.shape[1]) xt_energy_guidance = self.energy_guidance( t_random_reshape, x_t_reshape, condition_repeat_reshape_new ) xt_energy_guidance = xt_energy_guidance.reshape( x_t.shape[0], x_t.shape[1] ).squeeze(dim=-1) else: # xt_energy_guidance = self.energy_guidance(t_random, x_t).squeeze(dim=-1) x_t_reshape = x_t.reshape(x_t.shape[0] * x_t.shape[1], *x_t.shape[2:]) t_random_reshape = t_random.reshape(t_random.shape[0] * t_random.shape[1]) xt_energy_guidance = self.energy_guidance(t_random_reshape, x_t_reshape) xt_energy_guidance = xt_energy_guidance.reshape( x_t.shape[0], x_t.shape[1] ).squeeze(dim=-1) log_xt_relative_energy = nn.LogSoftmax(dim=1)(xt_energy_guidance) x0_relative_energy = nn.Softmax(dim=1)(energy * self.alpha) loss = -torch.mean( torch.sum(x0_relative_energy * log_xt_relative_energy, axis=-1) ) return loss