Shortcuts

Source code for ding.policy.cql

from typing import List, Dict, Any, Tuple, Union
import copy
import numpy as np
import torch
import torch.nn.functional as F
from torch.distributions import Normal, Independent

from ding.torch_utils import Adam, to_device
from ding.rl_utils import v_1step_td_data, v_1step_td_error, get_train_sample, \
    qrdqn_nstep_td_data, qrdqn_nstep_td_error, get_nstep_return_data
from ding.model import model_wrap
from ding.utils import POLICY_REGISTRY
from ding.utils.data import default_collate, default_decollate
from .sac import SACPolicy
from .qrdqn import QRDQNPolicy
from .common_utils import default_preprocess_learn


[docs]@POLICY_REGISTRY.register('cql') class CQLPolicy(SACPolicy): """ Overview: Policy class of CQL algorithm for continuous control. Paper link: https://arxiv.org/abs/2006.04779. Config: == ==================== ======== ============= ================================= ======================= ID Symbol Type Default Value Description Other(Shape) == ==================== ======== ============= ================================= ======================= 1 ``type`` str cql | RL policy register name, refer | this arg is optional, | to registry ``POLICY_REGISTRY`` | a placeholder 2 ``cuda`` bool True | Whether to use cuda for network | 3 | ``random_`` int 10000 | Number of randomly collected | Default to 10000 for | ``collect_size`` | training samples in replay | SAC, 25000 for DDPG/ | | buffer when training starts. | TD3. 4 | ``model.policy_`` int 256 | Linear layer size for policy | | ``embedding_size`` | network. | 5 | ``model.soft_q_`` int 256 | Linear layer size for soft q | | ``embedding_size`` | network. | 6 | ``model.value_`` int 256 | Linear layer size for value | Defalut to None when | ``embedding_size`` | network. | model.value_network | | | is False. 7 | ``learn.learning`` float 3e-4 | Learning rate for soft q | Defalut to 1e-3, when | ``_rate_q`` | network. | model.value_network | | | is True. 8 | ``learn.learning`` float 3e-4 | Learning rate for policy | Defalut to 1e-3, when | ``_rate_policy`` | network. | model.value_network | | | is True. 9 | ``learn.learning`` float 3e-4 | Learning rate for policy | Defalut to None when | ``_rate_value`` | network. | model.value_network | | | is False. 10 | ``learn.alpha`` float 0.2 | Entropy regularization | alpha is initiali- | | coefficient. | zation for auto | | | `alpha`, when | | | auto_alpha is True 11 | ``learn.repara_`` bool True | Determine whether to use | | ``meterization`` | reparameterization trick. | 12 | ``learn.`` bool False | Determine whether to use | Temperature parameter | ``auto_alpha`` | auto temperature parameter | determines the | | `alpha`. | relative importance | | | of the entropy term | | | against the reward. 13 | ``learn.-`` bool False | Determine whether to ignore | Use ignore_done only | ``ignore_done`` | done flag. | in halfcheetah env. 14 | ``learn.-`` float 0.005 | Used for soft update of the | aka. Interpolation | ``target_theta`` | target network. | factor in polyak aver | | | aging for target | | | networks. == ==================== ======== ============= ================================= ======================= """ config = dict( # (str) RL policy register name (refer to function "POLICY_REGISTRY"). type='cql', # (bool) Whether to use cuda for policy. cuda=False, # (bool) on_policy: Determine whether on-policy or off-policy. # on-policy setting influences the behaviour of buffer. on_policy=False, # (bool) priority: Determine whether to use priority in buffer sample. priority=False, # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True. priority_IS_weight=False, # (int) Number of training samples(randomly collected) in replay buffer when training starts. random_collect_size=10000, model=dict( # (bool type) twin_critic: Determine whether to use double-soft-q-net for target q computation. # Please refer to TD3 about Clipped Double-Q Learning trick, which learns two Q-functions instead of one . # Default to True. twin_critic=True, # (str type) action_space: Use reparameterization trick for continous action action_space='reparameterization', # (int) Hidden size for actor network head. actor_head_hidden_size=256, # (int) Hidden size for critic network head. critic_head_hidden_size=256, ), # learn_mode config learn=dict( # (int) How many updates (iterations) to train after collector's one collection. # Bigger "update_per_collect" means bigger off-policy. update_per_collect=1, # (int) Minibatch size for gradient descent. batch_size=256, # (float) learning_rate_q: Learning rate for soft q network. learning_rate_q=3e-4, # (float) learning_rate_policy: Learning rate for policy network. learning_rate_policy=3e-4, # (float) learning_rate_alpha: Learning rate for auto temperature parameter ``alpha``. learning_rate_alpha=3e-4, # (float) target_theta: Used for soft update of the target network, # aka. Interpolation factor in polyak averaging for target networks. target_theta=0.005, # (float) discount factor for the discounted sum of rewards, aka. gamma. discount_factor=0.99, # (float) alpha: Entropy regularization coefficient. # Please check out the original SAC paper (arXiv 1801.01290): Eq 1 for more details. # If auto_alpha is set to `True`, alpha is initialization for auto `\alpha`. # Default to 0.2. alpha=0.2, # (bool) auto_alpha: Determine whether to use auto temperature parameter `\alpha` . # Temperature parameter determines the relative importance of the entropy term against the reward. # Please check out the original SAC paper (arXiv 1801.01290): Eq 1 for more details. # Default to False. # Note that: Using auto alpha needs to set learning_rate_alpha in `cfg.policy.learn`. auto_alpha=True, # (bool) log_space: Determine whether to use auto `\alpha` in log space. log_space=True, # (bool) Whether ignore done(usually for max step termination env. e.g. pendulum) # Note: Gym wraps the MuJoCo envs by default with TimeLimit environment wrappers. # These limit HalfCheetah, and several other MuJoCo envs, to max length of 1000. # However, interaction with HalfCheetah always gets done with done is False, # Since we inplace done==True with done==False to keep # TD-error accurate computation(``gamma * (1 - done) * next_v + reward``), # when the episode step is greater than max episode step. ignore_done=False, # (float) Weight uniform initialization range in the last output layer. init_w=3e-3, # (int) The numbers of action sample each at every state s from a uniform-at-random. num_actions=10, # (bool) Whether use lagrange multiplier in q value loss. with_lagrange=False, # (float) The threshold for difference in Q-values. lagrange_thresh=-1, # (float) Loss weight for conservative item. min_q_weight=1.0, # (bool) Whether to use entropy in target q. with_q_entropy=False, ), eval=dict(), # for compatibility )
[docs] def _init_learn(self) -> None: """ Overview: Initialize the learn mode of policy, including related attributes and modules. For SAC, it mainly \ contains three optimizers, algorithm-specific arguments such as gamma, min_q_weight, with_lagrange and \ with_q_entropy, main and target model. Especially, the ``auto_alpha`` mechanism for balancing max entropy \ target is also initialized here. This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. .. note:: For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ and ``_load_state_dict_learn`` methods. .. note:: For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. .. note:: If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. """ self._priority = self._cfg.priority self._priority_IS_weight = self._cfg.priority_IS_weight self._twin_critic = self._cfg.model.twin_critic self._num_actions = self._cfg.learn.num_actions self._min_q_version = 3 self._min_q_weight = self._cfg.learn.min_q_weight self._with_lagrange = self._cfg.learn.with_lagrange and (self._lagrange_thresh > 0) self._lagrange_thresh = self._cfg.learn.lagrange_thresh if self._with_lagrange: self.target_action_gap = self._lagrange_thresh self.log_alpha_prime = torch.tensor(0.).to(self._device).requires_grad_() self.alpha_prime_optimizer = Adam( [self.log_alpha_prime], lr=self._cfg.learn.learning_rate_q, ) self._with_q_entropy = self._cfg.learn.with_q_entropy # Weight Init init_w = self._cfg.learn.init_w self._model.actor_head[-1].mu.weight.data.uniform_(-init_w, init_w) self._model.actor_head[-1].mu.bias.data.uniform_(-init_w, init_w) self._model.actor_head[-1].log_sigma_layer.weight.data.uniform_(-init_w, init_w) self._model.actor_head[-1].log_sigma_layer.bias.data.uniform_(-init_w, init_w) if self._twin_critic: self._model.critic_head[0][-1].last.weight.data.uniform_(-init_w, init_w) self._model.critic_head[0][-1].last.bias.data.uniform_(-init_w, init_w) self._model.critic_head[1][-1].last.weight.data.uniform_(-init_w, init_w) self._model.critic_head[1][-1].last.bias.data.uniform_(-init_w, init_w) else: self._model.critic_head[2].last.weight.data.uniform_(-init_w, init_w) self._model.critic_head[-1].last.bias.data.uniform_(-init_w, init_w) # Optimizers self._optimizer_q = Adam( self._model.critic.parameters(), lr=self._cfg.learn.learning_rate_q, ) self._optimizer_policy = Adam( self._model.actor.parameters(), lr=self._cfg.learn.learning_rate_policy, ) # Algorithm config self._gamma = self._cfg.learn.discount_factor # Init auto alpha if self._cfg.learn.auto_alpha: if self._cfg.learn.target_entropy is None: assert 'action_shape' in self._cfg.model, "CQL need network model with action_shape variable" self._target_entropy = -np.prod(self._cfg.model.action_shape) else: self._target_entropy = self._cfg.learn.target_entropy if self._cfg.learn.log_space: self._log_alpha = torch.log(torch.FloatTensor([self._cfg.learn.alpha])) self._log_alpha = self._log_alpha.to(self._device).requires_grad_() self._alpha_optim = torch.optim.Adam([self._log_alpha], lr=self._cfg.learn.learning_rate_alpha) assert self._log_alpha.shape == torch.Size([1]) and self._log_alpha.requires_grad self._alpha = self._log_alpha.detach().exp() self._auto_alpha = True self._log_space = True else: self._alpha = torch.FloatTensor([self._cfg.learn.alpha]).to(self._device).requires_grad_() self._alpha_optim = torch.optim.Adam([self._alpha], lr=self._cfg.learn.learning_rate_alpha) self._auto_alpha = True self._log_space = False else: self._alpha = torch.tensor( [self._cfg.learn.alpha], requires_grad=False, device=self._device, dtype=torch.float32 ) self._auto_alpha = False # Main and target models self._target_model = copy.deepcopy(self._model) self._target_model = model_wrap( self._target_model, wrapper_name='target', update_type='momentum', update_kwargs={'theta': self._cfg.learn.target_theta} ) self._learn_model = model_wrap(self._model, wrapper_name='base') self._learn_model.reset() self._target_model.reset() self._forward_learn_cnt = 0
[docs] def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: """ Overview: Policy forward function of learn mode (training policy and updating parameters). Forward means \ that the policy inputs some training batch data from the offline dataset and then returns the output \ result, including various training information such as loss, action, priority. Arguments: - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \ training samples. For each element in list, the key of the dict is the name of data items and the \ value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \ combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \ dimension by some utility functions such as ``default_preprocess_learn``. \ For CQL, each element in list is a dict containing at least the following keys: ``obs``, ``action``, \ ``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys such as ``weight``. Returns: - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \ recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \ detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. .. note:: The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ For the data type that not supported, the main reason is that the corresponding model does not support it. \ You can implement you own model rather than use the default model. For more information, please raise an \ issue in GitHub repo and we will continue to follow up. """ loss_dict = {} data = default_preprocess_learn( data, use_priority=self._priority, use_priority_IS_weight=self._cfg.priority_IS_weight, ignore_done=self._cfg.learn.ignore_done, use_nstep=False ) if len(data.get('action').shape) == 1: data['action'] = data['action'].reshape(-1, 1) if self._cuda: data = to_device(data, self._device) self._learn_model.train() self._target_model.train() obs = data['obs'] next_obs = data['next_obs'] reward = data['reward'] done = data['done'] # 1. predict q value q_value = self._learn_model.forward(data, mode='compute_critic')['q_value'] # 2. predict target value with torch.no_grad(): (mu, sigma) = self._learn_model.forward(next_obs, mode='compute_actor')['logit'] dist = Independent(Normal(mu, sigma), 1) pred = dist.rsample() next_action = torch.tanh(pred) y = 1 - next_action.pow(2) + 1e-6 next_log_prob = dist.log_prob(pred).unsqueeze(-1) next_log_prob = next_log_prob - torch.log(y).sum(-1, keepdim=True) next_data = {'obs': next_obs, 'action': next_action} target_q_value = self._target_model.forward(next_data, mode='compute_critic')['q_value'] # the value of a policy according to the maximum entropy objective if self._twin_critic: # find min one as target q value if self._with_q_entropy: target_q_value = torch.min(target_q_value[0], target_q_value[1]) - self._alpha * next_log_prob.squeeze(-1) else: target_q_value = torch.min(target_q_value[0], target_q_value[1]) else: if self._with_q_entropy: target_q_value = target_q_value - self._alpha * next_log_prob.squeeze(-1) # 3. compute q loss if self._twin_critic: q_data0 = v_1step_td_data(q_value[0], target_q_value, reward, done, data['weight']) loss_dict['critic_loss'], td_error_per_sample0 = v_1step_td_error(q_data0, self._gamma) q_data1 = v_1step_td_data(q_value[1], target_q_value, reward, done, data['weight']) loss_dict['twin_critic_loss'], td_error_per_sample1 = v_1step_td_error(q_data1, self._gamma) td_error_per_sample = (td_error_per_sample0 + td_error_per_sample1) / 2 else: q_data = v_1step_td_data(q_value, target_q_value, reward, done, data['weight']) loss_dict['critic_loss'], td_error_per_sample = v_1step_td_error(q_data, self._gamma) # 4. add CQL curr_actions_tensor, curr_log_pis = self._get_policy_actions(data, self._num_actions) new_curr_actions_tensor, new_log_pis = self._get_policy_actions({'obs': next_obs}, self._num_actions) random_actions_tensor = torch.FloatTensor(curr_actions_tensor.shape).uniform_(-1, 1).to(curr_actions_tensor.device) obs_repeat = obs.unsqueeze(1).repeat(1, self._num_actions, 1).view(obs.shape[0] * self._num_actions, obs.shape[1]) act_repeat = data['action'].unsqueeze(1).repeat(1, self._num_actions, 1).view( data['action'].shape[0] * self._num_actions, data['action'].shape[1] ) q_rand = self._get_q_value({'obs': obs_repeat, 'action': random_actions_tensor}) # q2_rand = self._get_q_value(obs, random_actions_tensor, network=self.qf2) q_curr_actions = self._get_q_value({'obs': obs_repeat, 'action': curr_actions_tensor}) # q2_curr_actions = self._get_tensor_values(obs, curr_actions_tensor, network=self.qf2) q_next_actions = self._get_q_value({'obs': obs_repeat, 'action': new_curr_actions_tensor}) # q2_next_actions = self._get_tensor_values(obs, new_curr_actions_tensor, network=self.qf2) cat_q1 = torch.cat([q_rand[0], q_value[0].reshape(-1, 1, 1), q_next_actions[0], q_curr_actions[0]], 1) cat_q2 = torch.cat([q_rand[1], q_value[1].reshape(-1, 1, 1), q_next_actions[1], q_curr_actions[1]], 1) std_q1 = torch.std(cat_q1, dim=1) std_q2 = torch.std(cat_q2, dim=1) if self._min_q_version == 3: # importance sampled version random_density = np.log(0.5 ** curr_actions_tensor.shape[-1]) cat_q1 = torch.cat( [ q_rand[0] - random_density, q_next_actions[0] - new_log_pis.detach(), q_curr_actions[0] - curr_log_pis.detach() ], 1 ) cat_q2 = torch.cat( [ q_rand[1] - random_density, q_next_actions[1] - new_log_pis.detach(), q_curr_actions[1] - curr_log_pis.detach() ], 1 ) min_qf1_loss = torch.logsumexp(cat_q1, dim=1).mean() * self._min_q_weight min_qf2_loss = torch.logsumexp(cat_q2, dim=1).mean() * self._min_q_weight """Subtract the log likelihood of data""" min_qf1_loss = min_qf1_loss - q_value[0].mean() * self._min_q_weight min_qf2_loss = min_qf2_loss - q_value[1].mean() * self._min_q_weight if self._with_lagrange: alpha_prime = torch.clamp(self.log_alpha_prime.exp(), min=0.0, max=1000000.0) min_qf1_loss = alpha_prime * (min_qf1_loss - self.target_action_gap) min_qf2_loss = alpha_prime * (min_qf2_loss - self.target_action_gap) self.alpha_prime_optimizer.zero_grad() alpha_prime_loss = (-min_qf1_loss - min_qf2_loss) * 0.5 alpha_prime_loss.backward(retain_graph=True) self.alpha_prime_optimizer.step() loss_dict['critic_loss'] += min_qf1_loss if self._twin_critic: loss_dict['twin_critic_loss'] += min_qf2_loss # 5. update q network self._optimizer_q.zero_grad() loss_dict['critic_loss'].backward(retain_graph=True) if self._twin_critic: loss_dict['twin_critic_loss'].backward() self._optimizer_q.step() # 6. evaluate to get action distribution (mu, sigma) = self._learn_model.forward(data['obs'], mode='compute_actor')['logit'] dist = Independent(Normal(mu, sigma), 1) pred = dist.rsample() action = torch.tanh(pred) y = 1 - action.pow(2) + 1e-6 log_prob = dist.log_prob(pred).unsqueeze(-1) log_prob = log_prob - torch.log(y).sum(-1, keepdim=True) eval_data = {'obs': obs, 'action': action} new_q_value = self._learn_model.forward(eval_data, mode='compute_critic')['q_value'] if self._twin_critic: new_q_value = torch.min(new_q_value[0], new_q_value[1]) # 8. compute policy loss policy_loss = (self._alpha * log_prob - new_q_value.unsqueeze(-1)).mean() loss_dict['policy_loss'] = policy_loss # 9. update policy network self._optimizer_policy.zero_grad() loss_dict['policy_loss'].backward() self._optimizer_policy.step() # 10. compute alpha loss if self._auto_alpha: if self._log_space: log_prob = log_prob + self._target_entropy loss_dict['alpha_loss'] = -(self._log_alpha * log_prob.detach()).mean() self._alpha_optim.zero_grad() loss_dict['alpha_loss'].backward() self._alpha_optim.step() self._alpha = self._log_alpha.detach().exp() else: log_prob = log_prob + self._target_entropy loss_dict['alpha_loss'] = -(self._alpha * log_prob.detach()).mean() self._alpha_optim.zero_grad() loss_dict['alpha_loss'].backward() self._alpha_optim.step() self._alpha = max(0, self._alpha) loss_dict['total_loss'] = sum(loss_dict.values()) # ============= # after update # ============= self._forward_learn_cnt += 1 # target update self._target_model.update(self._learn_model.state_dict()) return { 'cur_lr_q': self._optimizer_q.defaults['lr'], 'cur_lr_p': self._optimizer_policy.defaults['lr'], 'priority': td_error_per_sample.abs().tolist(), 'td_error': td_error_per_sample.detach().mean().item(), 'alpha': self._alpha.item(), 'target_q_value': target_q_value.detach().mean().item(), **loss_dict }
def _get_policy_actions(self, data: Dict, num_actions: int = 10, epsilon: float = 1e-6) -> List: # evaluate to get action distribution obs = data['obs'] obs = obs.unsqueeze(1).repeat(1, num_actions, 1).view(obs.shape[0] * num_actions, obs.shape[1]) (mu, sigma) = self._learn_model.forward(obs, mode='compute_actor')['logit'] dist = Independent(Normal(mu, sigma), 1) pred = dist.rsample() action = torch.tanh(pred) # evaluate action log prob depending on Jacobi determinant. y = 1 - action.pow(2) + epsilon log_prob = dist.log_prob(pred).unsqueeze(-1) log_prob = log_prob - torch.log(y).sum(-1, keepdim=True) return action, log_prob.view(-1, num_actions, 1) def _get_q_value(self, data: Dict, keep: bool = True) -> torch.Tensor: new_q_value = self._learn_model.forward(data, mode='compute_critic')['q_value'] if self._twin_critic: new_q_value = [value.view(-1, self._num_actions, 1) for value in new_q_value] else: new_q_value = new_q_value.view(-1, self._num_actions, 1) if self._twin_critic and not keep: new_q_value = torch.min(new_q_value[0], new_q_value[1]) return new_q_value
[docs]@POLICY_REGISTRY.register('discrete_cql') class DiscreteCQLPolicy(QRDQNPolicy): """ Overview: Policy class of discrete CQL algorithm in discrete action space environments. Paper link: https://arxiv.org/abs/2006.04779. """ config = dict( # (str) RL policy register name (refer to function "POLICY_REGISTRY"). type='discrete_cql', # (bool) Whether to use cuda for policy. cuda=False, # (bool) Whether the RL algorithm is on-policy or off-policy. on_policy=False, # (bool) Whether use priority(priority sample, IS weight, update priority) priority=False, # (float) Reward's future discount factor, aka. gamma. discount_factor=0.97, # (int) N-step reward for target q_value estimation nstep=1, # learn_mode config learn=dict( # (int) How many updates (iterations) to train after collector's one collection. # Bigger "update_per_collect" means bigger off-policy. update_per_collect=1, # (int) Minibatch size for one gradient descent. batch_size=64, # (float) Learning rate for soft q network. learning_rate=0.001, # (int) Frequence of target network update. target_update_freq=100, # (bool) Whether ignore done(usually for max step termination env). ignore_done=False, # (float) Loss weight for conservative item. min_q_weight=1.0, ), eval=dict(), # for compatibility )
[docs] def _init_learn(self) -> None: """ Overview: Initialize the learn mode of policy, including related attributes and modules. For DiscreteCQL, it mainly \ contains the optimizer, algorithm-specific arguments such as gamma, nstep and min_q_weight, main and \ target model. This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. .. note:: For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ and ``_load_state_dict_learn`` methods. .. note:: For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. .. note:: If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. """ self._min_q_weight = self._cfg.learn.min_q_weight self._priority = self._cfg.priority # Optimizer self._optimizer = Adam(self._model.parameters(), lr=self._cfg.learn.learning_rate) self._gamma = self._cfg.discount_factor self._nstep = self._cfg.nstep # use wrapper instead of plugin self._target_model = copy.deepcopy(self._model) self._target_model = model_wrap( self._target_model, wrapper_name='target', update_type='assign', update_kwargs={'freq': self._cfg.learn.target_update_freq} ) self._learn_model = model_wrap(self._model, wrapper_name='argmax_sample') self._learn_model.reset() self._target_model.reset()
[docs] def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: """ Overview: Policy forward function of learn mode (training policy and updating parameters). Forward means \ that the policy inputs some training batch data from the offline dataset and then returns the output \ result, including various training information such as loss, action, priority. Arguments: - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \ training samples. For each element in list, the key of the dict is the name of data items and the \ value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \ combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \ dimension by some utility functions such as ``default_preprocess_learn``. \ For DiscreteCQL, each element in list is a dict containing at least the following keys: ``obs``, \ ``action``, ``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys like ``weight`` \ and ``value_gamma`` for nstep return computation. Returns: - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \ recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \ detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. .. note:: The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ For the data type that not supported, the main reason is that the corresponding model does not support it. \ You can implement you own model rather than use the default model. For more information, please raise an \ issue in GitHub repo and we will continue to follow up. """ data = default_preprocess_learn( data, use_priority=self._priority, ignore_done=self._cfg.learn.ignore_done, use_nstep=True ) if self._cuda: data = to_device(data, self._device) if data['action'].dim() == 2 and data['action'].shape[-1] == 1: data['action'] = data['action'].squeeze(-1) # ==================== # Q-learning forward # ==================== self._learn_model.train() self._target_model.train() # Current q value (main model) ret = self._learn_model.forward(data['obs']) q_value, tau = ret['q'], ret['tau'] # Target q value with torch.no_grad(): target_q_value = self._target_model.forward(data['next_obs'])['q'] # Max q value action (main model) target_q_action = self._learn_model.forward(data['next_obs'])['action'] # add CQL # 1. chose action and compute q in dataset. # 2. compute value loss(negative_sampling - dataset_expec) replay_action_one_hot = F.one_hot(data['action'], self._cfg.model.action_shape) replay_chosen_q = (q_value.mean(-1) * replay_action_one_hot).sum(dim=1) dataset_expec = replay_chosen_q.mean() negative_sampling = torch.logsumexp(q_value.mean(-1), dim=1).mean() min_q_loss = negative_sampling - dataset_expec data_n = qrdqn_nstep_td_data( q_value, target_q_value, data['action'], target_q_action, data['reward'], data['done'], tau, data['weight'] ) value_gamma = data.get('value_gamma') loss, td_error_per_sample = qrdqn_nstep_td_error( data_n, self._gamma, nstep=self._nstep, value_gamma=value_gamma ) loss += self._min_q_weight * min_q_loss # ==================== # Q-learning update # ==================== self._optimizer.zero_grad() loss.backward() if self._cfg.multi_gpu: self.sync_gradients(self._learn_model) self._optimizer.step() # ============= # after update # ============= self._target_model.update(self._learn_model.state_dict()) return { 'cur_lr': self._optimizer.defaults['lr'], 'total_loss': loss.item(), 'priority': td_error_per_sample.abs().tolist(), 'q_target': target_q_value.mean().item(), 'q_value': q_value.mean().item(), # Only discrete action satisfying len(data['action'])==1 can return this and draw histogram on tensorboard. # '[histogram]action_distribution': data['action'], }
[docs] def _monitor_vars_learn(self) -> List[str]: """ Overview: Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \ as text logger, tensorboard logger, will use these keys to save the corresponding data. Returns: - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged. """ return ['cur_lr', 'total_loss', 'q_target', 'q_value']