Shortcuts

Source code for ding.policy.ppg

from typing import List, Dict, Any, Tuple, Union
from collections import namedtuple
import copy
import torch
from torch.utils.data import Dataset, DataLoader

from ding.utils import POLICY_REGISTRY, split_data_generator, RunningMeanStd
from ding.utils.data import default_collate, default_decollate
from ding.torch_utils import Adam, to_device
from ding.rl_utils import get_gae_with_default_last_value, get_train_sample, gae, gae_data, get_gae, \
    ppo_policy_data, ppo_policy_error, ppo_value_data, ppo_value_error, ppg_data, ppg_joint_error
from ding.model import model_wrap
from .base_policy import Policy


class ExperienceDataset(Dataset):
    """
    Overview:
        A dataset class for storing and accessing experience data.

    Interface:
        ``__init__``, ``__len__``, ``__getitem__``.
    """

    def __init__(self, data):
        """
        Arguments:
            - data (:obj:`dict`): A dictionary containing the experience data, where the keys represent the data types \
                and the values are the corresponding data arrays.
        """
        super().__init__()
        self.data = data

    def __len__(self):
        return list(self.data.values())[0].shape[0]

    def __getitem__(self, ind):
        data = {}
        for key in self.data.keys():
            data[key] = self.data[key][ind]
        return data


def create_shuffled_dataloader(data, batch_size):
    ds = ExperienceDataset(data)
    return DataLoader(ds, batch_size=batch_size, shuffle=True)


