Source code for ding.framework.middleware.learner
from typing import TYPE_CHECKING, Callable, List, Tuple, Union, Dict, Optional
from easydict import EasyDict
from collections import deque
from ding.framework import task
from ding.data import Buffer
from .functional import trainer, offpolicy_data_fetcher, reward_estimator, her_data_enhancer
if TYPE_CHECKING:
    from ding.framework import Context, OnlineRLContext
    from ding.policy import Policy
    from ding.reward_model import BaseRewardModel
[docs]class OffPolicyLearner:
    """
    Overview:
        The class of the off-policy learner, including data fetching and model training. Use \
            the `__call__` method to execute the whole learning process.
    """
    def __new__(cls, *args, **kwargs):
        if task.router.is_active and not task.has_role(task.role.LEARNER):
            return task.void()
        return super(OffPolicyLearner, cls).__new__(cls)
[docs]    def __init__(
            self,
            cfg: EasyDict,
            policy: 'Policy',
            buffer_: Union[Buffer, List[Tuple[Buffer, float]], Dict[str, Buffer]],
            reward_model: Optional['BaseRewardModel'] = None,
            log_freq: int = 100,
    ) -> None:
        """
        Arguments:
            - cfg (:obj:`EasyDict`): Config.
            - policy (:obj:`Policy`): The policy to be trained.
            - buffer (:obj:`Buffer`): The replay buffer to store the data for training.
            - reward_model (:obj:`BaseRewardModel`): Additional reward estimator likes RND, ICM, etc. \
                default to None.
            - log_freq (:obj:`int`): The frequency (iteration) of showing log.
        """
        self.cfg = cfg
        self._fetcher = task.wrap(offpolicy_data_fetcher(cfg, buffer_))
        self._trainer = task.wrap(trainer(cfg, policy, log_freq=log_freq))
        if reward_model is not None:
            self._reward_estimator = task.wrap(reward_estimator(cfg, reward_model))
        else:
            self._reward_estimator = None
[docs]    def __call__(self, ctx: "OnlineRLContext") -> None:
        """
        Output of ctx:
            - train_output (:obj:`Deque`): The training output in deque.
        """
        train_output_queue = []
        for _ in range(self.cfg.policy.learn.update_per_collect):
            self._fetcher(ctx)
            if ctx.train_data is None:
                break
            if self._reward_estimator:
                self._reward_estimator(ctx)
            self._trainer(ctx)
            train_output_queue.append(ctx.train_output)
        ctx.train_output = train_output_queue
[docs]class HERLearner:
    """
    Overview:
        The class of the learner with the Hindsight Experience Replay (HER). \
            Use the `__call__` method to execute the data featching and training \
            process.
    """
[docs]    def __init__(
            self,
            cfg: EasyDict,
            policy,
            buffer_: Union[Buffer, List[Tuple[Buffer, float]], Dict[str, Buffer]],
            her_reward_model,
    ) -> None:
        """
        Arguments:
            - cfg (:obj:`EasyDict`): Config.
            - policy (:obj:`Policy`): The policy to be trained.
            - buffer\_ (:obj:`Buffer`): The replay buffer to store the data for training.
            - her_reward_model (:obj:`HerRewardModel`): HER reward model.
        """
        self.cfg = cfg
        self._fetcher = task.wrap(her_data_enhancer(cfg, buffer_, her_reward_model))
        self._trainer = task.wrap(trainer(cfg, policy))
[docs]    def __call__(self, ctx: "OnlineRLContext") -> None:
        """
        Output of ctx:
            - train_output (:obj:`Deque`): The deque of training output.
        """
        train_output_queue = []
        for _ in range(self.cfg.policy.learn.update_per_collect):
            self._fetcher(ctx)
            if ctx.train_data is None:
                break
            self._trainer(ctx)
            train_output_queue.append(ctx.train_output)
        ctx.train_output = train_output_queue