Shortcuts

Source code for ding.model.template.ebm

"""
Vanilla DFO and EBM are adapted from https://github.com/kevinzakka/ibc.
MCMC is adapted from https://github.com/google-research/ibc.
"""
from typing import Callable, Tuple
from functools import wraps

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from abc import ABC, abstractmethod

from ding.utils import MODEL_REGISTRY, STOCHASTIC_OPTIMIZER_REGISTRY
from ding.torch_utils import unsqueeze_repeat
from ding.model.wrapper import IModelWrapper
from ding.model.common import RegressionHead


def create_stochastic_optimizer(device: str, stochastic_optimizer_config: dict):
    """
    Overview:
        Create stochastic optimizer.
    Arguments:
        - device (:obj:`str`): Device.
        - stochastic_optimizer_config (:obj:`dict`): Stochastic optimizer config.
    """
    return STOCHASTIC_OPTIMIZER_REGISTRY.build(
        stochastic_optimizer_config.pop("type"), device=device, **stochastic_optimizer_config
    )


def no_ebm_grad():
    """Wrapper that disables energy based model gradients"""

    def ebm_disable_grad_wrapper(func: Callable):

        @wraps(func)
        def wrapper(*args, **kwargs):
            ebm = args[-1]
            assert isinstance(ebm, (IModelWrapper, nn.Module)),\
                   'Make sure ebm is the last positional arguments.'
            ebm.requires_grad_(False)
            result = func(*args, **kwargs)
            ebm.requires_grad_(True)
            return result

        return wrapper

    return ebm_disable_grad_wrapper