[docs]@POLICY_REGISTRY.register('ppg') class PPGPolicy(Policy): """ Overview: Policy class of PPG algorithm. PPG is a policy gradient algorithm with auxiliary phase training. \ The auxiliary phase training is proposed to distill the value into the policy network, \ while making sure the policy network does not change the action predictions (kl div loss). \ Paper link: https://arxiv.org/abs/2009.04416. Interface: ``_init_learn``, ``_data_preprocess_learn``, ``_forward_learn``, ``_state_dict_learn``, \ ``_load_state_dict_learn``, ``_init_collect``, ``_forward_collect``, ``_process_transition``, \ ``_get_train_sample``, ``_get_batch_size``, ``_init_eval``, ``_forward_eval``, ``default_model``, \ ``_monitor_vars_learn``, ``learn_aux``. Config: == ==================== ======== ============== ======================================== ======================= ID Symbol Type Default Value Description Other(Shape) == ==================== ======== ============== ======================================== ======================= 1 ``type`` str ppg | RL policy register name, refer to | this arg is optional, | registry ``POLICY_REGISTRY`` | a placeholder 2 ``cuda`` bool False | Whether to use cuda for network | this arg can be diff- | erent from modes 3 ``on_policy`` bool True | Whether the RL algorithm is on-policy | or off-policy 4. ``priority`` bool False | Whether use priority(PER) | priority sample, | update priority 5 | ``priority_`` bool False | Whether use Importance Sampling | IS weight | ``IS_weight`` | Weight to correct biased update. 6 | ``learn.update`` int 5 | How many updates(iterations) to train | this args can be vary | ``_per_collect`` | after collector's one collection. Only | from envs. Bigger val | valid in serial training | means more off-policy 7 | ``learn.value_`` float 1.0 | The loss weight of value network | policy network weight | ``weight`` | is set to 1 8 | ``learn.entropy_`` float 0.01 | The loss weight of entropy | policy network weight | ``weight`` | regularization | is set to 1 9 | ``learn.clip_`` float 0.2 | PPO clip ratio | ``ratio`` 10 | ``learn.adv_`` bool False | Whether to use advantage norm in | ``norm`` | a whole training batch 11 | ``learn.aux_`` int 5 | The frequency(normal update times) | ``freq`` | of auxiliary phase training 12 | ``learn.aux_`` int 6 | The training epochs of auxiliary | ``train_epoch`` | phase 13 | ``learn.aux_`` int 1 | The loss weight of behavioral_cloning | ``bc_weight`` | in auxiliary phase 14 | ``collect.dis`` float 0.99 | Reward's future discount factor, aka. | may be 1 when sparse | ``count_factor`` | gamma | reward env 15 | ``collect.gae_`` float 0.95 | GAE lambda factor for the balance | ``lambda`` | of bias and variance(1-step td and mc) == ==================== ======== ============== ======================================== ======================= """ config = dict( # (str) RL policy register name (refer to function "POLICY_REGISTRY"). type='ppg', # (bool) Whether to use cuda for network. cuda=False, # (bool) Whether the RL algorithm is on-policy or off-policy. (Note: in practice PPO can be off-policy used) on_policy=True, priority=False, # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True. priority_IS_weight=False, learn=dict( actor_epoch_per_collect=1, critic_epoch_per_collect=1, batch_size=64, learning_rate=0.001, # ============================================================== # The following configs is algorithm-specific # ============================================================== # (float) The loss weight of value network, policy network weight is set to 1 value_weight=0.5, # (float) The loss weight of entropy regularization, policy network weight is set to 1 entropy_weight=0.01, # (float) PPO clip ratio, defaults to 0.2 clip_ratio=0.2, value_norm=False, # (bool) Whether to use advantage norm in a whole training batch adv_norm=False, # (int) The frequency(normal update times) of auxiliary phase training aux_freq=8, # (int) The training epochs of auxiliary phase aux_train_epoch=6, # (int) The loss weight of behavioral_cloning in auxiliary phase aux_bc_weight=1, grad_clip_type='clip_norm', grad_clip_value=10, ignore_done=False, ), collect=dict( # n_sample=64, unroll_len=1, # ============================================================== # The following configs is algorithm-specific # ============================================================== # (float) Reward's future discount factor, aka. gamma. discount_factor=0.99, # (float) GAE lambda factor for the balance of bias and variance(1-step td and mc) gae_lambda=0.95, ), eval=dict(), ) def default_model(self) -> Tuple[str, List[str]]: """ Overview: Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \ automatically call this method to get the default model setting and create model. Returns: - model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names. """ return 'ppg', ['ding.model.template.ppg'] def _init_learn(self) -> None: """ Overview: Initialize the learn mode of policy, including related attributes and modules. For PPG, it mainly \ contains optimizer, algorithm-specific arguments such as aux_bc_weight and aux_train_epoch. This method \ also executes some special network initializations and prepares running mean/std monitor for value. \ This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. .. note:: For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ and ``_load_state_dict_learn`` methods. .. note:: For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. .. note:: If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. """ # Optimizer self._optimizer_ac = Adam(self._model.actor_critic.parameters(), lr=self._cfg.learn.learning_rate) self._optimizer_aux_critic = Adam(self._model.aux_critic.parameters(), lr=self._cfg.learn.learning_rate) self._learn_model = model_wrap(self._model, wrapper_name='base') # Algorithm config self._priority = self._cfg.priority self._priority_IS_weight = self._cfg.priority_IS_weight assert not self._priority and not self._priority_IS_weight, "Priority is not implemented in PPG" self._value_weight = self._cfg.learn.value_weight self._entropy_weight = self._cfg.learn.entropy_weight self._value_norm = self._cfg.learn.value_norm if self._value_norm: self._running_mean_std = RunningMeanStd(epsilon=1e-4, device=self._device) self._clip_ratio = self._cfg.learn.clip_ratio self._adv_norm = self._cfg.learn.adv_norm # Main model self._learn_model.reset() # Auxiliary memories self._aux_train_epoch = self._cfg.learn.aux_train_epoch self._train_iteration = 0 self._aux_memories = [] self._aux_bc_weight = self._cfg.learn.aux_bc_weight def _data_preprocess_learn(self, data: List[Any]) -> dict: """ Overview: Preprocess the data to fit the required data format for learning, including \ collate(stack data into batch), ignore done(in some fake terminate env),\ prepare loss weight per training sample, and cpu tensor to cuda. Arguments: - data (:obj:`List[Dict[str, Any]]`): The data collected from collect function. Returns: - data (:obj:`Dict[str, Any]`): The processed data, including at least ['done', 'weight']. """ # data preprocess data = default_collate(data) ignore_done = self._cfg.learn.ignore_done if ignore_done: data['done'] = None else: data['done'] = data['done'].float() data['weight'] = None if self._cuda: data = to_device(data, self._device) return data def _forward_learn(self, data: dict) -> Dict[str, Any]: """ Overview: Forward and backward function of learn mode. Arguments: - data (:obj:`Dict[str, Any]`): Input data used for policy forward, including the \ collected training samples from replay buffer. For each element in dict, the key of the \ dict is the name of data items and the value is the corresponding data. Usually, the value is \ torch.Tensor or np.ndarray or there dict/list combinations. In the ``_forward_learn`` method, data \ often need to first be stacked in the batch dimension by some utility functions such as \ ``default_preprocess_learn``. \ For PPG, each element in list is a dict containing at least the following keys: ``obs``, ``action``, \ ``reward``, ``logit``, ``value``, ``done``. Sometimes, it also contains other keys such as ``weight``. Returns: - info_dict (:obj:`Dict[str, Any]`): Dict type data, a info dict indicated training result, which will be \ recorded in text log and tensorboard, values are python scalar or a list of scalars. \ For the detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. .. note:: The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ For the data type that not supported, the main reason is that the corresponding model does not support it. \ You can implement you own model rather than use the default model. For more information, please raise an \ issue in GitHub repo and we will continue to follow up. .. note:: For more detailed examples, please refer to our unittest for PPGPolicy: ``ding.policy.tests.test_ppgs``. """ data = self._data_preprocess_learn(data) # ==================== # PPG forward # ==================== self._learn_model.train() return_infos = [] if self._value_norm: unnormalized_return = data['adv'] + data['value'] * self._running_mean_std.std data['return'] = unnormalized_return / self._running_mean_std.std self._running_mean_std.update(unnormalized_return.cpu().numpy()) else: data['return'] = data['adv'] + data['value'] for epoch in range(self._cfg.learn.actor_epoch_per_collect): for policy_data in split_data_generator(data, self._cfg.learn.batch_size, shuffle=True): policy_adv = policy_data['adv'] if self._adv_norm: # Normalize advantage in a total train_batch policy_adv = (policy_adv - policy_adv.mean()) / (policy_adv.std() + 1e-8) # Policy Phase(Policy) policy_output = self._learn_model.forward(policy_data['obs'], mode='compute_actor') policy_error_data = ppo_policy_data( policy_output['logit'], policy_data['logit'], policy_data['action'], policy_adv, policy_data['weight'] ) ppo_policy_loss, ppo_info = ppo_policy_error(policy_error_data, self._clip_ratio) policy_loss = ppo_policy_loss.policy_loss - self._entropy_weight * ppo_policy_loss.entropy_loss self._optimizer_ac.zero_grad() policy_loss.backward() self._optimizer_ac.step() for epoch in range(self._cfg.learn.critic_epoch_per_collect): for value_data in split_data_generator(data, self._cfg.learn.batch_size, shuffle=True): value_adv = value_data['adv'] return_ = value_data['return'] if self._adv_norm: # Normalize advantage in a total train_batch value_adv = (value_adv - value_adv.mean()) / (value_adv.std() + 1e-8) # Policy Phase(Value) value_output = self._learn_model.forward(value_data['obs'], mode='compute_critic') value_error_data = ppo_value_data( value_output['value'], value_data['value'], return_, value_data['weight'] ) value_loss = self._value_weight * ppo_value_error(value_error_data, self._clip_ratio) self._optimizer_aux_critic.zero_grad() value_loss.backward() self._optimizer_aux_critic.step() data['return_'] = data['return'] self._aux_memories.append(copy.deepcopy(data)) self._train_iteration += 1 # ==================== # PPG update # use aux loss after iterations and reset aux_memories # ==================== # Auxiliary Phase # record data for auxiliary head if self._train_iteration % self._cfg.learn.aux_freq == 0: aux_loss, bc_loss, aux_value_loss = self.learn_aux() return { 'policy_cur_lr': self._optimizer_ac.defaults['lr'], 'value_cur_lr': self._optimizer_aux_critic.defaults['lr'], 'policy_loss': ppo_policy_loss.policy_loss.item(), 'value_loss': value_loss.item(), 'entropy_loss': ppo_policy_loss.entropy_loss.item(), 'policy_adv_abs_max': policy_adv.abs().max().item(), 'approx_kl': ppo_info.approx_kl, 'clipfrac': ppo_info.clipfrac, 'aux_value_loss': aux_value_loss, 'auxiliary_loss': aux_loss, 'behavioral_cloning_loss': bc_loss, } else: return { 'policy_cur_lr': self._optimizer_ac.defaults['lr'], 'value_cur_lr': self._optimizer_aux_critic.defaults['lr'], 'policy_loss': ppo_policy_loss.policy_loss.item(), 'value_loss': value_loss.item(), 'entropy_loss': ppo_policy_loss.entropy_loss.item(), 'policy_adv_abs_max': policy_adv.abs().max().item(), 'approx_kl': ppo_info.approx_kl, 'clipfrac': ppo_info.clipfrac, } def _state_dict_learn(self) -> Dict[str, Any]: """ Overview: Return the state_dict of learn mode, usually including model and optimizer. Returns: - state_dict (:obj:`Dict[str, Any]`): the dict of current policy learn state, for saving and restoring. """ return { 'model': self._learn_model.state_dict(), 'optimizer_ac': self._optimizer_ac.state_dict(), 'optimizer_aux_critic': self._optimizer_aux_critic.state_dict(), } def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: """ Overview: Load the state_dict variable into policy learn mode. Arguments: - state_dict (:obj:`Dict[str, Any]`): the dict of policy learn state saved before.\ When the value is distilled into the policy network, we need to make sure the policy \ network does not change the action predictions, we need two optimizers, \ _optimizer_ac is used in policy net, and _optimizer_aux_critic is used in value net. .. tip:: If you want to only load some parts of model, you can simply set the ``strict`` argument in \ load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \ complicated operation. """ self._learn_model.load_state_dict(state_dict['model']) self._optimizer_ac.load_state_dict(state_dict['optimizer_ac']) self._optimizer_aux_critic.load_state_dict(state_dict['optimizer_aux_critic']) def _init_collect(self) -> None: """ Overview: Initialize the collect mode of policy, including related attributes and modules. For PPG, it contains the \ collect_model to balance the exploration and exploitation (e.g. the multinomial sample mechanism in \ discrete action space), and other algorithm-specific arguments such as unroll_len and gae_lambda. This method will be called in ``__init__`` method if ``collect`` field is in ``enable_field``. .. note:: If you want to set some spacial member variables in ``_init_collect`` method, you'd better name them \ with prefix ``_collect_`` to avoid conflict with other modes, such as ``self._collect_attr1``. """ self._unroll_len = self._cfg.collect.unroll_len self._collect_model = model_wrap(self._model, wrapper_name='multinomial_sample') # TODO continuous action space exploration self._collect_model.reset() self._gamma = self._cfg.collect.discount_factor self._gae_lambda = self._cfg.collect.gae_lambda def _forward_collect(self, data: dict) -> dict: """ Overview: Policy forward function of collect mode (collecting training data by interacting with envs). Forward means \ that the policy gets some necessary data (mainly observation) from the envs and then returns the output \ data, such as the action to interact with the envs. Arguments: - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \ values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer. Returns: - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \ other necessary data (action logit and value) for learn mode defined in \ ``self._process_transition`` method. The key of the dict is the same as the input data, \ i.e. environment id. .. tip:: If you want to add more tricks on this policy, like temperature factor in multinomial sample, you can pass \ related data as extra keyword arguments of this method. .. note:: The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ For the data type that not supported, the main reason is that the corresponding model does not support it. \ You can implement you own model rather than use the default model. For more information, please raise an \ issue in GitHub repo and we will continue to follow up. .. note:: For more detailed examples, please refer to our unittest for PPGPolicy: ``ding.policy.tests.test_ppg``. """ data_id = list(data.keys()) data = default_collate(list(data.values())) if self._cuda: data = to_device(data, self._device) self._collect_model.eval() with torch.no_grad(): output = self._collect_model.forward(data, mode='compute_actor_critic') if self._cuda: output = to_device(output, 'cpu') output = default_decollate(output) return {i: d for i, d in zip(data_id, output)} def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict: """ Overview: Process and pack one timestep transition data into a dict, which can be directly used for training and \ saved in replay buffer. For PPG, it contains obs, next_obs, action, reward, done, logit, value. Arguments: - obs (:obj:`Any`): Env observation - model_output (:obj:`dict`): The output of the policy network with the observation \ as input. For PPG, it contains the state value, action and the logit of the action. - timestep (:obj:`namedtuple`): The execution result namedtuple returned by the environment step \ method, except all the elements have been transformed into tensor data. Usually, it contains the next \ obs, reward, done, info, etc. Returns: - transition (:obj:`dict`): The processed transition data of the current timestep. .. note:: ``next_obs`` is used to calculate nstep return when necessary, so we place in into transition by default. \ You can delete this field to save memory occupancy if you do not need nstep return. """ transition = { 'obs': obs, 'next_obs': timestep.obs, 'logit': model_output['logit'], 'action': model_output['action'], 'value': model_output['value'], 'reward': timestep.reward, 'done': timestep.done, } return transition def _get_train_sample(self, data: List[Dict[str, Any]]) -> Union[None, List[Any]]: """ Overview: For a given trajectory (transitions, a list of transition) data, process it into a list of sample that \ can be used for training directly. In PPG, a train sample is a processed transition with new computed \ ``adv`` field. This method is usually used in collectors to execute necessary. \ RL data preprocessing before training, which can help learner amortize revelant time consumption. \ In addition, you can also implement this method as an identity function and do the data processing \ in ``self._forward_learn`` method. Arguments: - data (:obj:`List[Dict[str, Any]]`): The trajectory data (a list of transition), each element is \ the same format as the return value of ``self._process_transition`` method. Returns: - samples (:obj:`dict`): The processed train samples, each element is the similar format \ as input transitions, but may contain more data for training, such as GAE advantage. """ data = to_device(data, self._device) if self._cfg.learn.ignore_done: data[-1]['done'] = False if data[-1]['done']: last_value = torch.zeros_like(data[-1]['value']) else: with torch.no_grad(): last_value = self._collect_model.forward( data[-1]['next_obs'].unsqueeze(0), mode='compute_actor_critic' )['value'] if self._value_norm: last_value *= self._running_mean_std.std for i in range(len(data)): data[i]['value'] *= self._running_mean_std.std data = get_gae( data, to_device(last_value, self._device), gamma=self._gamma, gae_lambda=self._gae_lambda, cuda=False, ) if self._value_norm: for i in range(len(data)): data[i]['value'] /= self._running_mean_std.std return get_train_sample(data, self._unroll_len) def _get_batch_size(self) -> Dict[str, int]: """ Overview: Get learn batch size. In the PPG algorithm, different networks require different data.\ We need to get data['policy'] and data['value'] to train policy net and value net,\ this function is used to get the batch size of data['policy'] and data['value']. Returns: - output (:obj:`dict[str, int]`): Dict type data, including str type batch size and int type batch size. """ bs = self._cfg.learn.batch_size return {'policy': bs, 'value': bs} def _init_eval(self) -> None: """ Overview: Initialize the eval mode of policy, including related attributes and modules. For PPG, it contains the \ eval model to select optimial action (e.g. greedily select action with argmax mechanism in discrete \ action). This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``. .. note:: If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \ with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``. """ self._eval_model = model_wrap(self._model, wrapper_name='argmax_sample') self._eval_model.reset() def _forward_eval(self, data: dict) -> dict: """ Overview: Policy forward function of eval mode (evaluation policy performance by interacting with envs). Forward \ means that the policy gets some necessary data (mainly observation) from the envs and then returns the \ action to interact with the envs. ``_forward_eval`` in PPG often uses deterministic sample method to get \ actions while ``_forward_collect`` usually uses stochastic sample method for balance exploration and \ exploitation. Arguments: - data (:obj:`Dict[str, Any]`): The input data used for policy forward, including at least the obs. The \ key of the dict is environment id and the value is the corresponding data of the env. Returns: - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \ key of the dict is the same as the input data, i.e. environment id. .. note:: The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ For the data type that not supported, the main reason is that the corresponding model does not support it. \ You can implement you own model rather than use the default model. For more information, please raise an \ issue in GitHub repo and we will continue to follow up. .. note:: For more detailed examples, please refer to our unittest for PPGPolicy: ``ding.policy.tests.test_ppg``. """ data_id = list(data.keys()) data = default_collate(list(data.values())) if self._cuda: data = to_device(data, self._device) self._eval_model.eval() with torch.no_grad(): output = self._eval_model.forward(data, mode='compute_actor') if self._cuda: output = to_device(output, 'cpu') output = default_decollate(output) return {i: d for i, d in zip(data_id, output)} def _monitor_vars_learn(self) -> List[str]: """ Overview: Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \ as text logger, tensorboard logger, will use these keys to save the corresponding data. Returns: - vars (:obj:`List[str]`): The list of the necessary keys to be logged. """ return [ 'policy_cur_lr', 'value_cur_lr', 'policy_loss', 'value_loss', 'entropy_loss', 'policy_adv_abs_max', 'approx_kl', 'clipfrac', 'aux_value_loss', 'auxiliary_loss', 'behavioral_cloning_loss', ] def learn_aux(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Overview: The auxiliary phase training, where the value is distilled into the policy network. In PPG algorithm, \ we use the value function loss as the auxiliary objective, thereby sharing features between the policy \ and value function while minimizing distortions to the policy. We also use behavioral cloning loss to \ optimize the auxiliary objective while otherwise preserving the original policy. Returns: - aux_loss (:obj:`Tuple[torch.Tensor, torch.Tensor, torch.Tensor]`): Including average auxiliary loss\ average behavioral cloning loss, and average auxiliary value loss. """ aux_memories = self._aux_memories # gather states and target values into one tensor data = {} states = [] actions = [] return_ = [] old_values = [] weights = [] for memory in aux_memories: # for memory in memories: states.append(memory['obs']) actions.append(memory['action']) return_.append(memory['return_']) old_values.append(memory['value']) if memory['weight'] is None: weight = torch.ones_like(memory['action']) else: weight = torch.tensor(memory['weight']) weights.append(weight) data['obs'] = torch.cat(states) data['action'] = torch.cat(actions) data['return_'] = torch.cat(return_) data['value'] = torch.cat(old_values) data['weight'] = torch.cat(weights).float() # compute current policy logit_old with torch.no_grad(): data['logit_old'] = self._model.forward(data['obs'], mode='compute_actor')['logit'] # prepared dataloader for auxiliary phase training dl = create_shuffled_dataloader(data, self._cfg.learn.batch_size) # the proposed auxiliary phase training # where the value is distilled into the policy network, # while making sure the policy network does not change the action predictions (kl div loss) i = 0 auxiliary_loss_ = 0 behavioral_cloning_loss_ = 0 value_loss_ = 0 for epoch in range(self._aux_train_epoch): for data in dl: policy_output = self._model.forward(data['obs'], mode='compute_actor_critic') # Calculate ppg error 'logit_new', 'logit_old', 'action', 'value_new', 'value_old', 'return_', 'weight' data_ppg = ppg_data( policy_output['logit'], data['logit_old'], data['action'], policy_output['value'], data['value'], data['return_'], data['weight'] ) ppg_joint_loss = ppg_joint_error(data_ppg, self._clip_ratio) wb = self._aux_bc_weight total_loss = ppg_joint_loss.auxiliary_loss + wb * ppg_joint_loss.behavioral_cloning_loss # # policy network loss copmoses of both the kl div loss as well as the auxiliary loss # aux_loss = clipped_value_loss(policy_values, rewards, old_values, self.value_clip) # loss_kl = F.kl_div(action_logprobs, old_action_probs, reduction='batchmean') # policy_loss = aux_loss + loss_kl self._optimizer_ac.zero_grad() total_loss.backward() self._optimizer_ac.step() # paper says it is important to train the value network extra during the auxiliary phase # Calculate ppg error 'value_new', 'value_old', 'return_', 'weight' values = self._model.forward(data['obs'], mode='compute_critic')['value'] data_aux = ppo_value_data(values, data['value'], data['return_'], data['weight']) value_loss = ppo_value_error(data_aux, self._clip_ratio) self._optimizer_aux_critic.zero_grad() value_loss.backward() self._optimizer_aux_critic.step() auxiliary_loss_ += ppg_joint_loss.auxiliary_loss.item() behavioral_cloning_loss_ += ppg_joint_loss.behavioral_cloning_loss.item() value_loss_ += value_loss.item() i += 1 self._aux_memories = [] return auxiliary_loss_ / i, behavioral_cloning_loss_ / i, value_loss_ / i
@POLICY_REGISTRY.register('ppg_offpolicy') class PPGOffPolicy(Policy): """ Overview: Policy class of PPG algorithm with off-policy training mode. Off-policy PPG contains two different data \ max_use buffers. The policy buffer offers data for policy phase , while the value buffer provides auxiliary \ phase's data. The whole training procedure is similar to off-policy PPO but execute additional auxiliary \ phase with a fixed frequency. Interface: ``_init_learn``, ``_data_preprocess_learn``, ``_forward_learn``, ``_state_dict_learn``, \ ``_load_state_dict_learn``, ``_init_collect``, ``_forward_collect``, ``_process_transition``, \ ``_get_train_sample``, ``_get_batch_size``, ``_init_eval``, ``_forward_eval``, ``default_model``, \ ``_monitor_vars_learn``, ``learn_aux``. Config: == ==================== ======== ============== ======================================== ======================= ID Symbol Type Default Value Description Other(Shape) == ==================== ======== ============== ======================================== ======================= 1 ``type`` str ppg | RL policy register name, refer to | this arg is optional, | registry ``POLICY_REGISTRY`` | a placeholder 2 ``cuda`` bool False | Whether to use cuda for network | this arg can be diff- | erent from modes 3 ``on_policy`` bool True | Whether the RL algorithm is on-policy | or off-policy 4. ``priority`` bool False | Whether use priority(PER) | priority sample, | update priority 5 | ``priority_`` bool False | Whether use Importance Sampling | IS weight | ``IS_weight`` | Weight to correct biased update. 6 | ``learn.update`` int 5 | How many updates(iterations) to train | this args can be vary | ``_per_collect`` | after collector's one collection. Only | from envs. Bigger val | valid in serial training | means more off-policy 7 | ``learn.value_`` float 1.0 | The loss weight of value network | policy network weight | ``weight`` | is set to 1 8 | ``learn.entropy_`` float 0.01 | The loss weight of entropy | policy network weight | ``weight`` | regularization | is set to 1 9 | ``learn.clip_`` float 0.2 | PPO clip ratio | ``ratio`` 10 | ``learn.adv_`` bool False | Whether to use advantage norm in | ``norm`` | a whole training batch 11 | ``learn.aux_`` int 5 | The frequency(normal update times) | ``freq`` | of auxiliary phase training 12 | ``learn.aux_`` int 6 | The training epochs of auxiliary | ``train_epoch`` | phase 13 | ``learn.aux_`` int 1 | The loss weight of behavioral_cloning | ``bc_weight`` | in auxiliary phase 14 | ``collect.dis`` float 0.99 | Reward's future discount factor, aka. | may be 1 when sparse | ``count_factor`` | gamma | reward env 15 | ``collect.gae_`` float 0.95 | GAE lambda factor for the balance | ``lambda`` | of bias and variance(1-step td and mc) == ==================== ======== ============== ======================================== ======================= """ config = dict( # (str) RL policy register name (refer to function "POLICY_REGISTRY"). type='ppg_offpolicy', # (bool) Whether to use cuda for network. cuda=False, # (bool) Whether the RL algorithm is on-policy or off-policy. (Note: in practice PPO can be off-policy used) on_policy=False, priority=False, # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True. priority_IS_weight=False, # (bool) Whether to need policy data in process transition transition_with_policy_data=True, learn=dict( update_per_collect=5, batch_size=64, learning_rate=0.001, # ============================================================== # The following configs is algorithm-specific # ============================================================== # (float) The loss weight of value network, policy network weight is set to 1 value_weight=0.5, # (float) The loss weight of entropy regularization, policy network weight is set to 1 entropy_weight=0.01, # (float) PPO clip ratio, defaults to 0.2 clip_ratio=0.2, # (bool) Whether to use advantage norm in a whole training batch adv_norm=False, # (int) The frequency(normal update times) of auxiliary phase training aux_freq=5, # (int) The training epochs of auxiliary phase aux_train_epoch=6, # (int) The loss weight of behavioral_cloning in auxiliary phase aux_bc_weight=1, ignore_done=False, ), collect=dict( # n_sample=64, unroll_len=1, # ============================================================== # The following configs is algorithm-specific # ============================================================== # (float) Reward's future discount factor, aka. gamma. discount_factor=0.99, # (float) GAE lambda factor for the balance of bias and variance(1-step td and mc) gae_lambda=0.95, ), eval=dict(), other=dict( replay_buffer=dict( # PPG use two separate buffer for different reuse multi_buffer=True, policy=dict(replay_buffer_size=1000, ), value=dict(replay_buffer_size=1000, ), ), ), ) def default_model(self) -> Tuple[str, List[str]]: """ Overview: Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \ automatically call this method to get the default model setting and create model. Returns: - model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names. .. note:: The user can define and use customized network model but must obey the same inferface definition indicated \ by import_names path. """ return 'ppg', ['ding.model.template.ppg'] def _init_learn(self) -> None: """ Overview: Initialize the learn mode of policy, including related attributes and modules. For PPG, it mainly \ contains optimizer, algorithm-specific arguments such as aux_bc_weight and aux_train_epoch. This method \ also executes some special network initializations and prepares running mean/std monitor for value. \ This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. .. note:: For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ and ``_load_state_dict_learn`` methods. .. note:: For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. .. note:: If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. """ # Optimizer self._optimizer_ac = Adam(self._model.actor_critic.parameters(), lr=self._cfg.learn.learning_rate) self._optimizer_aux_critic = Adam(self._model.aux_critic.parameters(), lr=self._cfg.learn.learning_rate) self._learn_model = model_wrap(self._model, wrapper_name='base') # Algorithm config self._priority = self._cfg.priority self._priority_IS_weight = self._cfg.priority_IS_weight assert not self._priority and not self._priority_IS_weight, "Priority is not implemented in PPG" self._value_weight = self._cfg.learn.value_weight self._entropy_weight = self._cfg.learn.entropy_weight self._clip_ratio = self._cfg.learn.clip_ratio self._adv_norm = self._cfg.learn.adv_norm # Main model self._learn_model.reset() # Auxiliary memories self._aux_train_epoch = self._cfg.learn.aux_train_epoch self._train_iteration = 0 self._aux_memories = [] self._aux_bc_weight = self._cfg.learn.aux_bc_weight def _data_preprocess_learn(self, data: List[Any]) -> dict: """ Overview: Preprocess the data to fit the required data format for learning, including \ collate(stack data into batch), ignore done(in some fake terminate env),\ prepare loss weight per training sample, and cpu tensor to cuda. Arguments: - data (:obj:`List[Dict[str, Any]]`): The data collected from collect function. Returns: - data (:obj:`Dict[str, Any]`): The processed data, including at least ['done', 'weight']. """ # data preprocess for k, data_item in data.items(): data_item = default_collate(data_item) ignore_done = self._cfg.learn.ignore_done if ignore_done: data_item['done'] = None else: data_item['done'] = data_item['done'].float() data_item['weight'] = None data[k] = data_item if self._cuda: data = to_device(data, self._device) return data def _forward_learn(self, data: dict) -> Dict[str, Any]: """ Overview: Forward and backward function of learn mode. Arguments: - data (:obj:`Dict[str, Any]`): Input data used for policy forward, including the \ collected training samples from replay buffer. For each element in dict, the key of the \ dict is the name of data items and the value is the corresponding data. Usually, \ the class type of value is either torch.Tensor or np.ndarray, or a dict/list containing \ either torch.Tensor or np.ndarray items In the ``_forward_learn`` method, data \ often need to first be stacked in the batch dimension by some utility functions such as \ ``default_preprocess_learn``. \ For PPGOff, each element in list is a dict containing at least the following keys: ``obs``, \ ``action``, ``reward``, ``logit``, ``value``, ``done``. Sometimes, it also contains other keys \ such as ``weight``. Returns: - info_dict (:obj:`Dict[str, Any]`): Dict type data, a info dict indicated training result, which will be \ recorded in text log and tensorboard, values are python scalar or a list of scalars. \ For the detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. ReturnsKeys: - necessary: "current lr", "total_loss", "policy_loss", "value_loss", "entropy_loss", \ "adv_abs_max", "approx_kl", "clipfrac", \ "aux_value_loss", "auxiliary_loss", "behavioral_cloning_loss". - current_lr (:obj:`float`): Current learning rate. - total_loss (:obj:`float`): The calculated loss. - policy_loss (:obj:`float`): The policy(actor) loss of ppg. - value_loss (:obj:`float`): The value(critic) loss of ppg. - entropy_loss (:obj:`float`): The entropy loss. - auxiliary_loss (:obj:`float`): The auxiliary loss, we use the value function loss \ as the auxiliary objective, thereby sharing features between the policy and value function\ while minimizing distortions to the policy. - aux_value_loss (:obj:`float`): The auxiliary value loss, we need to train the value network extra \ during the auxiliary phase, it's the value loss we train the value network during auxiliary phase. - behavioral_cloning_loss (:obj:`float`): The behavioral cloning loss, used to optimize the auxiliary\ objective while otherwise preserving the original policy. """ data = self._data_preprocess_learn(data) # ==================== # PPG forward # ==================== self._learn_model.train() policy_data, value_data = data['policy'], data['value'] policy_adv, value_adv = policy_data['adv'], value_data['adv'] return_ = value_data['value'] + value_adv if self._adv_norm: # Normalize advantage in a total train_batch policy_adv = (policy_adv - policy_adv.mean()) / (policy_adv.std() + 1e-8) value_adv = (value_adv - value_adv.mean()) / (value_adv.std() + 1e-8) # Policy Phase(Policy) policy_output = self._learn_model.forward(policy_data['obs'], mode='compute_actor') policy_error_data = ppo_policy_data( policy_output['logit'], policy_data['logit'], policy_data['action'], policy_adv, policy_data['weight'] ) ppo_policy_loss, ppo_info = ppo_policy_error(policy_error_data, self._clip_ratio) policy_loss = ppo_policy_loss.policy_loss - self._entropy_weight * ppo_policy_loss.entropy_loss self._optimizer_ac.zero_grad() policy_loss.backward() self._optimizer_ac.step() # Policy Phase(Value) value_output = self._learn_model.forward(value_data['obs'], mode='compute_critic') value_error_data = ppo_value_data(value_output['value'], value_data['value'], return_, value_data['weight']) value_loss = self._value_weight * ppo_value_error(value_error_data, self._clip_ratio) self._optimizer_aux_critic.zero_grad() value_loss.backward() self._optimizer_aux_critic.step() # ==================== # PPG update # use aux loss after iterations and reset aux_memories # ==================== # Auxiliary Phase # record data for auxiliary head data = data['value'] data['return_'] = return_.data self._aux_memories.append(copy.deepcopy(data)) self._train_iteration += 1 total_loss = policy_loss + value_loss if self._train_iteration % self._cfg.learn.aux_freq == 0: aux_loss, bc_loss, aux_value_loss = self.learn_aux() total_loss += aux_loss + bc_loss + aux_value_loss return { 'policy_cur_lr': self._optimizer_ac.defaults['lr'], 'value_cur_lr': self._optimizer_aux_critic.defaults['lr'], 'policy_loss': ppo_policy_loss.policy_loss.item(), 'value_loss': value_loss.item(), 'entropy_loss': ppo_policy_loss.entropy_loss.item(), 'policy_adv_abs_max': policy_adv.abs().max().item(), 'approx_kl': ppo_info.approx_kl, 'clipfrac': ppo_info.clipfrac, 'aux_value_loss': aux_value_loss, 'auxiliary_loss': aux_loss, 'behavioral_cloning_loss': bc_loss, 'total_loss': total_loss.item(), } else: return { 'policy_cur_lr': self._optimizer_ac.defaults['lr'], 'value_cur_lr': self._optimizer_aux_critic.defaults['lr'], 'policy_loss': ppo_policy_loss.policy_loss.item(), 'value_loss': value_loss.item(), 'entropy_loss': ppo_policy_loss.entropy_loss.item(), 'policy_adv_abs_max': policy_adv.abs().max().item(), 'approx_kl': ppo_info.approx_kl, 'clipfrac': ppo_info.clipfrac, 'total_loss': total_loss.item(), } def _state_dict_learn(self) -> Dict[str, Any]: """ Overview: Return the state_dict of learn mode, usually including model and optimizer. Returns: - state_dict (:obj:`Dict[str, Any]`): the dict of current policy learn state, for saving and restoring. """ return { 'model': self._learn_model.state_dict(), 'optimizer_ac': self._optimizer_ac.state_dict(), 'optimizer_aux_critic': self._optimizer_aux_critic.state_dict(), } def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: """ Overview: Load the state_dict variable into policy learn mode. Arguments: - state_dict (:obj:`Dict[str, Any]`): the dict of policy learn state saved before.\ When the value is distilled into the policy network, we need to make sure the policy \ network does not change the action predictions, we need two optimizers, \ _optimizer_ac is used in policy net, and _optimizer_aux_critic is used in value net. .. tip:: If you want to only load some parts of model, you can simply set the ``strict`` argument in \ load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \ complicated operation. """ self._learn_model.load_state_dict(state_dict['model']) self._optimizer_ac.load_state_dict(state_dict['optimizer_ac']) self._optimizer_aux_critic.load_state_dict(state_dict['optimizer_aux_critic']) def _init_collect(self) -> None: """ Overview: Initialize the collect mode of policy, including related attributes and modules. For PPO, it contains the \ collect_model to balance the exploration and exploitation (e.g. the multinomial sample mechanism in \ discrete action space), and other algorithm-specific arguments such as unroll_len and gae_lambda. This method will be called in ``__init__`` method if ``collect`` field is in ``enable_field``. .. note:: If you want to set some spacial member variables in ``_init_collect`` method, you'd better name them \ with prefix ``_collect_`` to avoid conflict with other modes, such as ``self._collect_attr1``. """ self._unroll_len = self._cfg.collect.unroll_len self._collect_model = model_wrap(self._model, wrapper_name='multinomial_sample') # TODO continuous action space exploration self._collect_model.reset() self._gamma = self._cfg.collect.discount_factor self._gae_lambda = self._cfg.collect.gae_lambda def _forward_collect(self, data: dict) -> dict: """ Overview: Policy forward function of collect mode (collecting training data by interacting with envs). Forward means \ that the policy gets some necessary data (mainly observation) from the envs and then returns the output \ data, such as the action to interact with the envs. Arguments: - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \ values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer. Returns: - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \ other necessary data (action logit and value) for learn mode defined in \ ``self._process_transition`` method. The key of the dict is the same as the input data, \ i.e. environment id. .. tip:: If you want to add more tricks on this policy, like temperature factor in multinomial sample, you can pass \ related data as extra keyword arguments of this method. .. note:: The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ For the data type that not supported, the main reason is that the corresponding model does not support it. \ You can implement you own model rather than use the default model. For more information, please raise an \ issue in GitHub repo and we will continue to follow up. .. note:: For more detailed examples, please refer to our unittest for PPGOffPolicy: ``ding.policy.tests.test_ppg``. """ data_id = list(data.keys()) data = default_collate(list(data.values())) if self._cuda: data = to_device(data, self._device) self._collect_model.eval() with torch.no_grad(): output = self._collect_model.forward(data, mode='compute_actor_critic') if self._cuda: output = to_device(output, 'cpu') output = default_decollate(output) return {i: d for i, d in zip(data_id, output)} def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict: """ Overview: Process and pack one timestep transition data into a dict, which can be directly used for training and \ saved in replay buffer. For PPG, it contains obs, next_obs, action, reward, done, logit, value. Arguments: - obs (:obj:`Any`): Env observation - model_output (:obj:`dict`): The output of the policy network with the observation \ as input. For PPG, it contains the state value, action and the logit of the action. - timestep (:obj:`namedtuple`): The execution result namedtuple returned by the environment step \ method, except all the elements have been transformed into tensor data. Usually, it contains the next \ obs, reward, done, info, etc. Returns: - transition (:obj:`dict`): The processed transition data of the current timestep. .. note:: ``next_obs`` is used to calculate nstep return when necessary, so we place in into transition by default. \ You can delete this field to save memory occupancy if you do not need nstep return. """ transition = { 'obs': obs, 'next_obs': timestep.obs, 'logit': model_output['logit'], 'action': model_output['action'], 'value': model_output['value'], 'reward': timestep.reward, 'done': timestep.done, } return transition def _get_train_sample(self, data: list) -> Union[None, List[Any]]: """ Overview: For a given trajectory (transitions, a list of transition) data, process it into a list of sample that \ can be used for training directly. In PPG, a train sample is a processed transition with new computed \ ``adv`` field. This method is usually used in collectors to execute necessary. \ RL data preprocessing before training, which can help learner amortize revelant time consumption. \ In addition, you can also implement this method as an identity function and do the data processing \ in ``self._forward_learn`` method. Arguments: - data (:obj:`list`): The trajectory data (a list of transition), each element is \ the same format as the return value of ``self._process_transition`` method. Returns: - samples (:obj:`dict`): The processed train samples, each element is the similar format \ as input transitions, but may contain more data for training, such as GAE advantage. """ data = get_gae_with_default_last_value( data, data[-1]['done'], gamma=self._gamma, gae_lambda=self._gae_lambda, cuda=False, ) data = get_train_sample(data, self._unroll_len) for d in data: d['buffer_name'] = ["policy", "value"] return data def _get_batch_size(self) -> Dict[str, int]: """ Overview: Get learn batch size. In the PPG algorithm, different networks require different data.\ We need to get data['policy'] and data['value'] to train policy net and value net,\ this function is used to get the batch size of data['policy'] and data['value']. Returns: - output (:obj:`dict[str, int]`): Dict type data, including str type batch size and int type batch size. """ bs = self._cfg.learn.batch_size return {'policy': bs, 'value': bs} def _init_eval(self) -> None: """ Overview: Initialize the eval mode of policy, including related attributes and modules. For PPG, it contains the \ eval model to select optimial action (e.g. greedily select action with argmax mechanism in discrete \ action). This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``. .. note:: If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \ with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``. """ self._eval_model = model_wrap(self._model, wrapper_name='argmax_sample') self._eval_model.reset() def _forward_eval(self, data: dict) -> dict: r""" Overview: Policy forward function of eval mode (evaluation policy performance by interacting with envs). Forward \ means that the policy gets some necessary data (mainly observation) from the envs and then returns the \ action to interact with the envs. ``_forward_eval`` in PPG often uses deterministic sample method to get \ actions while ``_forward_collect`` usually uses stochastic sample method for balance exploration and \ exploitation. Arguments: - data (:obj:`Dict[str, Any]`): The input data used for policy forward, including at least the obs. The \ key of the dict is environment id and the value is the corresponding data of the env. Returns: - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \ key of the dict is the same as the input data, i.e. environment id. .. note:: The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ For the data type that not supported, the main reason is that the corresponding model does not support it. \ You can implement you own model rather than use the default model. For more information, please raise an \ issue in GitHub repo and we will continue to follow up. .. note:: For more detailed examples, please refer to our unittest for PPGOffPolicy: ``ding.policy.tests.test_ppg``. """ data_id = list(data.keys()) data = default_collate(list(data.values())) if self._cuda: data = to_device(data, self._device) self._eval_model.eval() with torch.no_grad(): output = self._eval_model.forward(data, mode='compute_actor') if self._cuda: output = to_device(output, 'cpu') output = default_decollate(output) return {i: d for i, d in zip(data_id, output)} def _monitor_vars_learn(self) -> List[str]: """ Overview: Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \ as text logger, tensorboard logger, will use these keys to save the corresponding data. Returns: - vars (:obj:`List[str]`): The list of the necessary keys to be logged. """ return [ 'policy_cur_lr', 'value_cur_lr', 'policy_loss', 'value_loss', 'entropy_loss', 'policy_adv_abs_max', 'approx_kl', 'clipfrac', 'aux_value_loss', 'auxiliary_loss', 'behavioral_cloning_loss', ] def learn_aux(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Overview: The auxiliary phase training, where the value is distilled into the policy network. In PPG algorithm, \ we use the value function loss as the auxiliary objective, thereby sharing features between the policy \ and value function while minimizing distortions to the policy. We also use behavioral cloning loss to \ optimize the auxiliary objective while otherwise preserving the original policy. Returns: - aux_loss (:obj:`Tuple[torch.Tensor, torch.Tensor, torch.Tensor]`): Including average auxiliary loss\ average behavioral cloning loss, and average auxiliary value loss. """ aux_memories = self._aux_memories # gather states and target values into one tensor data = {} states = [] actions = [] return_ = [] old_values = [] weights = [] for memory in aux_memories: # for memory in memories: states.append(memory['obs']) actions.append(memory['action']) return_.append(memory['return_']) old_values.append(memory['value']) if memory['weight'] is None: weight = torch.ones_like(memory['action']) else: weight = torch.tensor(memory['weight']) weights.append(weight) data['obs'] = torch.cat(states) data['action'] = torch.cat(actions) data['return_'] = torch.cat(return_) data['value'] = torch.cat(old_values) data['weight'] = torch.cat(weights) # compute current policy logit_old with torch.no_grad(): data['logit_old'] = self._model.forward(data['obs'], mode='compute_actor')['logit'] # prepared dataloader for auxiliary phase training dl = create_shuffled_dataloader(data, self._cfg.learn.batch_size) # the proposed auxiliary phase training # where the value is distilled into the policy network, # while making sure the policy network does not change the action predictions (kl div loss) i = 0 auxiliary_loss_ = 0 behavioral_cloning_loss_ = 0 value_loss_ = 0 for epoch in range(self._aux_train_epoch): for data in dl: policy_output = self._model.forward(data['obs'], mode='compute_actor_critic') # Calculate ppg error 'logit_new', 'logit_old', 'action', 'value_new', 'value_old', 'return_', 'weight' data_ppg = ppg_data( policy_output['logit'], data['logit_old'], data['action'], policy_output['value'], data['value'], data['return_'], data['weight'] ) ppg_joint_loss = ppg_joint_error(data_ppg, self._clip_ratio) wb = self._aux_bc_weight total_loss = ppg_joint_loss.auxiliary_loss + wb * ppg_joint_loss.behavioral_cloning_loss # # policy network loss copmoses of both the kl div loss as well as the auxiliary loss # aux_loss = clipped_value_loss(policy_values, rewards, old_values, self.value_clip) # loss_kl = F.kl_div(action_logprobs, old_action_probs, reduction='batchmean') # policy_loss = aux_loss + loss_kl self._optimizer_ac.zero_grad() total_loss.backward() self._optimizer_ac.step() # paper says it is important to train the value network extra during the auxiliary phase # Calculate ppg error 'value_new', 'value_old', 'return_', 'weight' values = self._model.forward(data['obs'], mode='compute_critic')['value'] data_aux = ppo_value_data(values, data['value'], data['return_'], data['weight']) value_loss = ppo_value_error(data_aux, self._clip_ratio) self._optimizer_aux_critic.zero_grad() value_loss.backward() self._optimizer_aux_critic.step() auxiliary_loss_ += ppg_joint_loss.auxiliary_loss.item() behavioral_cloning_loss_ += ppg_joint_loss.behavioral_cloning_loss.item() value_loss_ += value_loss.item() i += 1 self._aux_memories = [] return auxiliary_loss_ / i, behavioral_cloning_loss_ / i, value_loss_ / i