Shortcuts

Source code for ding.framework.middleware.functional.enhancer

from typing import TYPE_CHECKING, Callable
from easydict import EasyDict
from ditk import logging
import torch
from ding.framework import task
if TYPE_CHECKING:
    from ding.framework import OnlineRLContext
    from ding.reward_model import BaseRewardModel, HerRewardModel
    from ding.data import Buffer


[docs]def reward_estimator(cfg: EasyDict, reward_model: "BaseRewardModel") -> Callable: """ Overview: Estimate the reward of `train_data` using `reward_model`. Arguments: - cfg (:obj:`EasyDict`): Config. - reward_model (:obj:`BaseRewardModel`): Reward model. """ if task.router.is_active and not task.has_role(task.role.LEARNER): return task.void() def _enhance(ctx: "OnlineRLContext"): """ Input of ctx: - train_data (:obj:`List`): The list of data used for estimation. """ reward_model.estimate(ctx.train_data) # inplace modification return _enhance
[docs]def her_data_enhancer(cfg: EasyDict, buffer_: "Buffer", her_reward_model: "HerRewardModel") -> Callable: """ Overview: Fetch a batch of data/episode from `buffer_`, \ then use `her_reward_model` to get HER processed episodes from original episodes. Arguments: - cfg (:obj:`EasyDict`): Config which should contain the following keys \ if her_reward_model.episode_size is None: `cfg.policy.learn.batch_size`. - buffer\_ (:obj:`Buffer`): Buffer to sample data from. - her_reward_model (:obj:`HerRewardModel`): Hindsight Experience Replay (HER) model \ which is used to process episodes. """ if task.router.is_active and not task.has_role(task.role.LEARNER): return task.void() def _fetch_and_enhance(ctx: "OnlineRLContext"): """ Output of ctx: - train_data (:obj:`List[treetensor.torch.Tensor]`): The HER processed episodes. """ if her_reward_model.episode_size is None: size = cfg.policy.learn.batch_size else: size = her_reward_model.episode_size try: buffered_episode = buffer_.sample(size) train_episode = [d.data for d in buffered_episode] except (ValueError, AssertionError): # You can modify data collect config to avoid this warning, e.g. increasing n_sample, n_episode. logging.warning( "Replay buffer's data is not enough to support training, so skip this training for waiting more data." ) ctx.train_data = None return her_episode = sum([her_reward_model.estimate(e) for e in train_episode], []) ctx.train_data = sum(her_episode, []) return _fetch_and_enhance
def nstep_reward_enhancer(cfg: EasyDict) -> Callable: if task.router.is_active and (not task.has_role(task.role.LEARNER) and not task.has_role(task.role.COLLECTOR)): return task.void() def _enhance(ctx: "OnlineRLContext"): nstep = cfg.policy.nstep gamma = cfg.policy.discount_factor L = len(ctx.trajectories) reward_template = ctx.trajectories[0].reward nstep_rewards = [] value_gamma = [] for i in range(L): valid = min(nstep, L - i) for j in range(1, valid): if ctx.trajectories[j + i].done: valid = j break value_gamma.append(torch.FloatTensor([gamma ** valid])) nstep_reward = [ctx.trajectories[j].reward for j in range(i, i + valid)] if nstep > valid: nstep_reward.extend([torch.zeros_like(reward_template) for j in range(nstep - valid)]) nstep_reward = torch.cat(nstep_reward) # (nstep, ) nstep_rewards.append(nstep_reward) for i in range(L): ctx.trajectories[i].reward = nstep_rewards[i] ctx.trajectories[i].value_gamma = value_gamma[i] return _enhance # TODO MBPO # TODO SIL # TODO TD3 VAE