Source code for grl.numerical_methods.ode from typing import Callable, Union from torch import nn [docs]class ODE: """ Overview: Base class for ordinary differential equations. The ODE is defined as: .. math:: dx = f(x, t)dt where f(x, t) is the drift term. Interfaces: ``__init__`` """ [docs] def __init__( self, drift: Union[nn.Module, Callable] = None, ): self.drift = drift