Source code for ding.policy.td3_bc
from typing import List, Dict, Any, Tuple, Union
from easydict import EasyDict
from collections import namedtuple
import torch
import torch.nn.functional as F
import copy
from ding.torch_utils import Adam, to_device
from ding.rl_utils import v_1step_td_data, v_1step_td_error, get_train_sample
from ding.model import model_wrap
from ding.utils import POLICY_REGISTRY
from ding.utils.data import default_collate, default_decollate
from .base_policy import Policy
from .common_utils import default_preprocess_learn
from .ddpg import DDPGPolicy
[docs]@POLICY_REGISTRY.register('td3_bc')
class TD3BCPolicy(DDPGPolicy):
r"""
Overview:
Policy class of TD3_BC algorithm.
Since DDPG and TD3 share many common things, we can easily derive this TD3_BC
class from DDPG class by changing ``_actor_update_freq``, ``_twin_critic`` and noise in model wrapper.
https://arxiv.org/pdf/2106.06860.pdf
Property:
learn_mode, collect_mode, eval_mode
Config:
== ==================== ======== ================== ================================= =======================
ID Symbol Type Default Value Description Other(Shape)
== ==================== ======== ================== ================================= =======================
1 ``type`` str td3_bc | RL policy register name, refer | this arg is optional,
| to registry ``POLICY_REGISTRY`` | a placeholder
2 ``cuda`` bool True | Whether to use cuda for network |
3 | ``random_`` int 25000 | Number of randomly collected | Default to 25000 for
| ``collect_size`` | training samples in replay | DDPG/TD3, 10000 for
| | buffer when training starts. | sac.
4 | ``model.twin_`` bool True | Whether to use two critic | Default True for TD3,
| ``critic`` | networks or only one. | Clipped Double
| | | Q-learning method in
| | | TD3 paper.
5 | ``learn.learning`` float 1e-3 | Learning rate for actor |
| ``_rate_actor`` | network(aka. policy). |
6 | ``learn.learning`` float 1e-3 | Learning rates for critic |
| ``_rate_critic`` | network (aka. Q-network). |
7 | ``learn.actor_`` int 2 | When critic network updates | Default 2 for TD3, 1
| ``update_freq`` | once, how many times will actor | for DDPG. Delayed
| | network update. | Policy Updates method
| | | in TD3 paper.
8 | ``learn.noise`` bool True | Whether to add noise on target | Default True for TD3,
| | network's action. | False for DDPG.
| | | Target Policy Smoo-
| | | thing Regularization
| | | in TD3 paper.
9 | ``learn.noise_`` dict | dict(min=-0.5, | Limit for range of target |
| ``range`` | max=0.5,) | policy smoothing noise, |
| | | aka. noise_clip. |
10 | ``learn.-`` bool False | Determine whether to ignore | Use ignore_done only
| ``ignore_done`` | done flag. | in halfcheetah env.
11 | ``learn.-`` float 0.005 | Used for soft update of the | aka. Interpolation
| ``target_theta`` | target network. | factor in polyak aver
| | | aging for target
| | | networks.
12 | ``collect.-`` float 0.1 | Used for add noise during co- | Sample noise from dis
| ``noise_sigma`` | llection, through controlling | tribution, Ornstein-
| | the sigma of distribution | Uhlenbeck process in
| | | DDPG paper, Guassian
| | | process in ours.
== ==================== ======== ================== ================================= =======================
"""
# You can refer to DDPG's default config for more details.
config = dict(
# (str) RL policy register name (refer to function "POLICY_REGISTRY").
type='td3_bc',
# (bool) Whether to use cuda for network.
cuda=False,
# (bool type) on_policy: Determine whether on-policy or off-policy.
# on-policy setting influences the behaviour of buffer.
# Default False in TD3.
on_policy=False,
# (bool) Whether use priority(priority sample, IS weight, update priority)
# Default False in TD3.
priority=False,
# (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True.
priority_IS_weight=False,
# (int) Number of training samples(randomly collected) in replay buffer when training starts.
# Default 25000 in DDPG/TD3.
random_collect_size=25000,
# (bool) Whether use batch normalization for reward
reward_batch_norm=False,
action_space='continuous',
model=dict(
# (bool) Whether to use two critic networks or only one.
# Clipped Double Q-Learning for Actor-Critic in original TD3 paper(https://arxiv.org/pdf/1802.09477.pdf).
# Default True for TD3, False for DDPG.
twin_critic=True,
# (str type) action_space: Use regression trick for continous action
action_space='regression',
# (int) Hidden size for actor network head.
actor_head_hidden_size=256,
# (int) Hidden size for critic network head.
critic_head_hidden_size=256,
),
learn=dict(
# How many updates(iterations) to train after collector's one collection.
# Bigger "update_per_collect" means bigger off-policy.
# collect data -> update policy-> collect data -> ...
update_per_collect=1,
# (int) Minibatch size for gradient descent.
batch_size=256,
# (float) Learning rates for actor network(aka. policy).
learning_rate_actor=1e-3,
# (float) Learning rates for critic network(aka. Q-network).
learning_rate_critic=1e-3,
# (bool) Whether ignore done(usually for max step termination env. e.g. pendulum)
# Note: Gym wraps the MuJoCo envs by default with TimeLimit environment wrappers.
# These limit HalfCheetah, and several other MuJoCo envs, to max length of 1000.
# However, interaction with HalfCheetah always gets done with False,
# Since we inplace done==True with done==False to keep
# TD-error accurate computation(``gamma * (1 - done) * next_v + reward``),
# when the episode step is greater than max episode step.
ignore_done=False,
# (float type) target_theta: Used for soft update of the target network,
# aka. Interpolation factor in polyak averaging for target networks.
# Default to 0.005.
target_theta=0.005,
# (float) discount factor for the discounted sum of rewards, aka. gamma.
discount_factor=0.99,
# (int) When critic network updates once, how many times will actor network update.
# Delayed Policy Updates in original TD3 paper(https://arxiv.org/pdf/1802.09477.pdf).
# Default 1 for DDPG, 2 for TD3.
actor_update_freq=2,
# (bool) Whether to add noise on target network's action.
# Target Policy Smoothing Regularization in original TD3 paper(https://arxiv.org/pdf/1802.09477.pdf).
# Default True for TD3, False for DDPG.
noise=True,
# (float) Sigma for smoothing noise added to target policy.
noise_sigma=0.2,
# (dict) Limit for range of target policy smoothing noise, aka. noise_clip.
noise_range=dict(
min=-0.5,
max=0.5,
),
alpha=2.5,
),
collect=dict(
# (int) Cut trajectories into pieces with length "unroll_len".
unroll_len=1,
# (float) It is a must to add noise during collection. So here omits "noise" and only set "noise_sigma".
noise_sigma=0.1,
),
eval=dict(
evaluator=dict(
# (int) Evaluate every "eval_freq" training iterations.
eval_freq=5000,
),
),
other=dict(
replay_buffer=dict(
# (int) Maximum size of replay buffer.
replay_buffer_size=1000000,
),
),
)
def default_model(self) -> Tuple[str, List[str]]:
return 'continuous_qac', ['ding.model.template.qac']
def _init_learn(self) -> None:
"""
Overview:
Learn mode init method. Called by ``self.__init__``. Init actor and critic optimizers, algorithm config.
"""
super(TD3BCPolicy, self)._init_learn()
self._alpha = self._cfg.learn.alpha
# actor and critic optimizer
self._optimizer_actor = Adam(
self._model.actor.parameters(),
lr=self._cfg.learn.learning_rate_actor,
grad_clip_type='clip_norm',
clip_value=1.0,
)
self._optimizer_critic = Adam(
self._model.critic.parameters(),
lr=self._cfg.learn.learning_rate_critic,
grad_clip_type='clip_norm',
clip_value=1.0,
)
self.noise_sigma = self._cfg.learn.noise_sigma
self.noise_range = self._cfg.learn.noise_range
def _forward_learn(self, data: dict) -> Dict[str, Any]:
r"""
Overview:
Forward and backward function of learn mode.
Arguments:
- data (:obj:`dict`): Dict type data, including at least ['obs', 'action', 'reward', 'next_obs']
Returns:
- info_dict (:obj:`Dict[str, Any]`): Including at least actor and critic lr, different losses.
"""
loss_dict = {}
data = default_preprocess_learn(
data,
use_priority=self._cfg.priority,
use_priority_IS_weight=self._cfg.priority_IS_weight,
ignore_done=self._cfg.learn.ignore_done,
use_nstep=False
)
if self._cuda:
data = to_device(data, self._device)
# ====================
# critic learn forward
# ====================
self._learn_model.train()
self._target_model.train()
next_obs = data['next_obs']
reward = data['reward']
if self._reward_batch_norm:
reward = (reward - reward.mean()) / (reward.std() + 1e-8)
# current q value
q_value = self._learn_model.forward(data, mode='compute_critic')['q_value']
q_value_dict = {}
if self._twin_critic:
q_value_dict['q_value'] = q_value[0].mean()
q_value_dict['q_value_twin'] = q_value[1].mean()
else:
q_value_dict['q_value'] = q_value.mean()
# target q value.
with torch.no_grad():
next_action = self._target_model.forward(next_obs, mode='compute_actor')['action']
noise = (torch.randn_like(next_action) *
self.noise_sigma).clamp(self.noise_range['min'], self.noise_range['max'])
next_action = (next_action + noise).clamp(-1, 1)
next_data = {'obs': next_obs, 'action': next_action}
target_q_value = self._target_model.forward(next_data, mode='compute_critic')['q_value']
if self._twin_critic:
# TD3: two critic networks
target_q_value = torch.min(target_q_value[0], target_q_value[1]) # find min one as target q value
# critic network1
td_data = v_1step_td_data(q_value[0], target_q_value, reward, data['done'], data['weight'])
critic_loss, td_error_per_sample1 = v_1step_td_error(td_data, self._gamma)
loss_dict['critic_loss'] = critic_loss
# critic network2(twin network)
td_data_twin = v_1step_td_data(q_value[1], target_q_value, reward, data['done'], data['weight'])
critic_twin_loss, td_error_per_sample2 = v_1step_td_error(td_data_twin, self._gamma)
loss_dict['critic_twin_loss'] = critic_twin_loss
td_error_per_sample = (td_error_per_sample1 + td_error_per_sample2) / 2
else:
# DDPG: single critic network
td_data = v_1step_td_data(q_value, target_q_value, reward, data['done'], data['weight'])
critic_loss, td_error_per_sample = v_1step_td_error(td_data, self._gamma)
loss_dict['critic_loss'] = critic_loss
# ================
# critic update
# ================
self._optimizer_critic.zero_grad()
for k in loss_dict:
if 'critic' in k:
loss_dict[k].backward()
self._optimizer_critic.step()
# ===============================
# actor learn forward and update
# ===============================
# actor updates every ``self._actor_update_freq`` iters
if (self._forward_learn_cnt + 1) % self._actor_update_freq == 0:
actor_data = self._learn_model.forward(data['obs'], mode='compute_actor')
actor_data['obs'] = data['obs']
if self._twin_critic:
q_value = self._learn_model.forward(actor_data, mode='compute_critic')['q_value'][0]
actor_loss = -q_value.mean()
else:
q_value = self._learn_model.forward(actor_data, mode='compute_critic')['q_value']
actor_loss = -q_value.mean()
# add behavior cloning loss weight(\lambda)
lmbda = self._alpha / q_value.abs().mean().detach()
# bc_loss = ((actor_data['action'] - data['action'])**2).mean()
bc_loss = F.mse_loss(actor_data['action'], data['action'])
actor_loss = lmbda * actor_loss + bc_loss
loss_dict['actor_loss'] = actor_loss
# actor update
self._optimizer_actor.zero_grad()
actor_loss.backward()
self._optimizer_actor.step()
# =============
# after update
# =============
loss_dict['total_loss'] = sum(loss_dict.values())
self._forward_learn_cnt += 1
self._target_model.update(self._learn_model.state_dict())
return {
'cur_lr_actor': self._optimizer_actor.defaults['lr'],
'cur_lr_critic': self._optimizer_critic.defaults['lr'],
# 'q_value': np.array(q_value).mean(),
'action': data.get('action').mean(),
'priority': td_error_per_sample.abs().tolist(),
'td_error': td_error_per_sample.abs().mean(),
**loss_dict,
**q_value_dict,
}
def _forward_eval(self, data: dict) -> dict:
r"""
Overview:
Forward function of eval mode, similar to ``self._forward_collect``.
Arguments:
- data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \
values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer.
Returns:
- output (:obj:`Dict[int, Any]`): The dict of predicting action for the interaction with env.
ReturnsKeys
- necessary: ``action``
- optional: ``logit``
"""
data_id = list(data.keys())
data = default_collate(list(data.values()))
if self._cuda:
data = to_device(data, self._device)
self._eval_model.eval()
with torch.no_grad():
output = self._eval_model.forward(data, mode='compute_actor')
if self._cuda:
output = to_device(output, 'cpu')
output = default_decollate(output)
return {i: d for i, d in zip(data_id, output)}