Shortcuts

Source code for grl.numerical_methods.numerical_solvers.sde_solver

from typing import Any, Dict, List, Tuple, Union

import torch
import torchsde
from torch import nn


class TorchSDE(nn.Module):
    """
    Overview:
        The SDE class for torchsde library, wich is an object with methods `f` and `g` representing the drift and diffusion.
        The output of `g` should be a single tensor of size (batch_size, d) for diagonal noise SDEs or (batch_size, d, m) for SDEs of other noise types,
        where d is the dimensionality of state and m is the dimensionality of Brownian motion.
    """

    def __init__(
        self,
        drift,
        diffusion,
        noise_type,
        sde_type,
    ):
        """
        Overview:
            Initialize the SDE object.
        Arguments:
            drift (:obj:`nn.Module`): The function that defines the drift of the SDE.
            diffusion (:obj:`nn.Module`): The function that defines the diffusion of the SDE.
            noise_type (:obj:`str`): The type of noise of the SDE. It can be 'diagonal', 'general', 'scalar' or 'additive'.
            sde_type (:obj:`str`): The type of the SDE. It can be 'ito' or 'stratonovich'.
        """
        super().__init__()
        self.drift = drift
        self.diffusion = diffusion

        self.noise_type = noise_type
        self.sde_type = sde_type

    def f(self, t, y):
        """
        Overview:
            The drift function of the SDE.
        """
        return self.drift(t, y)

    def g(self, t, y):
        """
        Overview:
            The diffusion function of the SDE.
        """
        return self.diffusion(t, y)


[docs]class SDESolver:
[docs] def __init__( self, sde_solver="euler", sde_noise_type="diagonal", sde_type="ito", dt=0.001, atol=1e-5, rtol=1e-5, library="torchsde", **kwargs, ): """ Overview: Initialize the SDE solver using torchsde library. Arguments: sde_solver (:obj:`str`): The SDE solver to use. sde_noise_type (:obj:`str`): The type of noise of the SDE. It can be 'diagonal', 'general', 'scalar' or 'additive'. sde_type (:obj:`str`): The type of the SDE. It can be 'ito' or 'stratonovich'. dt (:obj:`float`): The time step. atol (:obj:`float`): The absolute tolerance. rtol (:obj:`float`): The relative tolerance. library (:obj:`str`): The library to use for the ODE solver. Currently, it supports 'torchsde'. **kwargs: Additional arguments for the ODE solver. """ super().__init__() self.sde_solver = sde_solver self.sde_noise_type = sde_noise_type self.sde_type = sde_type self.dt = dt self.atol = atol self.rtol = rtol self.nfe_drift = 0 self.nfe_diffusion = 0 self.kwargs = kwargs self.library = library
[docs] def integrate(self, drift, diffusion, x0, t_span, logqp=False, adaptive=False): """ Overview: Integrate the SDE. Arguments: drift (:obj:`nn.Module`): The function that defines the ODE. diffusion (:obj:`nn.Module`): The function that defines the ODE. """ batch_size = x0.shape[0] data_size = x0.shape[1:] self.nfe_drift = 0 self.nfe_diffusion = 0 def forward_drift(t, x): self.nfe_drift += 1 x = x.reshape(batch_size, *data_size) f = drift(t, x) return f.reshape(batch_size, -1) def forward_diffusion(t, x): self.nfe_diffusion += 1 x = x.reshape(batch_size, *data_size) g = diffusion(t, x) return g.reshape(batch_size, -1) sde = TorchSDE( drift=forward_drift, diffusion=forward_diffusion, noise_type=self.sde_noise_type, sde_type=self.sde_type, ) x0 = x0.reshape(batch_size, -1) trajectory = torchsde.sdeint( sde, x0, t_span, method=self.sde_solver, dt=self.dt, rtol=self.rtol, atol=self.atol, logqp=logqp, adaptive=adaptive, **self.kwargs, ) trajectory = trajectory.reshape(t_span.shape[0], batch_size, *data_size) return trajectory