Shortcuts

如何自定义模型 Wrapper

模型包装器的功能

通常强化学习模型的输出是 V, Q 或者动作 logits。 DI-engine需要定制模型来输出这些的部分或者全部。 为了提高模型的可用性并支持更多 功能, DI-engine 提供模型 wrapper 去通过某个具体的策略(比如RNN或者其他方程)去采样动作。

DI-engine 提供以下模型 wrapper:

  • BaseModelWrapper: 为模型添加重置方法。 在 DI-engine 的策略实现中, 许多策略会调用模型的重置方法(例如 HiddenStateWrapper)。 任何继承了 nn.Module 的模型在使用 model_wrap 函数来 wrap 后将会自动被 BaseModelWrapper wrap。

  • HiddenStateWrapper: 用于需要维护 hidden state 的模型,比如LSTM。

  • SampleWrapper, 包括 ArgmaxSampleWrapper, MultinomialSampleWrapper,EpsGreedySampleWrapper:允许用户通过 argmax、多项式分布或 epsilon 贪心策略采样动作。

  • ActionNoiseWrapper:在输出动作上添加噪声,主要用于连续动作空间环境。

  • TargetNetworkWrapper:为基本模型添加目标网络相关功能,用于需要目标网络的 DQN 等 RL 算法。

举例

用户可以参考以下步骤自定义模型 wrapper:

  1. 像其他 wrapper 一样定义模型 wrapper 类 ding/model/wrappers/model_wrappers.py;

  2. 将 wrapper 的名称添加到 ding/model/wrappers/model_wrappers.py: wrapper_name_map 或使用 wrapper 注册以确保可以通过 ding 检索您的 wrapper。前者的话, 无需额外注册即可通过 wrapper_name 指定调用的模型 wrapper, 后者的话,您需要注册这个 wrapper 。通常格式如下:

@WRAPPER_REGISTRY.register('your_wrapper_name')
class YourWrapper('IModelWrapper'):

    pass
  1. 调用 model_wrap 函数包装你的模型。

wrapped_model = model_wrap(origin_model, wrapper_name='your_wrapper_name', **kwargs)

Note

所有 model wrapper 必须 继承 IModelWrapper

我们将在下面展示 DI-engine 中 HiddenStateWrapper 的实现,以用来解释如何自定义模型 wrapper。

如果我们想在我们的模型中使用 RNN,我们必须在训练过程中保存 hidden state。通过使用 Hidden StateWrapper, 我们可以在不更改政策代码的情况下实现这一点。

HiddenStateWrapper 的结构如下:

class HiddenStateWrapper(IModelWrapper):

    def __init__(
       self, model: Any, state_num: int, save_prev_state: bool = False, init_fn: Callable = lambda: None
       ) -> None:

    """

    Overview:

      Maintain the hidden state for RNN-base model. Each sample in a batch has its own state.
      Init the maintain state and state function; Then wrap the ``model.forward`` method with auto
      saved data ['prev_state'] input, and create the ``model.reset`` method.

    Arguments:

      - model(:obj:`Any`): Wrapped model class, should contain forward method.
      - state_num (:obj:`int`): Number of states to process.
      - save_prev_state (:obj:`bool`): Whether to output the prev state in output['prev_state'].
      - init_fn (:obj:`Callable`): The function which is used to init every hidden state when init and reset.
        Default return None for hidden states.

    """

      ...

  def forward(self, data, **kwargs):
      ...
      return output

  def reset(self, *args, **kwargs):
      ...

  def reset_state(self, state: Optional[list] = None, state_id: Optional[list] = None) -> None:
      ...

  def before_forward(self, data: dict, state_id: Optional[list]) -> Tuple[dict, dict]:
      ...

  def after_forward(self, h: Any, state_info: dict, valid_id: Optional[list] = None) -> None:
      ...
  • __init__: Initialize hidden state as arguments, save it as model property self._state

  • before_forward: Put self._state into model input data, the key is ‘prev_state’

  • after_forward: Save model’s output next_state into self._state

  • reset: Reset wrapper related state, e.g. hidden state in RNN

  • forward: Call before_forward, forward function of model, after_forward in turn

这个过程的数据流如下:

../_images/model_hiddenwrapper_img.png

关于模型 wrapper 的其他示例,您可以在 ding/model/wrappers/model_wrappers.py 找到更多细节。