Shortcuts

Source code for ding.policy.qgpo

#############################################################
# This QGPO model is a modification implementation from https://github.com/ChenDRAG/CEP-energy-guided-diffusion
#############################################################

from typing import List, Dict, Any
import torch
from ding.utils import POLICY_REGISTRY
from ding.utils.data import default_collate
from ding.torch_utils import to_device
from .base_policy import Policy


[docs]@POLICY_REGISTRY.register('qgpo') class QGPOPolicy(Policy): """ Overview: Policy class of QGPO algorithm (https://arxiv.org/abs/2304.12824). Contrastive Energy Prediction for Exact Energy-Guided Diffusion Sampling in Offline Reinforcement Learning Interfaces: ``__init__``, ``forward``, ``learn``, ``eval``, ``state_dict``, ``load_state_dict`` """ config = dict( # (str) RL policy register name (refer to function "POLICY_REGISTRY"). type='qgpo', # (bool) Whether to use cuda for network. cuda=False, # (bool type) on_policy: Determine whether on-policy or off-policy. # on-policy setting influences the behaviour of buffer. # Default False in QGPO. on_policy=False, multi_agent=False, model=dict( qgpo_critic=dict( # (float) The scale of the energy guidance when training qt. # \pi_{behavior}\exp(f(s,a)) \propto \pi_{behavior}\exp(alpha * Q(s,a)) alpha=3, # (float) The scale of the energy guidance when training q0. # \mathcal{T}Q(s,a)=r(s,a)+\mathbb{E}_{s'\sim P(s'|s,a),a'\sim\pi_{support}(a'|s')}Q(s',a') # \pi_{support} \propto \pi_{behavior}\exp(q_alpha * Q(s,a)) q_alpha=1, ), device='cuda', # obs_dim # action_dim ), learn=dict( # learning rate for behavior model training learning_rate=1e-4, # batch size during the training of behavior model batch_size=4096, # batch size during the training of q value batch_size_q=256, # number of fake action support M=16, # number of diffusion time steps diffusion_steps=15, # training iterations when behavior model is fixed behavior_policy_stop_training_iter=600000, # training iterations when energy-guided policy begin training energy_guided_policy_begin_training_iter=600000, # training iterations when q value stop training, default None means no limit q_value_stop_training_iter=1100000, ), eval=dict( # energy guidance scale for policy in evaluation # \pi_{evaluation} \propto \pi_{behavior}\exp(guidance_scale * alpha * Q(s,a)) guidance_scale=[0.0, 1.0, 2.0, 3.0, 5.0, 8.0, 10.0], ), ) def _init_learn(self) -> None: """ Overview: Learn mode initialization method. For QGPO, it mainly contains the optimizer, \ algorithm-specific arguments such as qt_update_momentum, discount, behavior_policy_stop_training_iter, \ energy_guided_policy_begin_training_iter and q_value_stop_training_iter, etc. This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. """ self.cuda = self._cfg.cuda self.behavior_model_optimizer = torch.optim.Adam( self._model.score_model.parameters(), lr=self._cfg.learn.learning_rate ) self.q_optimizer = torch.optim.Adam(self._model.q.q0.parameters(), lr=3e-4) self.qt_optimizer = torch.optim.Adam(self._model.q.qt.parameters(), lr=3e-4) self.qt_update_momentum = 0.005 self.discount = 0.99 self.behavior_policy_stop_training_iter = self._cfg.learn.behavior_policy_stop_training_iter self.energy_guided_policy_begin_training_iter = self._cfg.learn.energy_guided_policy_begin_training_iter self.q_value_stop_training_iter = self._cfg.learn.q_value_stop_training_iter def _forward_learn(self, data: dict) -> Dict[str, Any]: """ Overview: Forward function for learning mode. The training of QGPO algorithm is based on contrastive energy prediction, \ which needs true action and fake action. The true action is sampled from the dataset, and the fake action \ is sampled from the action support generated by the behavior policy. The training process is divided into two stages: 1. Train the behavior model, which is modeled as a diffusion model by parameterizing the score function. 2. Train the Q function by fake action support generated by the behavior model. 3. Train the energy-guided policy by the Q function. Arguments: - data (:obj:`dict`): Dict type data. Returns: - result (:obj:`dict`): Dict type data of algorithm results. """ if self.cuda: data = to_device(data, self._device) s = data['s'] a = data['a'] r = data['r'] s_ = data['s_'] d = data['d'] fake_a = data['fake_a'] fake_a_ = data['fake_a_'] # training behavior model if self.behavior_policy_stop_training_iter > 0: behavior_model_training_loss = self._model.score_model_loss_fn(a, s) self.behavior_model_optimizer.zero_grad() behavior_model_training_loss.backward() self.behavior_model_optimizer.step() self.behavior_policy_stop_training_iter -= 1 behavior_model_training_loss = behavior_model_training_loss.item() else: behavior_model_training_loss = 0 # training Q function self.energy_guided_policy_begin_training_iter -= 1 self.q_value_stop_training_iter -= 1 if self.energy_guided_policy_begin_training_iter < 0: if self.q_value_stop_training_iter > 0: q0_loss = self._model.q_loss_fn(a, s, r, s_, d, fake_a_, discount=self.discount) self.q_optimizer.zero_grad() q0_loss.backward() self.q_optimizer.step() # Update target for param, target_param in zip(self._model.q.q0.parameters(), self._model.q.q0_target.parameters()): target_param.data.copy_( self.qt_update_momentum * param.data + (1 - self.qt_update_momentum) * target_param.data ) q0_loss = q0_loss.item() else: q0_loss = 0 qt_loss = self._model.qt_loss_fn(s, fake_a) self.qt_optimizer.zero_grad() qt_loss.backward() self.qt_optimizer.step() qt_loss = qt_loss.item() else: q0_loss = 0 qt_loss = 0 total_loss = behavior_model_training_loss + q0_loss + qt_loss return dict( total_loss=total_loss, behavior_model_training_loss=behavior_model_training_loss, q0_loss=q0_loss, qt_loss=qt_loss, ) def _init_collect(self) -> None: """ Overview: Collect mode initialization method. Not supported for QGPO. """ pass def _forward_collect(self) -> None: """ Overview: Forward function for collect mode. Not supported for QGPO. """ pass def _init_eval(self) -> None: """ Overview: Eval mode initialization method. For QGPO, it mainly contains the guidance_scale and diffusion_steps, etc. This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``. """ self.diffusion_steps = self._cfg.eval.diffusion_steps def _forward_eval(self, data: dict, guidance_scale: float) -> dict: """ Overview: Forward function for eval mode. The eval process is based on the energy-guided policy, \ which is modeled as a diffusion model by parameterizing the score function. Arguments: - data (:obj:`dict`): Dict type data. - guidance_scale (:obj:`float`): The scale of the energy guidance. Returns: - output (:obj:`dict`): Dict type data of algorithm output. """ data_id = list(data.keys()) states = default_collate(list(data.values())) actions = self._model.select_actions( states, diffusion_steps=self.diffusion_steps, guidance_scale=guidance_scale ) output = actions return {i: {"action": d} for i, d in zip(data_id, output)} def _get_train_sample(self, transitions: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """ Overview: Get the train sample from the replay buffer, currently not supported for QGPO. Arguments: - transitions (:obj:`List[Dict[str, Any]]`): The data from replay buffer. Returns: - samples (:obj:`List[Dict[str, Any]]`): The data for training. """ pass def _process_transition(self) -> None: """ Overview: Process the transition data, currently not supported for QGPO. """ pass def _state_dict_learn(self) -> Dict[str, Any]: """ Overview: Return the state dict for saving. Returns: - state_dict (:obj:`Dict[str, Any]`): Dict type data of state dict. """ return { 'model': self._model.state_dict(), 'behavior_model_optimizer': self.behavior_model_optimizer.state_dict(), } def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: """ Overview: Load the state dict. Arguments: - state_dict (:obj:`Dict[str, Any]`): Dict type data of state dict. """ self._model.load_state_dict(state_dict['model']) self.behavior_model_optimizer.load_state_dict(state_dict['behavior_model_optimizer']) def _monitor_vars_learn(self) -> List[str]: """ Overview: Return the variables names to be monitored. """ return ['total_loss', 'behavior_model_training_loss', 'q0_loss', 'qt_loss']