Shortcuts

中间件

在大部分强化学习流程中,都存在着环境与智能体之间的「探索-利用」循环 —— 从环境中取得数据,训练智能体,取得更好的数据,周而复始。 我们将在后续的 DI-zoo 章节 中详细介绍各个环境的特性,这里将着重实现智能体的交互策略。

强化学习的复杂策略决定了它很难用对象抽象所有参与交互的实体,随着更好的策略和算法不断出现,新的概念和对象无穷无尽, 所以我们的主意是不做对象抽象,而只封装过程,并且力保这些封装后的代码可重用,可替换。这就产生了 DI-engine 的基础概念 —— 中间件。

../_images/middleware.png

如上图所示,每个中间件(图中绿色部分)仅靠名字即可推测其用途,您仅需在 DI-engine 的 middleware 库中选择合适的方法,将它们组合起来,就完成了整个智能体的交互策略。

with task.start(async_mode=False, ctx=OnlineRLContext()):
    task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
    task.use(eps_greedy_handler(cfg))
    task.use(StepCollector(cfg, policy.collect_mode, collector_env))
    task.use(data_pusher(cfg, buffer_))
    task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_))
    task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
    task.run(max_step=100000)

熟悉了中间件的使用之后,您会发现原来强化学习的几大流派 —— Onpolicy, Offpolicy, Offline 等等居然在流程上会有这么多可复用部分, 通过简单的取舍,您就能将一个 Offpolicy 策略的交互流程改造为 Onpolicy 策略。

with task.start(async_mode=False, ctx=OnlineRLContext()):
    task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
    task.use(StepCollector(cfg, policy.collect_mode, collector_env))
    task.use(gae_estimator(cfg, policy.collect_mode))
    task.use(multistep_trainer(cfg, policy.learn_mode))
    task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
    task.run(max_step=100000)

上下文对象(Context)

Context 是为中间件之间传递数据的信使,不同的交互策略决定了它们该使用什么类型的 context, 例如 DI-engine 中提供了 OnlineRLContextOfflineRLContext 两种 context。

@dataclasses.dataclass
class OnlineRLContext(Context):

    # common
    total_step: int = 0
    env_step: int = 0
    env_episode: int = 0
    train_iter: int = 0
    train_data: Union[Dict, List] = None
    ...

    def __post_init__(self):
        self.keep('env_step', 'env_episode', 'train_iter', 'last_eval_iter')

OnlineRLContext 上面保存了在线训练所需要的数据,每个中间件的任务就是利用这些数据,并提交新的数据到 context 上面, 例如 OffPolicyLearner 中间件的任务就是利用 ctx.train_data 训练模型,并且将训练结果写回到 ctx.train_iter 上面。

在每个循环开始,context 会初始化为新的实例,这确保中间件只需关注一次循环内的数据流,简化了逻辑,也减少了内存泄漏的风险。

如果您需要保存属性到下一个循环,例如 env_step,train_iter 这类需要累加的数值,可以用 ctx.keep 方法将它设置为保留字段。 使用 ctx.keep 调用的字段将在新一轮迭代,context 初始化为新的实例时保留,而其他的字段将被重新初始化。 注意,理论上 ctx.keep 不需要,也不应该被用来保留那些集合类型的数据,或者比较复杂的类,比如 list,dict,torch.Tensor 或者 torch.nn.Module 等, 而只应该保存 int,float 等类型的数据到下一个迭代,如果需要的话。

注:__post_init__(self) 是在 __init__(self) 后被立刻调用的方法。在我们的 Context 中,这意味着在每一个字段初始化之后调用该方法。 我们将 self.keep 在该函数中调用,是因为我们需要先将每个字段初始化,才能调用 self.keep 来保留那些被选择的变量。

v0.4.2 更新 Context 到 dataclass

v0.4.2 版本 中,我们将 Context 从 dict 类改为 dataclass 类。 这个改动的原因是:

  • 防止在开发过程中随意在 Context 中添加新字段,即 ctx 中的字段必须在定义时明确清楚;

  • 防止使用者使用字符串去访问 Context 中的具体字段,即,禁止 ctx[‘xxx’]。

因为通过 Context 传递数据不同于通过函数的输入和输出传递数据,会有一个强制的约束。 随意在外部定义一个新的属性,或者使用字符串访问 Context 中的字段的话,很容易在阅读代码或者多人协作时造成混乱,在拼接不同中间件时报错。

