中间件编写规范¶
文件目录与命名¶
DI-engine 内的中间件可分为两类,一种我们称之为 function
,是原子化的操作方法,用几行代码专注做一件事,例如 train
中间件执行模型的训练;另一种我们称之为 module
,可能组合了多种 function
,执行更复杂的逻辑。这种分类方式参考了 pytorch 的 nn 和 nn.functional 。
本质上它们都属于中间件,用法一致。
目录结构上, module
直接放置在 middleware
目录中,以名词命名; function
放置在 middleware/functional
目录中,以动词或名词命名。
同种类型的多个中间件,可写在一个文件中。
ding/
framework/
middleware/
functional/collect.py # Function
collector.py # Module
类,函数,参数¶
编写 function
时,由于代码简短,建议使用函数式风格代码;编写 module
时,建议使用类。例如:
# Function 写法
def train(model: Model):
def _train(ctx: Context):
...
return _train
# Module 写法
class Trainer:
def __init__(self, model: Model):
self._model = model
def __call__(self, ctx: Context):
...
所有函数建议传递明确命名参数,不推荐使用 dict 作为参数传递。如参数过多,推荐使用 TreeTensor。
构造方法¶
大部分中间件都有两层方法,例如 function
的外层函数和 module
的 __init__
函数,是为了初始化中间件的必要参数和对象。
而 function
的返回函数和 module
的 __call__
方法,则是运行时循环调用的过程,只支持 ctx
一个参数。
建议在外部实例化对象传递给中间件,而不是在中间件内部实例化,以确保中间件的无状态和过程化:
# 正确
def train(model: Model):
def _train(ctx: Context):
...
return _train
model = Model()
train(model)
# 错误
def train():
model = Model()
def _train(ctx: Context):
...
return _train
train()
运行时方法¶
编写 function
的返回函数或 module
的 __call__
方法时,需要注意以下几点:
如果方法中有死循环,确保判断了
task.finish
条件退出:
def runtime(ctx: Context):
while True:
if task.finish: # 确保判断 task.finish
break
sleep(1)
task
支持顺序执行和异步执行两种模式,ctx
传递的数据在两种模式下产生的时间不一定相同,在中间件中需要注意判断,并最好同时支持两种模式:
def runtime(ctx: Context):
if ctx.get("next_obs"): # 在异步模式下,采集到的数据可能不在这轮迭代中使用,而会推迟到下一个迭代
ctx.obs = next_obs
ctx.next_obs = get_obs()
else: # 同步模式下,直接采集数据,给接着的 trainer 使用
ctx.obs = get_obs()
中间件内部不建议再开多进程,以免因为前面过多的实例化对象,或者进程嵌套多层导致难以预料的问题,如果需要利用多进程并行,可以将逻辑拆分为多个中间件,利用 DI-engine 的并行能力执行:
# 正确
def train1(ctx: Context):
...
def train2(ctx: Context):
...
task.use(train1)
task.use(train2)
# 错误
def train(ctx: Context):
p1 = mp.Process(target=...)
p1.start()
p2 = mp.Process(target=...)
p2.start()
p1.join()
p2.join()
事件命名规范¶
在使用 DI-engine 内的事件机制时,我们约定事件按以下规范命名:
以广播数据为目的的事件,使用
发出位置_数据名[_参数名_参数值]
命名,例如:league_job_actor_0
(由 league 向 actor 广播数据,传递 job)以远程调用为目的的事件,使用
接收位置_方法名
命名,例如:league_get_job
(由 actor 调用 league 的 get_job 方法)