Shortcuts

Source code for ding.model.template.qgpo

#############################################################
# This QGPO model is a modification implementation from https://github.com/ChenDRAG/CEP-energy-guided-diffusion
#############################################################

from easydict import EasyDict
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
from ding.torch_utils import MLP
from ding.torch_utils.diffusion_SDE.dpm_solver_pytorch import DPM_Solver, NoiseScheduleVP
from ding.model.common.encoder import GaussianFourierProjectionTimeEncoder
from ding.torch_utils.network.res_block import TemporalSpatialResBlock


def marginal_prob_std(t, device):
    """
    Overview:
        Compute the mean and standard deviation of $p_{0t}(x(t) | x(0))$.
    Arguments:
        - t (:obj:`torch.Tensor`): The input time.
        - device (:obj:`torch.device`): The device to use.
    """

    t = torch.tensor(t, device=device)
    beta_1 = 20.0
    beta_0 = 0.1
    log_mean_coeff = -0.25 * t ** 2 * (beta_1 - beta_0) - 0.5 * t * beta_0
    alpha_t = torch.exp(log_mean_coeff)
    std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff))
    return alpha_t, std


class TwinQ(nn.Module):
    """
    Overview:
        Twin Q network for QGPO, which has two Q networks.
    Interfaces:
        ``__init__``, ``forward``, ``both``
    """

    def __init__(self, action_dim, state_dim):
        """
        Overview:
            Initialization of Twin Q.
        Arguments:
            - action_dim (:obj:`int`): The dimension of action.
            - state_dim (:obj:`int`): The dimension of state.
        """
        super().__init__()
        self.q1 = MLP(
            in_channels=state_dim + action_dim,
            hidden_channels=256,
            out_channels=1,
            activation=nn.ReLU(),
            layer_num=4,
            output_activation=False
        )
        self.q2 = MLP(
            in_channels=state_dim + action_dim,
            hidden_channels=256,
            out_channels=1,
            activation=nn.ReLU(),
            layer_num=4,
            output_activation=False
        )

    def both(self, action, condition=None):
        """
        Overview:
            Return the output of two Q networks.
        Arguments:
            - action (:obj:`torch.Tensor`): The input action.
            - condition (:obj:`torch.Tensor`): The input condition.
        """
        as_ = torch.cat([action, condition], -1) if condition is not None else action
        return self.q1(as_), self.q2(as_)

    def forward(self, action, condition=None):
        """
        Overview:
            Return the minimum output of two Q networks.
        Arguments:
            - action (:obj:`torch.Tensor`): The input action.
            - condition (:obj:`torch.Tensor`): The input condition.
        """
        return torch.min(*self.both(action, condition))


class GuidanceQt(nn.Module):
    """
    Overview:
        Energy Guidance Qt network for QGPO. \
            In the origin paper, the energy guidance is trained by CEP method.
    Interfaces:
        ``__init__``, ``forward``
    """

    def __init__(self, action_dim, state_dim, time_embed_dim=32):
        """
        Overview:
            Initialization of Guidance Qt.
        Arguments:
            - action_dim (:obj:`int`): The dimension of action.
            - state_dim (:obj:`int`): The dimension of state.
            - time_embed_dim (:obj:`int`): The dimension of time embedding. \
                The time embedding is a Gaussian Fourier Feature tensor.
        """
        super().__init__()
        self.qt = MLP(
            in_channels=action_dim + time_embed_dim + state_dim,
            hidden_channels=256,
            out_channels=1,
            activation=torch.nn.SiLU(),
            layer_num=4,
            output_activation=False
        )
        self.embed = nn.Sequential(
            GaussianFourierProjectionTimeEncoder(embed_dim=time_embed_dim), nn.Linear(time_embed_dim, time_embed_dim)
        )

    def forward(self, action, t, condition=None):
        """
        Overview:
            Return the output of Guidance Qt.
        Arguments:
            - action (:obj:`torch.Tensor`): The input action.
            - t (:obj:`torch.Tensor`): The input time.
            - condition (:obj:`torch.Tensor`): The input condition.
        """
        embed = self.embed(t)
        ats = torch.cat([action, embed, condition], -1) if condition is not None else torch.cat([action, embed], -1)
        return self.qt(ats)