class StochasticOptimizer(ABC):
    """
    Overview:
        Base class for stochastic optimizers.
    Interface:
        ``__init__``, ``_sample``, ``_get_best_action_sample``, ``set_action_bounds``, ``sample``, ``infer``
    """

    def _sample(self, obs: torch.Tensor, num_samples: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Overview:
            Drawing action samples from the uniform random distribution \
                and tiling observations to the same shape as action samples.
        Arguments:
            - obs (:obj:`torch.Tensor`): Observation.
            - num_samples (:obj:`int`): The number of negative samples.
        Returns:
            - tiled_obs (:obj:`torch.Tensor`): Observations tiled.
            - action (:obj:`torch.Tensor`): Action sampled.
        Shapes:
            - obs (:obj:`torch.Tensor`): :math:`(B, O)`.
            - num_samples (:obj:`int`): :math:`N`.
            - tiled_obs (:obj:`torch.Tensor`): :math:`(B, N, O)`.
            - action (:obj:`torch.Tensor`): :math:`(B, N, A)`.
        Examples:
            >>> obs = torch.randn(2, 4)
            >>> opt = StochasticOptimizer()
            >>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0))
            >>> tiled_obs, action = opt._sample(obs, 8)
        """
        size = (obs.shape[0], num_samples, self.action_bounds.shape[1])
        low, high = self.action_bounds[0, :], self.action_bounds[1, :]
        action_samples = low + (high - low) * torch.rand(size).to(self.device)
        tiled_obs = unsqueeze_repeat(obs, num_samples, 1)
        return tiled_obs, action_samples

    @staticmethod
    @torch.no_grad()
    def _get_best_action_sample(obs: torch.Tensor, action_samples: torch.Tensor, ebm: nn.Module):
        """
        Overview:
            Return one action for each batch with highest probability (lowest energy).
        Arguments:
            - obs (:obj:`torch.Tensor`): Observation.
            - action_samples (:obj:`torch.Tensor`): Action from uniform distributions.
        Returns:
            - best_action_samples (:obj:`torch.Tensor`): Best action.
        Shapes:
            - obs (:obj:`torch.Tensor`): :math:`(B, O)`.
            - action_samples (:obj:`torch.Tensor`): :math:`(B, N, A)`.
            - best_action_samples (:obj:`torch.Tensor`): :math:`(B, A)`.
        Examples:
            >>> obs = torch.randn(2, 4)
            >>> action_samples = torch.randn(2, 8, 5)
            >>> ebm = EBM(4, 5)
            >>> opt = StochasticOptimizer()
            >>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0))
            >>> best_action_samples = opt._get_best_action_sample(obs, action_samples, ebm)
        """
        # (B, N)
        energies = ebm.forward(obs, action_samples)
        probs = F.softmax(-1.0 * energies, dim=-1)
        # (B, )
        best_idxs = probs.argmax(dim=-1)
        return action_samples[torch.arange(action_samples.size(0)), best_idxs]

    def set_action_bounds(self, action_bounds: np.ndarray):
        """
        Overview:
            Set action bounds calculated from the dataset statistics.
        Arguments:
            - action_bounds (:obj:`np.ndarray`): Array of shape (2, A), \
                where action_bounds[0] is lower bound and action_bounds[1] is upper bound.
        Returns:
            - action_bounds (:obj:`torch.Tensor`): Action bounds.
        Shapes:
            - action_bounds (:obj:`np.ndarray`): :math:`(2, A)`.
            - action_bounds (:obj:`torch.Tensor`): :math:`(2, A)`.
        Examples:
            >>> opt = StochasticOptimizer()
            >>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0))
        """
        self.action_bounds = torch.as_tensor(action_bounds, dtype=torch.float32).to(self.device)

    @abstractmethod
    def sample(self, obs: torch.Tensor, ebm: nn.Module) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Overview:
            Create tiled observations and sample counter-negatives for InfoNCE loss.
        Arguments:
            - obs (:obj:`torch.Tensor`): Observations.
            - ebm (:obj:`torch.nn.Module`): Energy based model.
        Returns:
            - tiled_obs (:obj:`torch.Tensor`): Tiled observations.
            - action (:obj:`torch.Tensor`): Actions.
        Shapes:
            - obs (:obj:`torch.Tensor`): :math:`(B, O)`.
            - ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`.
            - tiled_obs (:obj:`torch.Tensor`): :math:`(B, N, O)`.
            - action (:obj:`torch.Tensor`): :math:`(B, N, A)`.

        .. note:: In the case of derivative-free optimization, this function will simply call _sample.
        """
        raise NotImplementedError

    @abstractmethod
    def infer(self, obs: torch.Tensor, ebm: nn.Module) -> torch.Tensor:
        """
        Overview:
            Optimize for the best action conditioned on the current observation.
        Arguments:
            - obs (:obj:`torch.Tensor`): Observations.
            - ebm (:obj:`torch.nn.Module`): Energy based model.
        Returns:
            - best_action_samples (:obj:`torch.Tensor`): Best actions.
        Shapes:
            - obs (:obj:`torch.Tensor`): :math:`(B, O)`.
            - ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`.
            - best_action_samples (:obj:`torch.Tensor`): :math:`(B, A)`.
        """
        raise NotImplementedError


