Shortcuts

Source code for ding.reward_model.pwil_irl_model

from typing import Dict, List
import math
import random
import pickle
import torch

from ding.utils import REWARD_MODEL_REGISTRY
from .base_reward_model import BaseRewardModel


def collect_state_action_pairs(iterator):
    # concat state and action
    """
    Overview:
        Concate state and action pairs from input iterator.
    Arguments:
        - iterator (:obj:`Iterable`): Iterables with at least ``obs`` and ``action`` tensor keys.
    Returns:
        - res (:obj:`Torch.tensor`): State and action pairs.
    """
    res = []
    for item in iterator:
        state = item['obs']
        action = item['action']
        # s_a = torch.cat([state, action.float()], dim=-1)
        res.append((state, action))
    return res


[docs]@REWARD_MODEL_REGISTRY.register('pwil') class PwilRewardModel(BaseRewardModel): """ Overview: The Pwil reward model class (https://arxiv.org/pdf/2006.04678.pdf) Interface: ``estimate``, ``train``, ``load_expert_data``, ``collect_data``, ``clear_date``, \ ``__init__``, ``_train``, ``_get_state_distance``, ``_get_action_distance`` Config: == ================== ===== ============= ======================================= ======================= ID Symbol Type Default Value Description Other(Shape) == ================== ===== ============= ======================================= ======================= 1 ``type`` str pwil | Reward model register name, refer | | to registry ``REWARD_MODEL_REGISTRY`` | 2 | ``expert_data_`` str expert_data. | Path to the expert dataset | Should be a '.pkl' | ``path`` .pkl | | file 3 | ``sample_size`` int 1000 | sample data from expert dataset | | with fixed size | 4 | ``alpha`` int 5 | factor alpha | 5 | ``beta`` int 5 | factor beta | 6 | ``s_size`` int 4 | state size | 7 | ``a_size`` int 2 | action size | 8 | ``clear_buffer`` int 1 | clear buffer per fixed iters | make sure replay ``_per_iters`` | buffer's data count | isn't too few. | (code work in entry) == ================== ===== ============= ======================================= ======================= Properties: - reward_table (:obj: `Dict`): In this algorithm, reward model is a dictionary. """ config = dict( # (str) Reward model register name, refer to registry ``REWARD_MODEL_REGISTRY``. type='pwil', # (str) Path to the expert dataset. # expert_data_path='expert_data.pkl', # (int) Sample data from expert dataset with fixed size. sample_size=1000, # r = alpha * exp((-beta*T/sqrt(|s_size|+ |a_size|))*c_i) # key idea for this reward is to minimize. # the Wasserstein distance between the state-action distribution. # (int) Factor alpha. alpha=5, # (int) Factor beta. beta=5, #(int)State size. # s_size=4, # (int) Action size. # a_size=2, # (int) Clear buffer per fixed iters. clear_buffer_per_iters=1, )
[docs] def __init__(self, config: Dict, device: str, tb_logger: 'SummaryWriter') -> None: # noqa """ Overview: Initialize ``self.`` See ``help(type(self))`` for accurate signature. Arguments: - cfg (:obj:`Dict`): Training config - device (:obj:`str`): Device usage, i.e. "cpu" or "cuda" - tb_logger (:obj:`str`): Logger, defaultly set as 'SummaryWriter' for model summary """ super(PwilRewardModel, self).__init__() self.cfg: Dict = config assert device in ["cpu", "cuda"] or "cuda" in device self.device = device self.expert_data: List[tuple] = [] self.train_data: List[tuple] = [] # In this algo, model is a dict self.reward_table: Dict = {} self.T: int = 0 self.load_expert_data()
[docs] def load_expert_data(self) -> None: """ Overview: Getting the expert data from ``config['expert_data_path']`` attribute in self Effects: This is a side effect function which updates the expert data attribute (e.g. ``self.expert_data``); \ in this algorithm, also the ``self.expert_s``, ``self.expert_a`` for states and actions are updated. """ with open(self.cfg.expert_data_path, 'rb') as f: self.expert_data = pickle.load(f) print("the data size is:", len(self.expert_data)) sample_size = min(self.cfg.sample_size, len(self.expert_data)) self.expert_data = random.sample(self.expert_data, sample_size) self.expert_data = [(item['obs'], item['action']) for item in self.expert_data] self.expert_s, self.expert_a = list(zip(*self.expert_data)) print('the expert data demonstrations is:', len(self.expert_data))
[docs] def collect_data(self, data: list) -> None: """ Overview: Collecting training data formatted by ``fn:concat_state_action_pairs``. Arguments: - data (:obj:`list`): Raw training data (e.g. some form of states, actions, obs, etc) Effects: - This is a side effect function which updates the data attribute in ``self``; \ in this algorithm, also the ``s_size``, ``a_size`` for states and actions are updated in the \ attribute in ``self.cfg`` Dict; ``reward_factor`` also updated as ``collect_data`` called. """ self.train_data.extend(collect_state_action_pairs(data)) self.T = len(self.train_data) s_size = self.cfg.s_size a_size = self.cfg.a_size beta = self.cfg.beta self.reward_factor = -beta * self.T / math.sqrt(s_size + a_size)
[docs] def train(self) -> None: """ Overview: Training the Pwil reward model. """ self._train(self.train_data)
[docs] def estimate(self, data: list) -> List[Dict]: """ Overview: Estimate reward by rewriting the reward key in each row of the data. Arguments: - data (:obj:`list`): the list of data used for estimation, \ with at least ``obs`` and ``action`` keys. Effects: - This is a side effect function which updates the ``reward_table`` with ``(obs,action)`` \ tuples from input. """ # NOTE: deepcopy reward part of data is very important, # otherwise the reward of data in the replay buffer will be incorrectly modified. train_data_augmented = self.reward_deepcopy(data) for item in train_data_augmented: s = item['obs'] a = item['action'] if (s, a) in self.reward_table: item['reward'] = self.reward_table[(s, a)] else: # when (s, a) pair is not trained, set the reward value to default value(e.g.: 0) item['reward'] = torch.zeros_like(item['reward']) return train_data_augmented
[docs] def _get_state_distance(self, s1: list, s2: list) -> torch.Tensor: """ Overview: Getting distances of states given 2 state lists. One single state \ is of shape ``torch.Size([n])`` (``n`` referred in in-code comments) Arguments: - s1 (:obj:`torch.Tensor list`): the 1st states' list of size M - s2 (:obj:`torch.Tensor list`): the 2nd states' list of size N Returns: - distance (:obj:`torch.Tensor`) Euclidean distance tensor of \ the state tensor lists, of size M x N. """ # Format the values in the tensors to be of float type s1 = torch.stack(s1).float() s2 = torch.stack(s2).float() M, N = s1.shape[0], s2.shape[0] # Automatically fill in length s1 = s1.view(M, -1) s2 = s2.view(N, -1) # Automatically fill in & format the tensor size to be (MxNxn) s1 = s1.unsqueeze(1).repeat(1, N, 1) s2 = s2.unsqueeze(0).repeat(M, 1, 1) # Return the distance tensor of size MxN return ((s1 - s2) ** 2).mean(dim=-1)
[docs] def _get_action_distance(self, a1: list, a2: list) -> torch.Tensor: # TODO the metric of action distance maybe different from envs """ Overview: Getting distances of actions given 2 action lists. One single action \ is of shape ``torch.Size([n])`` (``n`` referred in in-code comments) Arguments: - a1 (:obj:`torch.Tensor list`): the 1st actions' list of size M - a2 (:obj:`torch.Tensor list`): the 2nd actions' list of size N Returns: - distance (:obj:`torch.Tensor`) Euclidean distance tensor of \ the action tensor lists, of size M x N. """ a1 = torch.stack(a1).float() a2 = torch.stack(a2).float() M, N = a1.shape[0], a2.shape[0] a1 = a1.view(M, -1) a2 = a2.view(N, -1) a1 = a1.unsqueeze(1).repeat(1, N, 1) a2 = a2.unsqueeze(0).repeat(M, 1, 1) return ((a1 - a2) ** 2).mean(dim=-1)
def _train(self, data: list): """ Overview: Helper function for ``train``, find the min disctance ``s_e``, ``a_e``. Arguments: - data (:obj:`list`): Raw training data (e.g. some form of states, actions, obs, etc) Effects: - This is a side effect function which updates the ``reward_table`` attribute in ``self`` . """ batch_s, batch_a = list(zip(*data)) s_distance_matrix = self._get_state_distance(batch_s, self.expert_s) a_distance_matrix = self._get_action_distance(batch_a, self.expert_a) distance_matrix = s_distance_matrix + a_distance_matrix w_e_list = [1 / len(self.expert_data)] * len(self.expert_data) for i, item in enumerate(data): s, a = item w_pi = 1 / self.T c = 0 expert_data_idx = torch.arange(len(self.expert_data)).tolist() while w_pi > 0: selected_dist = distance_matrix[i, expert_data_idx] nearest_distance = selected_dist.min().item() nearest_index_selected = selected_dist.argmin().item() nearest_index = expert_data_idx[nearest_index_selected] if w_pi >= w_e_list[nearest_index]: c = c + nearest_distance * w_e_list[nearest_index] w_pi = w_pi - w_e_list[nearest_index] expert_data_idx.pop(nearest_index_selected) else: c = c + w_pi * nearest_distance w_e_list[nearest_index] = w_e_list[nearest_index] - w_pi w_pi = 0 reward = self.cfg.alpha * math.exp(self.reward_factor * c) self.reward_table[(s, a)] = torch.FloatTensor([reward])
[docs] def clear_data(self) -> None: """ Overview: Clearing training data. \ This is a side effect function which clears the data attribute in ``self`` """ self.train_data.clear()