Source code for grl.numerical_methods.sde
from typing import Callable, Union
from torch import nn
[docs]class SDE:
"""
Overview:
Base class for stochastic differential equations.
The SDE is defined as:
.. math::
dx = f(x, t)dt + g(x, t)dW
where f(x, t) is the drift term, g(x, t) is the diffusion term, and dW is the Wiener process.
Interfaces:
``__init__``
"""
[docs] def __init__(
self,
drift: Union[nn.Module, Callable] = None,
diffusion: Union[nn.Module, Callable] = None,
):
self.drift = drift
self.diffusion = diffusion