@STOCHASTIC_OPTIMIZER_REGISTRY.register('dfo')
class DFO(StochasticOptimizer):
    """
    Overview:
        Derivative-Free Optimizer in paper Implicit Behavioral Cloning.
        https://arxiv.org/abs/2109.00137
    Interface:
        ``init``, ``sample``, ``infer``
    """

    def __init__(
        self,
        noise_scale: float = 0.33,
        noise_shrink: float = 0.5,
        iters: int = 3,
        train_samples: int = 8,
        inference_samples: int = 16384,
        device: str = 'cpu',
    ):
        """
        Overview:
            Initialize the Derivative-Free Optimizer
        Arguments:
            - noise_scale (:obj:`float`): Initial noise scale.
            - noise_shrink (:obj:`float`): Noise scale shrink rate.
            - iters (:obj:`int`): Number of iterations.
            - train_samples (:obj:`int`): Number of samples for training.
            - inference_samples (:obj:`int`): Number of samples for inference.
            - device (:obj:`str`): Device.
        """
        self.action_bounds = None
        self.noise_scale = noise_scale
        self.noise_shrink = noise_shrink
        self.iters = iters
        self.train_samples = train_samples
        self.inference_samples = inference_samples
        self.device = device

    def sample(self, obs: torch.Tensor, ebm: nn.Module) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Overview:
            Drawing action samples from the uniform random distribution \
                and tiling observations to the same shape as action samples.
        Arguments:
            - obs (:obj:`torch.Tensor`): Observations.
            - ebm (:obj:`torch.nn.Module`): Energy based model.
        Returns:
            - tiled_obs (:obj:`torch.Tensor`): Tiled observation.
            - action_samples (:obj:`torch.Tensor`): Action samples.
        Shapes:
            - obs (:obj:`torch.Tensor`): :math:`(B, O)`.
            - ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`.
            - tiled_obs (:obj:`torch.Tensor`): :math:`(B, N, O)`.
            - action_samples (:obj:`torch.Tensor`): :math:`(B, N, A)`.
        Examples:
            >>> obs = torch.randn(2, 4)
            >>> ebm = EBM(4, 5)
            >>> opt = DFO()
            >>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0))
            >>> tiled_obs, action_samples = opt.sample(obs, ebm)
        """
        return self._sample(obs, self.train_samples)

    @torch.no_grad()
    def infer(self, obs: torch.Tensor, ebm: nn.Module) -> torch.Tensor:
        """
        Overview:
            Optimize for the best action conditioned on the current observation.
        Arguments:
            - obs (:obj:`torch.Tensor`): Observations.
            - ebm (:obj:`torch.nn.Module`): Energy based model.
        Returns:
            - best_action_samples (:obj:`torch.Tensor`): Actions.
        Shapes:
            - obs (:obj:`torch.Tensor`): :math:`(B, O)`.
            - ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`.
            - best_action_samples (:obj:`torch.Tensor`): :math:`(B, A)`.
        Examples:
            >>> obs = torch.randn(2, 4)
            >>> ebm = EBM(4, 5)
            >>> opt = DFO()
            >>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0))
            >>> best_action_samples = opt.infer(obs, ebm)
        """
        noise_scale = self.noise_scale

        # (B, N, O), (B, N, A)
        obs, action_samples = self._sample(obs, self.inference_samples)

        for i in range(self.iters):
            # (B, N)
            energies = ebm.forward(obs, action_samples)
            probs = F.softmax(-1.0 * energies, dim=-1)

            # Resample with replacement.
            idxs = torch.multinomial(probs, self.inference_samples, replacement=True)
            action_samples = action_samples[torch.arange(action_samples.size(0)).unsqueeze(-1), idxs]

            # Add noise and clip to target bounds.
            action_samples = action_samples + torch.randn_like(action_samples) * noise_scale
            action_samples = action_samples.clamp(min=self.action_bounds[0, :], max=self.action_bounds[1, :])

            noise_scale *= self.noise_shrink

        # Return target with highest probability.
        return self._get_best_action_sample(obs, action_samples, ebm)


@STOCHASTIC_OPTIMIZER_REGISTRY.register('ardfo')
class AutoRegressiveDFO(DFO):
    """
    Overview:
        AutoRegressive Derivative-Free Optimizer in paper Implicit Behavioral Cloning.
        https://arxiv.org/abs/2109.00137
    Interface:
        ``__init__``, ``infer``
    """

    def __init__(
        self,
        noise_scale: float = 0.33,
        noise_shrink: float = 0.5,
        iters: int = 3,
        train_samples: int = 8,
        inference_samples: int = 4096,
        device: str = 'cpu',
    ):
        """
        Overview:
            Initialize the AutoRegressive Derivative-Free Optimizer
        Arguments:
            - noise_scale (:obj:`float`): Initial noise scale.
            - noise_shrink (:obj:`float`): Noise scale shrink rate.
            - iters (:obj:`int`): Number of iterations.
            - train_samples (:obj:`int`): Number of samples for training.
            - inference_samples (:obj:`int`): Number of samples for inference.
            - device (:obj:`str`): Device.
        """
        super().__init__(noise_scale, noise_shrink, iters, train_samples, inference_samples, device)

    @torch.no_grad()
    def infer(self, obs: torch.Tensor, ebm: nn.Module) -> torch.Tensor:
        """
        Overview:
            Optimize for the best action conditioned on the current observation.
        Arguments:
            - obs (:obj:`torch.Tensor`): Observations.
            - ebm (:obj:`torch.nn.Module`): Energy based model.
        Returns:
            - best_action_samples (:obj:`torch.Tensor`): Actions.
        Shapes:
            - obs (:obj:`torch.Tensor`): :math:`(B, O)`.
            - ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`.
            - best_action_samples (:obj:`torch.Tensor`): :math:`(B, A)`.
        Examples:
            >>> obs = torch.randn(2, 4)
            >>> ebm = EBM(4, 5)
            >>> opt = AutoRegressiveDFO()
            >>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0))
            >>> best_action_samples = opt.infer(obs, ebm)
        """
        noise_scale = self.noise_scale

        # (B, N, O), (B, N, A)
        obs, action_samples = self._sample(obs, self.inference_samples)

        for i in range(self.iters):
            # j: action_dim index
            for j in range(action_samples.shape[-1]):
                # (B, N)
                energies = ebm.forward(obs, action_samples)[..., j]
                probs = F.softmax(-1.0 * energies, dim=-1)

                # Resample with replacement.
                idxs = torch.multinomial(probs, self.inference_samples, replacement=True)
                action_samples = action_samples[torch.arange(action_samples.size(0)).unsqueeze(-1), idxs]

                # Add noise and clip to target bounds.
                action_samples[..., j] = action_samples[..., j] + torch.randn_like(action_samples[..., j]) * noise_scale

                action_samples[..., j] = action_samples[..., j].clamp(
                    min=self.action_bounds[0, j], max=self.action_bounds[1, j]
                )

            noise_scale *= self.noise_shrink

        # (B, N)
        energies = ebm.forward(obs, action_samples)[..., -1]
        probs = F.softmax(-1.0 * energies, dim=-1)
        # (B, )
        best_idxs = probs.argmax(dim=-1)
        return action_samples[torch.arange(action_samples.size(0)), best_idxs]


@STOCHASTIC_OPTIMIZER_REGISTRY.register('mcmc')
class MCMC(StochasticOptimizer):
    """
    Overview:
        MCMC method as stochastic optimizers in paper Implicit Behavioral Cloning.
        https://arxiv.org/abs/2109.00137
    Interface:
        ``__init__``, ``sample``, ``infer``, ``grad_penalty``
    """

    class BaseScheduler(ABC):
        """
        Overview:
            Base class for learning rate scheduler.
        Interface:
            ``get_rate``
        """

        @abstractmethod
        def get_rate(self, index):
            """
            Overview:
                Abstract method for getting learning rate.
            """
            raise NotImplementedError

    class ExponentialScheduler:
        """
        Overview:
            Exponential learning rate schedule for Langevin sampler.
        Interface:
            ``__init__``, ``get_rate``
        """

        def __init__(self, init, decay):
            """
            Overview:
                Initialize the ExponentialScheduler.
            Arguments:
                - init (:obj:`float`): Initial learning rate.
                - decay (:obj:`float`): Decay rate.
            """
            self._decay = decay
            self._latest_lr = init

        def get_rate(self, index):
            """
            Overview:
                Get learning rate. Assumes calling sequentially.
            Arguments:
                - index (:obj:`int`): Current iteration.
            """
            del index
            lr = self._latest_lr
            self._latest_lr *= self._decay
            return lr

    class PolynomialScheduler:
        """
        Overview:
            Polynomial learning rate schedule for Langevin sampler.
        Interface:
            ``__init__``, ``get_rate``
        """

        def __init__(self, init, final, power, num_steps):
            """
            Overview:
                Initialize the PolynomialScheduler.
            Arguments:
                - init (:obj:`float`): Initial learning rate.
                - final (:obj:`float`): Final learning rate.
                - power (:obj:`float`): Power of polynomial.
                - num_steps (:obj:`int`): Number of steps.
            """
            self._init = init
            self._final = final
            self._power = power
            self._num_steps = num_steps

        def get_rate(self, index):
            """
            Overview:
                Get learning rate for index.
            Arguments:
                - index (:obj:`int`): Current iteration.
            """
            if index == -1:
                return self._init
            return (
                (self._init - self._final) * ((1 - (float(index) / float(self._num_steps - 1))) ** (self._power))
            ) + self._final

    def __init__(
        self,
        iters: int = 100,
        use_langevin_negative_samples: bool = True,
        train_samples: int = 8,
        inference_samples: int = 512,
        stepsize_scheduler: dict = dict(
            init=0.5,
            final=1e-5,
            power=2.0,
            # num_steps,
        ),
        optimize_again: bool = True,
        again_stepsize_scheduler: dict = dict(
            init=1e-5,
            final=1e-5,
            power=2.0,
            # num_steps,
        ),
        device: str = 'cpu',
        # langevin_step
        noise_scale: float = 0.5,
        grad_clip=None,
        delta_action_clip: float = 0.5,
        add_grad_penalty: bool = True,
        grad_norm_type: str = 'inf',
        grad_margin: float = 1.0,
        grad_loss_weight: float = 1.0,
        **kwargs,
    ):
        """
        Overview:
            Initialize the MCMC.
        Arguments:
            - iters (:obj:`int`): Number of iterations.
            - use_langevin_negative_samples (:obj:`bool`): Whether to use Langevin sampler.
            - train_samples (:obj:`int`): Number of samples for training.
            - inference_samples (:obj:`int`): Number of samples for inference.
            - stepsize_scheduler (:obj:`dict`): Step size scheduler for Langevin sampler.
            - optimize_again (:obj:`bool`): Whether to run a second optimization.
            - again_stepsize_scheduler (:obj:`dict`): Step size scheduler for the second optimization.
            - device (:obj:`str`): Device.
            - noise_scale (:obj:`float`): Initial noise scale.
            - grad_clip (:obj:`float`): Gradient clip.
            - delta_action_clip (:obj:`float`): Action clip.
            - add_grad_penalty (:obj:`bool`): Whether to add gradient penalty.
            - grad_norm_type (:obj:`str`): Gradient norm type.
            - grad_margin (:obj:`float`): Gradient margin.
            - grad_loss_weight (:obj:`float`): Gradient loss weight.
        """
        self.iters = iters
        self.use_langevin_negative_samples = use_langevin_negative_samples
        self.train_samples = train_samples
        self.inference_samples = inference_samples
        self.stepsize_scheduler = stepsize_scheduler
        self.optimize_again = optimize_again
        self.again_stepsize_scheduler = again_stepsize_scheduler
        self.device = device

        self.noise_scale = noise_scale
        self.grad_clip = grad_clip
        self.delta_action_clip = delta_action_clip
        self.add_grad_penalty = add_grad_penalty
        self.grad_norm_type = grad_norm_type
        self.grad_margin = grad_margin
        self.grad_loss_weight = grad_loss_weight

    @staticmethod
    def _gradient_wrt_act(
            obs: torch.Tensor,
            action: torch.Tensor,
            ebm: nn.Module,
            create_graph: bool = False,
    ) -> torch.Tensor:
        """
        Overview:
            Calculate gradient w.r.t action.
        Arguments:
            - obs (:obj:`torch.Tensor`): Observations.
            - action (:obj:`torch.Tensor`): Actions.
            - ebm (:obj:`torch.nn.Module`): Energy based model.
            - create_graph (:obj:`bool`): Whether to create graph.
        Returns:
            - grad (:obj:`torch.Tensor`): Gradient w.r.t action.
        Shapes:
            - obs (:obj:`torch.Tensor`): :math:`(B, N, O)`.
            - action (:obj:`torch.Tensor`): :math:`(B, N, A)`.
            - ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`.
            - grad (:obj:`torch.Tensor`): :math:`(B, N, A)`.
        """
        action.requires_grad_(True)
        energy = ebm.forward(obs, action).sum()
        # `create_graph` set to `True` when second order derivative
        #  is needed i.e, d(de/da)/d_param
        grad = torch.autograd.grad(energy, action, create_graph=create_graph)[0]
        action.requires_grad_(False)
        return grad

    def grad_penalty(self, obs: torch.Tensor, action: torch.Tensor, ebm: nn.Module) -> torch.Tensor:
        """
        Overview:
            Calculate gradient penalty.
        Arguments:
            - obs (:obj:`torch.Tensor`): Observations.
            - action (:obj:`torch.Tensor`): Actions.
            - ebm (:obj:`torch.nn.Module`): Energy based model.
        Returns:
            - loss (:obj:`torch.Tensor`): Gradient penalty.
        Shapes:
            - obs (:obj:`torch.Tensor`): :math:`(B, N+1, O)`.
            - action (:obj:`torch.Tensor`): :math:`(B, N+1, A)`.
            - ebm (:obj:`torch.nn.Module`): :math:`(B, N+1, O)`.
            - loss (:obj:`torch.Tensor`): :math:`(B, )`.
        """
        if not self.add_grad_penalty:
            return 0.
        # (B, N+1, A), this gradient is differentiable w.r.t model parameters
        de_dact = MCMC._gradient_wrt_act(obs, action, ebm, create_graph=True)

        def compute_grad_norm(grad_norm_type, de_dact) -> torch.Tensor:
            # de_deact: B, N+1, A
            # return:   B, N+1
            grad_norm_type_to_ord = {
                '1': 1,
                '2': 2,
                'inf': float('inf'),
            }
            ord = grad_norm_type_to_ord[grad_norm_type]
            return torch.linalg.norm(de_dact, ord, dim=-1)

        # (B, N+1)
        grad_norms = compute_grad_norm(self.grad_norm_type, de_dact)
        grad_norms = grad_norms - self.grad_margin
        grad_norms = grad_norms.clamp(min=0., max=1e10)
        grad_norms = grad_norms.pow(2)

        grad_loss = grad_norms.mean()
        return grad_loss * self.grad_loss_weight

    # can not use @torch.no_grad() during the inference
    # because we need to calculate gradient w.r.t inputs as MCMC updates.
    @no_ebm_grad()
    def _langevin_step(self, obs: torch.Tensor, action: torch.Tensor, stepsize: float, ebm: nn.Module) -> torch.Tensor:
        """
        Overview:
            Run one langevin MCMC step.
        Arguments:
            - obs (:obj:`torch.Tensor`): Observations.
            - action (:obj:`torch.Tensor`): Actions.
            - stepsize (:obj:`float`): Step size.
            - ebm (:obj:`torch.nn.Module`): Energy based model.
        Returns:
            - action (:obj:`torch.Tensor`): Actions.
        Shapes:
            - obs (:obj:`torch.Tensor`): :math:`(B, N, O)`.
            - action (:obj:`torch.Tensor`): :math:`(B, N, A)`.
            - stepsize (:obj:`float`): :math:`(B, )`.
            - ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`.
        """
        l_lambda = 1.0
        de_dact = MCMC._gradient_wrt_act(obs, action, ebm)

        if self.grad_clip:
            de_dact = de_dact.clamp(min=-self.grad_clip, max=self.grad_clip)

        gradient_scale = 0.5
        de_dact = (gradient_scale * l_lambda * de_dact + torch.randn_like(de_dact) * l_lambda * self.noise_scale)

        delta_action = stepsize * de_dact
        delta_action_clip = self.delta_action_clip * 0.5 * (self.action_bounds[1] - self.action_bounds[0])
        delta_action = delta_action.clamp(min=-delta_action_clip, max=delta_action_clip)

        action = action - delta_action
        action = action.clamp(min=self.action_bounds[0], max=self.action_bounds[1])

        return action

    @no_ebm_grad()
    def _langevin_action_given_obs(
            self,
            obs: torch.Tensor,
            action: torch.Tensor,
            ebm: nn.Module,
            scheduler: BaseScheduler = None
    ) -> torch.Tensor:
        """
        Overview:
            Run langevin MCMC for `self.iters` steps.
        Arguments:
            - obs (:obj:`torch.Tensor`): Observations.
            - action (:obj:`torch.Tensor`): Actions.
            - ebm (:obj:`torch.nn.Module`): Energy based model.
            - scheduler (:obj:`BaseScheduler`): Learning rate scheduler.
        Returns:
            - action (:obj:`torch.Tensor`): Actions.
        Shapes:
            - obs (:obj:`torch.Tensor`): :math:`(B, N, O)`.
            - action (:obj:`torch.Tensor`): :math:`(B, N, A)`.
            - ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`.
        """
        if not scheduler:
            self.stepsize_scheduler['num_steps'] = self.iters
            scheduler = MCMC.PolynomialScheduler(**self.stepsize_scheduler)
        stepsize = scheduler.get_rate(-1)
        for i in range(self.iters):
            action = self._langevin_step(obs, action, stepsize, ebm)
            stepsize = scheduler.get_rate(i)
        return action

    @no_ebm_grad()
    def sample(self, obs: torch.Tensor, ebm: nn.Module) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Overview:
            Create tiled observations and sample counter-negatives for InfoNCE loss.
        Arguments:
            - obs (:obj:`torch.Tensor`): Observations.
            - ebm (:obj:`torch.nn.Module`): Energy based model.
        Returns:
            - tiled_obs (:obj:`torch.Tensor`): Tiled observations.
            - action_samples (:obj:`torch.Tensor`): Action samples.
        Shapes:
            - obs (:obj:`torch.Tensor`): :math:`(B, O)`.
            - ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`.
            - tiled_obs (:obj:`torch.Tensor`): :math:`(B, N, O)`.
            - action_samples (:obj:`torch.Tensor`): :math:`(B, N, A)`.
        Examples:
            >>> obs = torch.randn(2, 4)
            >>> ebm = EBM(4, 5)
            >>> opt = MCMC()
            >>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0))
            >>> tiled_obs, action_samples = opt.sample(obs, ebm)
        """
        obs, uniform_action_samples = self._sample(obs, self.train_samples)
        if not self.use_langevin_negative_samples:
            return obs, uniform_action_samples
        langevin_action_samples = self._langevin_action_given_obs(obs, uniform_action_samples, ebm)
        return obs, langevin_action_samples

    @no_ebm_grad()
    def infer(self, obs: torch.Tensor, ebm: nn.Module) -> torch.Tensor:
        """
        Overview:
            Optimize for the best action conditioned on the current observation.
        Arguments:
            - obs (:obj:`torch.Tensor`): Observations.
            - ebm (:obj:`torch.nn.Module`): Energy based model.
        Returns:
            - best_action_samples (:obj:`torch.Tensor`): Actions.
        Shapes:
            - obs (:obj:`torch.Tensor`): :math:`(B, O)`.
            - ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`.
            - best_action_samples (:obj:`torch.Tensor`): :math:`(B, A)`.
        Examples:
            >>> obs = torch.randn(2, 4)
            >>> ebm = EBM(4, 5)
            >>> opt = MCMC()
            >>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0))
            >>> best_action_samples = opt.infer(obs, ebm)
        """
        # (B, N, O), (B, N, A)
        obs, uniform_action_samples = self._sample(obs, self.inference_samples)
        action_samples = self._langevin_action_given_obs(
            obs,
            uniform_action_samples,
            ebm,
        )

        # Run a second optimization, a trick for more precise inference
        if self.optimize_again:
            self.again_stepsize_scheduler['num_steps'] = self.iters
            action_samples = self._langevin_action_given_obs(
                obs,
                action_samples,
                ebm,
                scheduler=MCMC.PolynomialScheduler(**self.again_stepsize_scheduler),
            )

        # action_samples: B, N, A
        return self._get_best_action_sample(obs, action_samples, ebm)


