Source code for ding.reward_model.base_reward_model
from abc import ABC, abstractmethod
from typing import Dict
from easydict import EasyDict
from ditk import logging
import os
import copy
from typing import Any
from ding.utils import REWARD_MODEL_REGISTRY, import_module, save_file
[docs]class BaseRewardModel(ABC):
"""
Overview:
the base class of reward model
Interface:
``default_config``, ``estimate``, ``train``, ``clear_data``, ``collect_data``, ``load_expert_date``
"""
@classmethod
def default_config(cls: type) -> EasyDict:
cfg = EasyDict(copy.deepcopy(cls.config))
cfg.cfg_type = cls.__name__ + 'Dict'
return cfg
[docs] @abstractmethod
def estimate(self, data: list) -> Any:
"""
Overview:
estimate reward
Arguments:
- data (:obj:`List`): the list of data used for estimation
Returns / Effects:
- This can be a side effect function which updates the reward value
- If this function returns, an example returned object can be reward (:obj:`Any`): the estimated reward
"""
raise NotImplementedError()
[docs] @abstractmethod
def train(self, data) -> None:
"""
Overview:
Training the reward model
Arguments:
- data (:obj:`Any`): Data used for training
Effects:
- This is mostly a side effect function which updates the reward model
"""
raise NotImplementedError()
[docs] @abstractmethod
def collect_data(self, data) -> None:
"""
Overview:
Collecting training data in designated formate or with designated transition.
Arguments:
- data (:obj:`Any`): Raw training data (e.g. some form of states, actions, obs, etc)
Returns / Effects:
- This can be a side effect function which updates the data attribute in ``self``
"""
raise NotImplementedError()
[docs] @abstractmethod
def clear_data(self) -> None:
"""
Overview:
Clearing training data. \
This can be a side effect function which clears the data attribute in ``self``
"""
raise NotImplementedError()
[docs] def load_expert_data(self, data) -> None:
"""
Overview:
Getting the expert data, usually used in inverse RL reward model
Arguments:
- data (:obj:`Any`): Expert data
Effects:
This is mostly a side effect function which updates the expert data attribute (e.g. ``self.expert_data``)
"""
pass
def reward_deepcopy(self, train_data) -> Any:
"""
Overview:
this method deepcopy reward part in train_data, and other parts keep shallow copy
to avoid the reward part of train_data in the replay buffer be incorrectly modified.
Arguments:
- train_data (:obj:`List`): the List of train data in which the reward part will be operated by deepcopy.
"""
train_data_reward_deepcopy = [
{k: copy.deepcopy(v) if k == 'reward' else v
for k, v in sample.items()} for sample in train_data
]
return train_data_reward_deepcopy
def state_dict(self) -> Dict:
# this method should be overrided by subclass.
return {}
def load_state_dict(self, _state_dict) -> None:
# this method should be overrided by subclass.
pass
def save(self, path: str = None, name: str = 'best'):
if path is None:
path = self.cfg.exp_name
path = os.path.join(path, 'reward_model', 'ckpt')
if not os.path.exists(path):
try:
os.makedirs(path)
except FileExistsError:
pass
path = os.path.join(path, 'ckpt_{}.pth.tar'.format(name))
state_dict = self.state_dict()
save_file(path, state_dict)
logging.info('Saved reward model ckpt in {}'.format(path))
def create_reward_model(cfg: dict, device: str, tb_logger: 'SummaryWriter') -> BaseRewardModel: # noqa
"""
Overview:
Reward Estimation Model.
Arguments:
- cfg (:obj:`Dict`): Training config
- device (:obj:`str`): Device usage, i.e. "cpu" or "cuda"
- tb_logger (:obj:`str`): Logger, defaultly set as 'SummaryWriter' for model summary
Returns:
- reward (:obj:`Any`): The reward model
"""
cfg = copy.deepcopy(cfg)
if 'import_names' in cfg:
import_module(cfg.pop('import_names'))
if hasattr(cfg, 'reward_model'):
reward_model_type = cfg.reward_model.pop('type')
else:
reward_model_type = cfg.pop('type')
return REWARD_MODEL_REGISTRY.build(reward_model_type, cfg, device=device, tb_logger=tb_logger)
def get_reward_model_cls(cfg: EasyDict) -> type:
import_module(cfg.get('import_names', []))
return REWARD_MODEL_REGISTRY.get(cfg.type)