Shortcuts

Source code for ding.policy.td3_bc

from typing import List, Dict, Any, Tuple, Union
from easydict import EasyDict
from collections import namedtuple
import torch
import torch.nn.functional as F
import copy

from ding.torch_utils import Adam, to_device
from ding.rl_utils import v_1step_td_data, v_1step_td_error, get_train_sample
from ding.model import model_wrap
from ding.utils import POLICY_REGISTRY
from ding.utils.data import default_collate, default_decollate
from .base_policy import Policy
from .common_utils import default_preprocess_learn
from .ddpg import DDPGPolicy


[docs]@POLICY_REGISTRY.register('td3_bc') class TD3BCPolicy(DDPGPolicy): r""" Overview: Policy class of TD3_BC algorithm. Since DDPG and TD3 share many common things, we can easily derive this TD3_BC class from DDPG class by changing ``_actor_update_freq``, ``_twin_critic`` and noise in model wrapper. https://arxiv.org/pdf/2106.06860.pdf Property: learn_mode, collect_mode, eval_mode Config: == ==================== ======== ================== ================================= ======================= ID Symbol Type Default Value Description Other(Shape) == ==================== ======== ================== ================================= ======================= 1 ``type`` str td3_bc | RL policy register name, refer | this arg is optional, | to registry ``POLICY_REGISTRY`` | a placeholder 2 ``cuda`` bool True | Whether to use cuda for network | 3 | ``random_`` int 25000 | Number of randomly collected | Default to 25000 for | ``collect_size`` | training samples in replay | DDPG/TD3, 10000 for | | buffer when training starts. | sac. 4 | ``model.twin_`` bool True | Whether to use two critic | Default True for TD3, | ``critic`` | networks or only one. | Clipped Double | | | Q-learning method in | | | TD3 paper. 5 | ``learn.learning`` float 1e-3 | Learning rate for actor | | ``_rate_actor`` | network(aka. policy). | 6 | ``learn.learning`` float 1e-3 | Learning rates for critic | | ``_rate_critic`` | network (aka. Q-network). | 7 | ``learn.actor_`` int 2 | When critic network updates | Default 2 for TD3, 1 | ``update_freq`` | once, how many times will actor | for DDPG. Delayed | | network update. | Policy Updates method | | | in TD3 paper. 8 | ``learn.noise`` bool True | Whether to add noise on target | Default True for TD3, | | network's action. | False for DDPG. | | | Target Policy Smoo- | | | thing Regularization | | | in TD3 paper. 9 | ``learn.noise_`` dict | dict(min=-0.5, | Limit for range of target | | ``range`` | max=0.5,) | policy smoothing noise, | | | | aka. noise_clip. | 10 | ``learn.-`` bool False | Determine whether to ignore | Use ignore_done only | ``ignore_done`` | done flag. | in halfcheetah env. 11 | ``learn.-`` float 0.005 | Used for soft update of the | aka. Interpolation | ``target_theta`` | target network. | factor in polyak aver | | | aging for target | | | networks. 12 | ``collect.-`` float 0.1 | Used for add noise during co- | Sample noise from dis | ``noise_sigma`` | llection, through controlling | tribution, Ornstein- | | the sigma of distribution | Uhlenbeck process in | | | DDPG paper, Guassian | | | process in ours. == ==================== ======== ================== ================================= ======================= """ # You can refer to DDPG's default config for more details. config = dict( # (str) RL policy register name (refer to function "POLICY_REGISTRY"). type='td3_bc', # (bool) Whether to use cuda for network. cuda=False, # (bool type) on_policy: Determine whether on-policy or off-policy. # on-policy setting influences the behaviour of buffer. # Default False in TD3. on_policy=False, # (bool) Whether use priority(priority sample, IS weight, update priority) # Default False in TD3. priority=False, # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True. priority_IS_weight=False, # (int) Number of training samples(randomly collected) in replay buffer when training starts. # Default 25000 in DDPG/TD3. random_collect_size=25000, # (bool) Whether use batch normalization for reward reward_batch_norm=False, action_space='continuous', model=dict( # (bool) Whether to use two critic networks or only one. # Clipped Double Q-Learning for Actor-Critic in original TD3 paper(https://arxiv.org/pdf/1802.09477.pdf). # Default True for TD3, False for DDPG. twin_critic=True, # (str type) action_space: Use regression trick for continous action action_space='regression', # (int) Hidden size for actor network head. actor_head_hidden_size=256, # (int) Hidden size for critic network head. critic_head_hidden_size=256, ), learn=dict( # How many updates(iterations) to train after collector's one collection. # Bigger "update_per_collect" means bigger off-policy. # collect data -> update policy-> collect data -> ... update_per_collect=1, # (int) Minibatch size for gradient descent. batch_size=256, # (float) Learning rates for actor network(aka. policy). learning_rate_actor=1e-3, # (float) Learning rates for critic network(aka. Q-network). learning_rate_critic=1e-3, # (bool) Whether ignore done(usually for max step termination env. e.g. pendulum) # Note: Gym wraps the MuJoCo envs by default with TimeLimit environment wrappers. # These limit HalfCheetah, and several other MuJoCo envs, to max length of 1000. # However, interaction with HalfCheetah always gets done with False, # Since we inplace done==True with done==False to keep # TD-error accurate computation(``gamma * (1 - done) * next_v + reward``), # when the episode step is greater than max episode step. ignore_done=False, # (float type) target_theta: Used for soft update of the target network, # aka. Interpolation factor in polyak averaging for target networks. # Default to 0.005. target_theta=0.005, # (float) discount factor for the discounted sum of rewards, aka. gamma. discount_factor=0.99, # (int) When critic network updates once, how many times will actor network update. # Delayed Policy Updates in original TD3 paper(https://arxiv.org/pdf/1802.09477.pdf). # Default 1 for DDPG, 2 for TD3. actor_update_freq=2, # (bool) Whether to add noise on target network's action. # Target Policy Smoothing Regularization in original TD3 paper(https://arxiv.org/pdf/1802.09477.pdf). # Default True for TD3, False for DDPG. noise=True, # (float) Sigma for smoothing noise added to target policy. noise_sigma=0.2, # (dict) Limit for range of target policy smoothing noise, aka. noise_clip. noise_range=dict( min=-0.5, max=0.5, ), alpha=2.5, ), collect=dict( # (int) Cut trajectories into pieces with length "unroll_len". unroll_len=1, # (float) It is a must to add noise during collection. So here omits "noise" and only set "noise_sigma". noise_sigma=0.1, ), eval=dict( evaluator=dict( # (int) Evaluate every "eval_freq" training iterations. eval_freq=5000, ), ), other=dict( replay_buffer=dict( # (int) Maximum size of replay buffer. replay_buffer_size=1000000, ), ), ) def default_model(self) -> Tuple[str, List[str]]: return 'continuous_qac', ['ding.model.template.qac'] def _init_learn(self) -> None: """ Overview: Learn mode init method. Called by ``self.__init__``. Init actor and critic optimizers, algorithm config. """ super(TD3BCPolicy, self)._init_learn() self._alpha = self._cfg.learn.alpha # actor and critic optimizer self._optimizer_actor = Adam( self._model.actor.parameters(), lr=self._cfg.learn.learning_rate_actor, grad_clip_type='clip_norm', clip_value=1.0, ) self._optimizer_critic = Adam( self._model.critic.parameters(), lr=self._cfg.learn.learning_rate_critic, grad_clip_type='clip_norm', clip_value=1.0, ) self.noise_sigma = self._cfg.learn.noise_sigma self.noise_range = self._cfg.learn.noise_range def _forward_learn(self, data: dict) -> Dict[str, Any]: r""" Overview: Forward and backward function of learn mode. Arguments: - data (:obj:`dict`): Dict type data, including at least ['obs', 'action', 'reward', 'next_obs'] Returns: - info_dict (:obj:`Dict[str, Any]`): Including at least actor and critic lr, different losses. """ loss_dict = {} data = default_preprocess_learn( data, use_priority=self._cfg.priority, use_priority_IS_weight=self._cfg.priority_IS_weight, ignore_done=self._cfg.learn.ignore_done, use_nstep=False ) if self._cuda: data = to_device(data, self._device) # ==================== # critic learn forward # ==================== self._learn_model.train() self._target_model.train() next_obs = data['next_obs'] reward = data['reward'] if self._reward_batch_norm: reward = (reward - reward.mean()) / (reward.std() + 1e-8) # current q value q_value = self._learn_model.forward(data, mode='compute_critic')['q_value'] q_value_dict = {} if self._twin_critic: q_value_dict['q_value'] = q_value[0].mean() q_value_dict['q_value_twin'] = q_value[1].mean() else: q_value_dict['q_value'] = q_value.mean() # target q value. with torch.no_grad(): next_action = self._target_model.forward(next_obs, mode='compute_actor')['action'] noise = (torch.randn_like(next_action) * self.noise_sigma).clamp(self.noise_range['min'], self.noise_range['max']) next_action = (next_action + noise).clamp(-1, 1) next_data = {'obs': next_obs, 'action': next_action} target_q_value = self._target_model.forward(next_data, mode='compute_critic')['q_value'] if self._twin_critic: # TD3: two critic networks target_q_value = torch.min(target_q_value[0], target_q_value[1]) # find min one as target q value # critic network1 td_data = v_1step_td_data(q_value[0], target_q_value, reward, data['done'], data['weight']) critic_loss, td_error_per_sample1 = v_1step_td_error(td_data, self._gamma) loss_dict['critic_loss'] = critic_loss # critic network2(twin network) td_data_twin = v_1step_td_data(q_value[1], target_q_value, reward, data['done'], data['weight']) critic_twin_loss, td_error_per_sample2 = v_1step_td_error(td_data_twin, self._gamma) loss_dict['critic_twin_loss'] = critic_twin_loss td_error_per_sample = (td_error_per_sample1 + td_error_per_sample2) / 2 else: # DDPG: single critic network td_data = v_1step_td_data(q_value, target_q_value, reward, data['done'], data['weight']) critic_loss, td_error_per_sample = v_1step_td_error(td_data, self._gamma) loss_dict['critic_loss'] = critic_loss # ================ # critic update # ================ self._optimizer_critic.zero_grad() for k in loss_dict: if 'critic' in k: loss_dict[k].backward() self._optimizer_critic.step() # =============================== # actor learn forward and update # =============================== # actor updates every ``self._actor_update_freq`` iters if (self._forward_learn_cnt + 1) % self._actor_update_freq == 0: actor_data = self._learn_model.forward(data['obs'], mode='compute_actor') actor_data['obs'] = data['obs'] if self._twin_critic: q_value = self._learn_model.forward(actor_data, mode='compute_critic')['q_value'][0] actor_loss = -q_value.mean() else: q_value = self._learn_model.forward(actor_data, mode='compute_critic')['q_value'] actor_loss = -q_value.mean() # add behavior cloning loss weight(\lambda) lmbda = self._alpha / q_value.abs().mean().detach() # bc_loss = ((actor_data['action'] - data['action'])**2).mean() bc_loss = F.mse_loss(actor_data['action'], data['action']) actor_loss = lmbda * actor_loss + bc_loss loss_dict['actor_loss'] = actor_loss # actor update self._optimizer_actor.zero_grad() actor_loss.backward() self._optimizer_actor.step() # ============= # after update # ============= loss_dict['total_loss'] = sum(loss_dict.values()) self._forward_learn_cnt += 1 self._target_model.update(self._learn_model.state_dict()) return { 'cur_lr_actor': self._optimizer_actor.defaults['lr'], 'cur_lr_critic': self._optimizer_critic.defaults['lr'], # 'q_value': np.array(q_value).mean(), 'action': data.get('action').mean(), 'priority': td_error_per_sample.abs().tolist(), 'td_error': td_error_per_sample.abs().mean(), **loss_dict, **q_value_dict, } def _forward_eval(self, data: dict) -> dict: r""" Overview: Forward function of eval mode, similar to ``self._forward_collect``. Arguments: - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \ values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer. Returns: - output (:obj:`Dict[int, Any]`): The dict of predicting action for the interaction with env. ReturnsKeys - necessary: ``action`` - optional: ``logit`` """ data_id = list(data.keys()) data = default_collate(list(data.values())) if self._cuda: data = to_device(data, self._device) self._eval_model.eval() with torch.no_grad(): output = self._eval_model.forward(data, mode='compute_actor') if self._cuda: output = to_device(output, 'cpu') output = default_decollate(output) return {i: d for i, d in zip(data_id, output)}