[docs]@MODEL_REGISTRY.register('ebm') class EBM(nn.Module): """ Overview: Energy based model. Interface: ``__init__``, ``forward`` """
[docs] def __init__( self, obs_shape: int, action_shape: int, hidden_size: int = 512, hidden_layer_num: int = 4, **kwargs, ): """ Overview: Initialize the EBM. Arguments: - obs_shape (:obj:`int`): Observation shape. - action_shape (:obj:`int`): Action shape. - hidden_size (:obj:`int`): Hidden size. - hidden_layer_num (:obj:`int`): Number of hidden layers. """ super().__init__() input_size = obs_shape + action_shape self.net = nn.Sequential( nn.Linear(input_size, hidden_size), nn.ReLU(), RegressionHead( hidden_size, 1, hidden_layer_num, final_tanh=False, ) )
[docs] def forward(self, obs, action): """ Overview: Forward computation graph of EBM. Arguments: - obs (:obj:`torch.Tensor`): Observation of shape (B, N, O). - action (:obj:`torch.Tensor`): Action of shape (B, N, A). Returns: - pred (:obj:`torch.Tensor`): Energy of shape (B, N). Examples: >>> obs = torch.randn(2, 3, 4) >>> action = torch.randn(2, 3, 5) >>> ebm = EBM(4, 5) >>> pred = ebm(obs, action) """ x = torch.cat([obs, action], -1) x = self.net(x) return x['pred']
[docs]@MODEL_REGISTRY.register('arebm') class AutoregressiveEBM(nn.Module): """ Overview: Autoregressive energy based model. Interface: ``__init__``, ``forward`` """
[docs] def __init__( self, obs_shape: int, action_shape: int, hidden_size: int = 512, hidden_layer_num: int = 4, ): """ Overview: Initialize the AutoregressiveEBM. Arguments: - obs_shape (:obj:`int`): Observation shape. - action_shape (:obj:`int`): Action shape. - hidden_size (:obj:`int`): Hidden size. - hidden_layer_num (:obj:`int`): Number of hidden layers. """ super().__init__() self.ebm_list = nn.ModuleList() for i in range(action_shape): self.ebm_list.append(EBM(obs_shape, i + 1, hidden_size, hidden_layer_num))
[docs] def forward(self, obs, action): """ Overview: Forward computation graph of AutoregressiveEBM. Arguments: - obs (:obj:`torch.Tensor`): Observation of shape (B, N, O). - action (:obj:`torch.Tensor`): Action of shape (B, N, A). Returns: - pred (:obj:`torch.Tensor`): Energy of shape (B, N, A). Examples: >>> obs = torch.randn(2, 3, 4) >>> action = torch.randn(2, 3, 5) >>> arebm = AutoregressiveEBM(4, 5) >>> pred = arebm(obs, action) """ output_list = [] for i, ebm in enumerate(self.ebm_list): output_list.append(ebm(obs, action[..., :i + 1])) return torch.stack(output_list, axis=-1)