Source code for ding.policy.iqn
from typing import List, Dict, Any, Tuple, Union
import copy
import torch
from ding.torch_utils import Adam, to_device
from ding.rl_utils import iqn_nstep_td_data, iqn_nstep_td_error, get_train_sample, get_nstep_return_data
from ding.model import model_wrap
from ding.utils import POLICY_REGISTRY
from ding.utils.data import default_collate, default_decollate
from .dqn import DQNPolicy
from .common_utils import default_preprocess_learn
[docs]@POLICY_REGISTRY.register('iqn')
class IQNPolicy(DQNPolicy):
"""
Overview:
Policy class of IQN algorithm. Paper link: https://arxiv.org/pdf/1806.06923.pdf. \
Distrbutional RL is a new direction of RL, which is more stable than the traditional RL algorithm. \
The core idea of distributional RL is to estimate the distribution of action value instead of the \
expectation. The difference between IQN and DQN is that IQN uses quantile regression to estimate the \
quantile value of the action distribution, while DQN uses the expectation of the action distribution. \
Config:
== ==================== ======== ============== ======================================== =======================
ID Symbol Type Default Value Description Other(Shape)
== ==================== ======== ============== ======================================== =======================
1 ``type`` str qrdqn | RL policy register name, refer to | this arg is optional,
| registry ``POLICY_REGISTRY`` | a placeholder
2 ``cuda`` bool False | Whether to use cuda for network | this arg can be diff-
| erent from modes
3 ``on_policy`` bool False | Whether the RL algorithm is on-policy
| or off-policy
4 ``priority`` bool True | Whether use priority(PER) | priority sample,
| update priority
6 | ``other.eps`` float 0.05 | Start value for epsilon decay. It's
| ``.start`` | small because rainbow use noisy net.
7 | ``other.eps`` float 0.05 | End value for epsilon decay.
| ``.end``
8 | ``discount_`` float 0.97, | Reward's future discount factor, aka. | may be 1 when sparse
| ``factor`` [0.95, 0.999] | gamma | reward env
9 ``nstep`` int 3, | N-step reward discount sum for target
[3, 5] | q_value estimation
10 | ``learn.update`` int 3 | How many updates(iterations) to train | this args can be vary
| ``per_collect`` | after collector's one collection. Only | from envs. Bigger val
| valid in serial training | means more off-policy
11 ``learn.kappa`` float / | Threshold of Huber loss
== ==================== ======== ============== ======================================== =======================
"""
config = dict(
# (str) RL policy register name (refer to function "POLICY_REGISTRY").
type='iqn',
# (bool) Whether to use cuda for network.
cuda=False,
# (bool) Whether the RL algorithm is on-policy or off-policy.
on_policy=False,
# (bool) Whether use priority(priority sample, IS weight, update priority)
priority=False,
# (float) Reward's future discount factor, aka. gamma.
discount_factor=0.97,
# (int) N-step reward for target q_value estimation
nstep=1,
learn=dict(
# How many updates(iterations) to train after collector's one collection.
# Bigger "update_per_collect" means bigger off-policy.
# collect data -> update policy-> collect data -> ...
update_per_collect=3,
batch_size=64,
learning_rate=0.001,
# ==============================================================
# The following configs are algorithm-specific
# ==============================================================
# (int) Frequence of target network update.
target_update_freq=100,
# (float) Threshold of Huber loss. In the IQN paper, this is denoted by kappa. Default to 1.0.
kappa=1.0,
# (bool) Whether ignore done(usually for max step termination env)
ignore_done=False,
),
# collect_mode config
collect=dict(
# (int) Only one of [n_sample, n_step, n_episode] shoule be set
# n_sample=8,
# (int) Cut trajectories into pieces with length "unroll_len".
unroll_len=1,
),
eval=dict(),
# other config
other=dict(
# Epsilon greedy with decay.
eps=dict(
# (str) Decay type. Support ['exp', 'linear'].
type='exp',
start=0.95,
end=0.1,
# (int) Decay length(env step)
decay=10000,
),
replay_buffer=dict(replay_buffer_size=10000, )
),
)
def default_model(self) -> Tuple[str, List[str]]:
"""
Overview:
Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \
automatically call this method to get the default model setting and create model.
Returns:
- model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names.
.. note::
The user can define and use customized network model but must obey the same inferface definition indicated \
by import_names path. For example about IQN, its registered name is ``iqn`` and the import_names is \
``ding.model.template.q_learning``.
"""
return 'iqn', ['ding.model.template.q_learning']
def _init_learn(self) -> None:
"""
Overview:
Initialize the learn mode of policy, including related attributes and modules. For IQN, it mainly contains \
optimizer, algorithm-specific arguments such as nstep, kappa and gamma, main and target model.
This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``.
.. note::
For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \
and ``_load_state_dict_learn`` methods.
.. note::
For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method.
.. note::
If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \
with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``.
"""
self._priority = self._cfg.priority
# Optimizer
self._optimizer = Adam(self._model.parameters(), lr=self._cfg.learn.learning_rate)
self._gamma = self._cfg.discount_factor
self._nstep = self._cfg.nstep
self._kappa = self._cfg.learn.kappa
# use model_wrapper for specialized demands of different modes
self._target_model = copy.deepcopy(self._model)
self._target_model = model_wrap(
self._target_model,
wrapper_name='target',
update_type='assign',
update_kwargs={'freq': self._cfg.learn.target_update_freq}
)
self._learn_model = model_wrap(self._model, wrapper_name='argmax_sample')
self._learn_model.reset()
self._target_model.reset()
def _forward_learn(self, data: List[Dict[int, Any]]) -> Dict[str, Any]:
"""
Overview:
Policy forward function of learn mode (training policy and updating parameters). Forward means \
that the policy inputs some training batch data from the replay buffer and then returns the output \
result, including various training information such as loss, priority.
Arguments:
- data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \
training samples. For each element in list, the key of the dict is the name of data items and the \
value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \
combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \
dimension by some utility functions such as ``default_preprocess_learn``. \
For IQN, each element in list is a dict containing at least the following keys: ``obs``, ``action``, \
``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys such as ``weight`` \
and ``value_gamma``.
Returns:
- info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \
recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \
detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method.
.. note::
The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
For the data type that not supported, the main reason is that the corresponding model does not support it. \
You can implement you own model rather than use the default model. For more information, please raise an \
issue in GitHub repo and we will continue to follow up.
.. note::
For more detailed examples, please refer to our unittest for IQNPolicy: ``ding.policy.tests.test_iqn``.
"""
data = default_preprocess_learn(
data, use_priority=self._priority, ignore_done=self._cfg.learn.ignore_done, use_nstep=True
)
if self._cuda:
data = to_device(data, self._device)
# ====================
# Q-learning forward
# ====================
self._learn_model.train()
self._target_model.train()
# Current q value (main model)
ret = self._learn_model.forward(data['obs'])
q_value = ret['q']
replay_quantiles = ret['quantiles']
# Target q value
with torch.no_grad():
target_q_value = self._target_model.forward(data['next_obs'])['q']
# Max q value action (main model)
target_q_action = self._learn_model.forward(data['next_obs'])['action']
data_n = iqn_nstep_td_data(
q_value, target_q_value, data['action'], target_q_action, data['reward'], data['done'], replay_quantiles,
data['weight']
)
value_gamma = data.get('value_gamma')
loss, td_error_per_sample = iqn_nstep_td_error(
data_n, self._gamma, nstep=self._nstep, kappa=self._kappa, value_gamma=value_gamma
)
# ====================
# Q-learning update
# ====================
self._optimizer.zero_grad()
loss.backward()
if self._cfg.multi_gpu:
self.sync_gradients(self._learn_model)
self._optimizer.step()
# =============
# after update
# =============
self._target_model.update(self._learn_model.state_dict())
return {
'cur_lr': self._optimizer.defaults['lr'],
'total_loss': loss.item(),
'priority': td_error_per_sample.abs().tolist(),
# Only discrete action satisfying len(data['action'])==1 can return this and draw histogram on tensorboard.
# '[histogram]action_distribution': data['action'],
}
def _state_dict_learn(self) -> Dict[str, Any]:
"""
Overview:
Return the state_dict of learn mode, usually including model, target_model and optimizer.
Returns:
- state_dict (:obj:`Dict[str, Any]`): The dict of current policy learn state, for saving and restoring.
"""
return {
'model': self._learn_model.state_dict(),
'target_model': self._target_model.state_dict(),
'optimizer': self._optimizer.state_dict(),
}
def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
"""
Overview:
Load the state_dict variable into policy learn mode.
Arguments:
- state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before.
.. tip::
If you want to only load some parts of model, you can simply set the ``strict`` argument in \
load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \
complicated operation.
"""
self._learn_model.load_state_dict(state_dict['model'])
self._target_model.load_state_dict(state_dict['target_model'])
self._optimizer.load_state_dict(state_dict['optimizer'])