class QGPOCritic(nn.Module):
    """
    Overview:
        QGPO critic network.
    Interfaces:
        ``__init__``, ``forward``, ``calculateQ``, ``calculate_guidance``
    """

    def __init__(self, device, cfg, action_dim, state_dim) -> None:
        """
        Overview:
            Initialization of QGPO critic.
        Arguments:
            - device (:obj:`torch.device`): The device to use.
            - cfg (:obj:`EasyDict`): The config dict.
            - action_dim (:obj:`int`): The dimension of action.
            - state_dim (:obj:`int`): The dimension of state.
        """

        super().__init__()
        # is state_dim is 0  means unconditional guidance
        assert state_dim > 0
        # only apply to conditional sampling here
        self.device = device
        self.q0 = TwinQ(action_dim, state_dim).to(self.device)
        self.q0_target = copy.deepcopy(self.q0).requires_grad_(False).to(self.device)
        self.qt = GuidanceQt(action_dim, state_dim).to(self.device)

        self.alpha = cfg.alpha
        self.q_alpha = cfg.q_alpha

    def calculate_guidance(self, a, t, condition=None, guidance_scale=1.0):
        """
        Overview:
            Calculate the guidance for conditional sampling.
        Arguments:
            - a (:obj:`torch.Tensor`): The input action.
            - t (:obj:`torch.Tensor`): The input time.
            - condition (:obj:`torch.Tensor`): The input condition.
            - guidance_scale (:obj:`float`): The scale of guidance.
        """

        with torch.enable_grad():
            a.requires_grad_(True)
            Q_t = self.qt(a, t, condition)
            guidance = guidance_scale * torch.autograd.grad(torch.sum(Q_t), a)[0]
        return guidance.detach()

    def forward(self, a, condition=None):
        """
        Overview:
            Return the output of QGPO critic.
        Arguments:
            - a (:obj:`torch.Tensor`): The input action.
            - condition (:obj:`torch.Tensor`): The input condition.
        """

        return self.q0(a, condition)

    def calculateQ(self, a, condition=None):
        """
        Overview:
            Return the output of QGPO critic.
        Arguments:
            - a (:obj:`torch.Tensor`): The input action.
            - condition (:obj:`torch.Tensor`): The input condition.
        """

        return self(a, condition)


class ScoreNet(nn.Module):
    """
    Overview:
        Score-based generative model for QGPO.
    Interfaces:
        ``__init__``, ``forward``
    """

    def __init__(self, device, input_dim, output_dim, embed_dim=32):
        """
        Overview:
            Initialization of ScoreNet.
        Arguments:
            - device (:obj:`torch.device`): The device to use.
            - input_dim (:obj:`int`): The dimension of input.
            - output_dim (:obj:`int`): The dimension of output.
            - embed_dim (:obj:`int`): The dimension of time embedding. \
                The time embedding is a Gaussian Fourier Feature tensor.
        """

        super().__init__()

        # origin score base
        self.output_dim = output_dim
        self.embed = nn.Sequential(
            GaussianFourierProjectionTimeEncoder(embed_dim=embed_dim), nn.Linear(embed_dim, embed_dim)
        )

        self.device = device
        self.pre_sort_condition = nn.Sequential(nn.Linear(input_dim - output_dim, 32), torch.nn.SiLU())
        self.sort_t = nn.Sequential(
            nn.Linear(64, 128),
            torch.nn.SiLU(),
            nn.Linear(128, 128),
        )
        self.down_block1 = TemporalSpatialResBlock(output_dim, 512)
        self.down_block2 = TemporalSpatialResBlock(512, 256)
        self.down_block3 = TemporalSpatialResBlock(256, 128)
        self.middle1 = TemporalSpatialResBlock(128, 128)
        self.up_block3 = TemporalSpatialResBlock(256, 256)
        self.up_block2 = TemporalSpatialResBlock(512, 512)
        self.last = nn.Linear(1024, output_dim)

    def forward(self, x, t, condition):
        """
        Overview:
            Return the output of ScoreNet.
        Arguments:
            - x (:obj:`torch.Tensor`): The input tensor.
            - t (:obj:`torch.Tensor`): The input time.
            - condition (:obj:`torch.Tensor`): The input condition.
        """

        embed = self.embed(t)
        embed = torch.cat([self.pre_sort_condition(condition), embed], dim=-1)
        embed = self.sort_t(embed)
        d1 = self.down_block1(x, embed)
        d2 = self.down_block2(d1, embed)
        d3 = self.down_block3(d2, embed)
        u3 = self.middle1(d3, embed)
        u2 = self.up_block3(torch.cat([d3, u3], dim=-1), embed)
        u1 = self.up_block2(torch.cat([d2, u2], dim=-1), embed)
        u0 = torch.cat([d1, u1], dim=-1)
        h = self.last(u0)
        self.h = h
        # Normalize output
        return h / marginal_prob_std(t, device=self.device)[1][..., None]


