Source code for ding.policy.qgpo
#############################################################
# This QGPO model is a modification implementation from https://github.com/ChenDRAG/CEP-energy-guided-diffusion
#############################################################
from typing import List, Dict, Any
import torch
from ding.utils import POLICY_REGISTRY
from ding.utils.data import default_collate
from ding.torch_utils import to_device
from .base_policy import Policy
[docs]@POLICY_REGISTRY.register('qgpo')
class QGPOPolicy(Policy):
"""
Overview:
Policy class of QGPO algorithm (https://arxiv.org/abs/2304.12824).
Contrastive Energy Prediction for Exact Energy-Guided Diffusion Sampling in Offline Reinforcement Learning
Interfaces:
``__init__``, ``forward``, ``learn``, ``eval``, ``state_dict``, ``load_state_dict``
"""
config = dict(
# (str) RL policy register name (refer to function "POLICY_REGISTRY").
type='qgpo',
# (bool) Whether to use cuda for network.
cuda=False,
# (bool type) on_policy: Determine whether on-policy or off-policy.
# on-policy setting influences the behaviour of buffer.
# Default False in QGPO.
on_policy=False,
multi_agent=False,
model=dict(
qgpo_critic=dict(
# (float) The scale of the energy guidance when training qt.
# \pi_{behavior}\exp(f(s,a)) \propto \pi_{behavior}\exp(alpha * Q(s,a))
alpha=3,
# (float) The scale of the energy guidance when training q0.
# \mathcal{T}Q(s,a)=r(s,a)+\mathbb{E}_{s'\sim P(s'|s,a),a'\sim\pi_{support}(a'|s')}Q(s',a')
# \pi_{support} \propto \pi_{behavior}\exp(q_alpha * Q(s,a))
q_alpha=1,
),
device='cuda',
# obs_dim
# action_dim
),
learn=dict(
# learning rate for behavior model training
learning_rate=1e-4,
# batch size during the training of behavior model
batch_size=4096,
# batch size during the training of q value
batch_size_q=256,
# number of fake action support
M=16,
# number of diffusion time steps
diffusion_steps=15,
# training iterations when behavior model is fixed
behavior_policy_stop_training_iter=600000,
# training iterations when energy-guided policy begin training
energy_guided_policy_begin_training_iter=600000,
# training iterations when q value stop training, default None means no limit
q_value_stop_training_iter=1100000,
),
eval=dict(
# energy guidance scale for policy in evaluation
# \pi_{evaluation} \propto \pi_{behavior}\exp(guidance_scale * alpha * Q(s,a))
guidance_scale=[0.0, 1.0, 2.0, 3.0, 5.0, 8.0, 10.0],
),
)
def _init_learn(self) -> None:
"""
Overview:
Learn mode initialization method. For QGPO, it mainly contains the optimizer, \
algorithm-specific arguments such as qt_update_momentum, discount, behavior_policy_stop_training_iter, \
energy_guided_policy_begin_training_iter and q_value_stop_training_iter, etc.
This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``.
"""
self.cuda = self._cfg.cuda
self.behavior_model_optimizer = torch.optim.Adam(
self._model.score_model.parameters(), lr=self._cfg.learn.learning_rate
)
self.q_optimizer = torch.optim.Adam(self._model.q.q0.parameters(), lr=3e-4)
self.qt_optimizer = torch.optim.Adam(self._model.q.qt.parameters(), lr=3e-4)
self.qt_update_momentum = 0.005
self.discount = 0.99
self.behavior_policy_stop_training_iter = self._cfg.learn.behavior_policy_stop_training_iter
self.energy_guided_policy_begin_training_iter = self._cfg.learn.energy_guided_policy_begin_training_iter
self.q_value_stop_training_iter = self._cfg.learn.q_value_stop_training_iter
def _forward_learn(self, data: dict) -> Dict[str, Any]:
"""
Overview:
Forward function for learning mode.
The training of QGPO algorithm is based on contrastive energy prediction, \
which needs true action and fake action. The true action is sampled from the dataset, and the fake action \
is sampled from the action support generated by the behavior policy.
The training process is divided into two stages:
1. Train the behavior model, which is modeled as a diffusion model by parameterizing the score function.
2. Train the Q function by fake action support generated by the behavior model.
3. Train the energy-guided policy by the Q function.
Arguments:
- data (:obj:`dict`): Dict type data.
Returns:
- result (:obj:`dict`): Dict type data of algorithm results.
"""
if self.cuda:
data = to_device(data, self._device)
s = data['s']
a = data['a']
r = data['r']
s_ = data['s_']
d = data['d']
fake_a = data['fake_a']
fake_a_ = data['fake_a_']
# training behavior model
if self.behavior_policy_stop_training_iter > 0:
behavior_model_training_loss = self._model.score_model_loss_fn(a, s)
self.behavior_model_optimizer.zero_grad()
behavior_model_training_loss.backward()
self.behavior_model_optimizer.step()
self.behavior_policy_stop_training_iter -= 1
behavior_model_training_loss = behavior_model_training_loss.item()
else:
behavior_model_training_loss = 0
# training Q function
self.energy_guided_policy_begin_training_iter -= 1
self.q_value_stop_training_iter -= 1
if self.energy_guided_policy_begin_training_iter < 0:
if self.q_value_stop_training_iter > 0:
q0_loss = self._model.q_loss_fn(a, s, r, s_, d, fake_a_, discount=self.discount)
self.q_optimizer.zero_grad()
q0_loss.backward()
self.q_optimizer.step()
# Update target
for param, target_param in zip(self._model.q.q0.parameters(), self._model.q.q0_target.parameters()):
target_param.data.copy_(
self.qt_update_momentum * param.data + (1 - self.qt_update_momentum) * target_param.data
)
q0_loss = q0_loss.item()
else:
q0_loss = 0
qt_loss = self._model.qt_loss_fn(s, fake_a)
self.qt_optimizer.zero_grad()
qt_loss.backward()
self.qt_optimizer.step()
qt_loss = qt_loss.item()
else:
q0_loss = 0
qt_loss = 0
total_loss = behavior_model_training_loss + q0_loss + qt_loss
return dict(
total_loss=total_loss,
behavior_model_training_loss=behavior_model_training_loss,
q0_loss=q0_loss,
qt_loss=qt_loss,
)
def _init_collect(self) -> None:
"""
Overview:
Collect mode initialization method. Not supported for QGPO.
"""
pass
def _forward_collect(self) -> None:
"""
Overview:
Forward function for collect mode. Not supported for QGPO.
"""
pass
def _init_eval(self) -> None:
"""
Overview:
Eval mode initialization method. For QGPO, it mainly contains the guidance_scale and diffusion_steps, etc.
This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``.
"""
self.diffusion_steps = self._cfg.eval.diffusion_steps
def _forward_eval(self, data: dict, guidance_scale: float) -> dict:
"""
Overview:
Forward function for eval mode. The eval process is based on the energy-guided policy, \
which is modeled as a diffusion model by parameterizing the score function.
Arguments:
- data (:obj:`dict`): Dict type data.
- guidance_scale (:obj:`float`): The scale of the energy guidance.
Returns:
- output (:obj:`dict`): Dict type data of algorithm output.
"""
data_id = list(data.keys())
states = default_collate(list(data.values()))
actions = self._model.select_actions(
states, diffusion_steps=self.diffusion_steps, guidance_scale=guidance_scale
)
output = actions
return {i: {"action": d} for i, d in zip(data_id, output)}
def _get_train_sample(self, transitions: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Overview:
Get the train sample from the replay buffer, currently not supported for QGPO.
Arguments:
- transitions (:obj:`List[Dict[str, Any]]`): The data from replay buffer.
Returns:
- samples (:obj:`List[Dict[str, Any]]`): The data for training.
"""
pass
def _process_transition(self) -> None:
"""
Overview:
Process the transition data, currently not supported for QGPO.
"""
pass
def _state_dict_learn(self) -> Dict[str, Any]:
"""
Overview:
Return the state dict for saving.
Returns:
- state_dict (:obj:`Dict[str, Any]`): Dict type data of state dict.
"""
return {
'model': self._model.state_dict(),
'behavior_model_optimizer': self.behavior_model_optimizer.state_dict(),
}
def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
"""
Overview:
Load the state dict.
Arguments:
- state_dict (:obj:`Dict[str, Any]`): Dict type data of state dict.
"""
self._model.load_state_dict(state_dict['model'])
self.behavior_model_optimizer.load_state_dict(state_dict['behavior_model_optimizer'])
def _monitor_vars_learn(self) -> List[str]:
"""
Overview:
Return the variables names to be monitored.
"""
return ['total_loss', 'behavior_model_training_loss', 'q0_loss', 'qt_loss']