通过将 Context 改为 dataclass 类,我们限定使用属性而不是字符串去访问 Context 中的具体字段,并且防止在外部添加新字段。 如果您需要在 Context 中添加新字段的话, 请在相关 Context 的 初始化阶段 添加。 下面是自定义 Context 的一个具体例子:

@dataclasses.dataclass
class MyContext(Context):

    # common
    total_step: int = 0
    var1: int = 0
    var2: int = 0
    var3: Union[Dict, List] = None
    var4: List = None

    def __post_init__(self):
        self.keep('var1', 'var2')

如果认为某新字段有必要添加到整个项目中的话,请向 DI-engine 的 main 分支提出 PR 并说明具体原因。

Context 字段介绍

注:Updated position 不包含 ctx.attribute = None 的情况。

OnlineRLContext

Attribute

Keeped

Type

Role

Updated position

total_step

True

int

The number of total iteration steps.

In the beginning of each middleware execution loop.

env_step

True

int

The number of environment steps.

rolloutor

env_episode

True

int

The number of environment episodes.

rolloutor

train_iter

True

int

The number of training iterations.

trainer, multistep_trainer

train_data

False

Union[Dict, List]

The fetched data used to be trained.

gae_estimator, offpolicy_data_fetcher, offline_data_fetcher, her_data_enhancer

train_output

False

Union[Dict, List[Dict]]

The training output including logit, action and other info.

OffPolicyLearner, HERLearner(List), trainer, multistep_trainer(Dict)

collect_kwargs

False

dict

The dict include epsilon value.

eps_greedy_handler

obs

False

ttorch.Tensor

The input observations collected from all collector environments.

inferencer

action

False

List

The inferred actions listed by env_id.

inferencer

inference_output

False

Dict[int, Dict]

The dict of which the key is env_id (int), and the value is inference result (Dict).

inferencer

trajectories

False

list

The trajectories collected from environment.

StepCollector, nstep_reward_enhancer

episodes

False

list

The episodes collected from environment.

EpisodeCollector

trajectory_end_idx

False

list

The end index of each trajectory in ctx.trajectories.

StepCollector

eval_value

False

float

The average reward in the current evaluation.

interaction_evaluator, metric_evaluator

last_eval_iter

True

int

The last ctx.train_iter that is evaluated.

interaction_evaluator, metric_evaluator

OfflineRLContext

Attribute

Keeped

Type

Role

Updated position

total_step

True

int

The number of total iteration steps.

In the beginning of each middleware execution loop.

train_epoch

False

int

The count of training epoches.

offline_data_fetcher

train_iter

True

int

The number of training iterations.

trainer, multistep_trainer

train_data

False

Union[Dict, List]

The fetched data used to be trained.

gae_estimator, offpolicy_data_fetcher, offline_data_fetcher, her_data_enhancer

train_output

False

Union[Dict, List[Dict]]

The training output including logit, action and other info.

OffPolicyLearner, HERLearner(List), trainer, multistep_trainer(Dict)

eval_value

False

float

The average reward in the current evaluation.

interaction_evaluator, metric_evaluator

last_eval_iter

True

int

The last ctx.train_iter that is evaluated.

interaction_evaluator, metric_evaluator

使用 task 异步执行任务

Task 是 DI-engine 用来管理强化学习交互任务的全局对象,所有的运行时状态都在 task 内维护,上面也提供了一些语法糖来帮助流程变得更简单。

在分秒必争的训练环境中,异步带来的好处是显而易见的。如果能在训练模型时(GPU 密集工作)采集下一次训练的数据(CPU 密集工作),理论上可以将训练时间缩短一半。 而要实现这种异步,则需要控制复杂的流程,小心翼翼的维护各种状态。现在借助中间件和 task,只需更改一个参数,即可实现各个环节的异步。

# 顺序执行
with task.start(async_mode=False, ctx=OnlineRLContext()):
    ...

# 异步执行
with task.start(async_mode=True, ctx=OnlineRLContext()):
    ...

除了训练和采集,有很多环节都可以利用异步的好处,例如在训练模型时,将下一批数据提前搬到 GPU 上;在训练模型的同时评估历史模型的表现。 实践中不妨多尝试一下通过异步执行来加速整个交互流程。

../_images/async.png

不同阶段的中间件

大部分中间件都可以对应不同的阶段,您可以通过下图查看已有的中间件与阶段的对应关系,以便正确的组合各类中间件:

../_images/pipeline.png