[docs]class QGPO(nn.Module): """ Overview: Model of QGPO algorithm. Interfaces: ``__init__``, ``calculateQ``, ``select_actions``, ``sample``, ``score_model_loss_fn``, ``q_loss_fn``, \ ``qt_loss_fn`` """
[docs] def __init__(self, cfg: EasyDict) -> None: """ Overview: Initialization of QGPO. Arguments: - cfg (:obj:`EasyDict`): The config dict. """ super(QGPO, self).__init__() self.device = cfg.device self.obs_dim = cfg.obs_dim self.action_dim = cfg.action_dim self.noise_schedule = NoiseScheduleVP(schedule='linear') self.score_model = ScoreNet( device=self.device, input_dim=self.obs_dim + self.action_dim, output_dim=self.action_dim, ) self.q = QGPOCritic(self.device, cfg.qgpo_critic, action_dim=self.action_dim, state_dim=self.obs_dim)
[docs] def calculateQ(self, s, a): """ Overview: Calculate the Q value. Arguments: - s (:obj:`torch.Tensor`): The input state. - a (:obj:`torch.Tensor`): The input action. """ return self.q(a, s)
[docs] def select_actions(self, states, diffusion_steps=15, guidance_scale=1.0): """ Overview: Select actions for conditional sampling. Arguments: - states (:obj:`list`): The input states. - diffusion_steps (:obj:`int`): The diffusion steps. - guidance_scale (:obj:`float`): The scale of guidance. """ def forward_dpm_wrapper_fn(x, t): score = self.score_model(x, t, condition=states) result = -(score + self.q.calculate_guidance(x, t, states, guidance_scale=guidance_scale)) * marginal_prob_std( t, device=self.device )[1][..., None] return result self.eval() multiple_input = True with torch.no_grad(): states = torch.FloatTensor(states).to(self.device) if states.dim == 1: states = states.unsqueeze(0) multiple_input = False num_states = states.shape[0] init_x = torch.randn(states.shape[0], self.action_dim, device=self.device) results = DPM_Solver( forward_dpm_wrapper_fn, self.noise_schedule, predict_x0=True ).sample( init_x, steps=diffusion_steps, order=2 ).cpu().numpy() actions = results.reshape(num_states, self.action_dim).copy() # <bz, A> out_actions = [actions[i] for i in range(actions.shape[0])] if multiple_input else actions[0] self.train() return out_actions
[docs] def sample(self, states, sample_per_state=16, diffusion_steps=15, guidance_scale=1.0): """ Overview: Sample actions for conditional sampling. Arguments: - states (:obj:`list`): The input states. - sample_per_state (:obj:`int`): The number of samples per state. - diffusion_steps (:obj:`int`): The diffusion steps. - guidance_scale (:obj:`float`): The scale of guidance. """ def forward_dpm_wrapper_fn(x, t): score = self.score_model(x, t, condition=states) result = -(score + self.q.calculate_guidance(x, t, states, guidance_scale=guidance_scale)) \ * marginal_prob_std(t, device=self.device)[1][..., None] return result self.eval() num_states = states.shape[0] with torch.no_grad(): states = torch.FloatTensor(states).to(self.device) states = torch.repeat_interleave(states, sample_per_state, dim=0) init_x = torch.randn(states.shape[0], self.action_dim, device=self.device) results = DPM_Solver( forward_dpm_wrapper_fn, self.noise_schedule, predict_x0=True ).sample( init_x, steps=diffusion_steps, order=2 ).cpu().numpy() actions = results[:, :].reshape(num_states, sample_per_state, self.action_dim).copy() self.train() return actions
[docs] def score_model_loss_fn(self, x, s, eps=1e-3): """ Overview: The loss function for training score-based generative models. Arguments: model: A PyTorch model instance that represents a \ time-dependent score-based model. x: A mini-batch of training data. eps: A tolerance value for numerical stability. """ random_t = torch.rand(x.shape[0], device=x.device) * (1. - eps) + eps z = torch.randn_like(x) alpha_t, std = marginal_prob_std(random_t, device=x.device) perturbed_x = x * alpha_t[:, None] + z * std[:, None] score = self.score_model(perturbed_x, random_t, condition=s) loss = torch.mean(torch.sum((score * std[:, None] + z) ** 2, dim=(1, ))) return loss
[docs] def q_loss_fn(self, a, s, r, s_, d, fake_a_, discount=0.99): """ Overview: The loss function for training Q function. Arguments: - a (:obj:`torch.Tensor`): The input action. - s (:obj:`torch.Tensor`): The input state. - r (:obj:`torch.Tensor`): The input reward. - s\_ (:obj:`torch.Tensor`): The input next state. - d (:obj:`torch.Tensor`): The input done. - fake_a (:obj:`torch.Tensor`): The input fake action. - discount (:obj:`float`): The discount factor. """ with torch.no_grad(): softmax = nn.Softmax(dim=1) next_energy = self.q.q0_target(fake_a_, torch.stack([s_] * fake_a_.shape[1], axis=1)).detach().squeeze() next_v = torch.sum(softmax(self.q.q_alpha * next_energy) * next_energy, dim=-1, keepdim=True) # Update Q function targets = r + (1. - d.float()) * discount * next_v.detach() qs = self.q.q0.both(a, s) q_loss = sum(F.mse_loss(q, targets) for q in qs) / len(qs) return q_loss
[docs] def qt_loss_fn(self, s, fake_a): """ Overview: The loss function for training Guidance Qt. Arguments: - s (:obj:`torch.Tensor`): The input state. - fake_a (:obj:`torch.Tensor`): The input fake action. """ # input many s <bz, S> anction <bz, M, A>, energy = self.q.q0_target(fake_a, torch.stack([s] * fake_a.shape[1], axis=1)).detach().squeeze() # CEP guidance method, as proposed in the paper logsoftmax = nn.LogSoftmax(dim=1) softmax = nn.Softmax(dim=1) x0_data_energy = energy * self.q.alpha random_t = torch.rand((fake_a.shape[0], ), device=self.device) * (1. - 1e-3) + 1e-3 random_t = torch.stack([random_t] * fake_a.shape[1], dim=1) z = torch.randn_like(fake_a) alpha_t, std = marginal_prob_std(random_t, device=self.device) perturbed_fake_a = fake_a * alpha_t[..., None] + z * std[..., None] xt_model_energy = self.q.qt(perturbed_fake_a, random_t, torch.stack([s] * fake_a.shape[1], axis=1)).squeeze() p_label = softmax(x0_data_energy) # <bz,M> qt_loss = -torch.mean(torch.sum(p_label * logsoftmax(xt_model_energy), axis=-1)) return qt_loss