Source code for ding.worker.learner.learner_hook
import numbers
import os
from abc import ABC, abstractmethod
from typing import Any, Dict, List
import torch
from easydict import EasyDict
import ding
from ding.utils import allreduce, read_file, save_file, get_rank
[docs]class Hook(ABC):
"""
Overview:
Abstract class for hooks.
Interfaces:
__init__, __call__
Property:
name, priority
"""
[docs] def __init__(self, name: str, priority: float, **kwargs) -> None:
"""
Overview:
Init method for hooks. Set name and priority.
Arguments:
- name (:obj:`str`): The name of hook
- priority (:obj:`float`): The priority used in ``call_hook``'s calling sequence. \
Lower value means higher priority.
"""
self._name = name
assert priority >= 0, "invalid priority value: {}".format(priority)
self._priority = priority
@property
def name(self) -> str:
return self._name
@property
def priority(self) -> float:
return self._priority
@abstractmethod
def __call__(self, engine: Any) -> Any:
"""
Overview:
Should be overwritten by subclass.
Arguments:
- engine (:obj:`Any`): For LearnerHook, it should be ``BaseLearner`` or its subclass.
"""
raise NotImplementedError
[docs]class LearnerHook(Hook):
"""
Overview:
Abstract class for hooks used in Learner.
Interfaces:
__init__
Property:
name, priority, position
.. note::
Subclass should implement ``self.__call__``.
"""
positions = ['before_run', 'after_run', 'before_iter', 'after_iter']
[docs] def __init__(self, *args, position: str, **kwargs) -> None:
"""
Overview:
Init LearnerHook.
Arguments:
- position (:obj:`str`): The position to call hook in learner. \
Must be in ['before_run', 'after_run', 'before_iter', 'after_iter'].
"""
super().__init__(*args, **kwargs)
assert position in self.positions
self._position = position
@property
def position(self) -> str:
return self._position
[docs]class LoadCkptHook(LearnerHook):
"""
Overview:
Hook to load checkpoint
Interfaces:
__init__, __call__
Property:
name, priority, position
"""
[docs] def __init__(self, *args, ext_args: EasyDict = EasyDict(), **kwargs) -> None:
"""
Overview:
Init LoadCkptHook.
Arguments:
- ext_args (:obj:`EasyDict`): Extended arguments. Use ``ext_args.freq`` to set ``load_ckpt_freq``.
"""
super().__init__(*args, **kwargs)
self._load_path = ext_args['load_path']
[docs] def __call__(self, engine: 'BaseLearner') -> None: # noqa
"""
Overview:
Load checkpoint to learner. Checkpoint info includes policy state_dict and iter num.
Arguments:
- engine (:obj:`BaseLearner`): The BaseLearner to load checkpoint to.
"""
path = self._load_path
if path == '': # not load
return
state_dict = read_file(path)
if 'last_iter' in state_dict:
last_iter = state_dict.pop('last_iter')
engine.last_iter.update(last_iter)
engine.policy.load_state_dict(state_dict)
engine.info('{} load ckpt in {}'.format(engine.instance_name, path))
[docs]class SaveCkptHook(LearnerHook):
"""
Overview:
Hook to save checkpoint
Interfaces:
__init__, __call__
Property:
name, priority, position
"""
[docs] def __init__(self, *args, ext_args: EasyDict = EasyDict(), **kwargs) -> None:
"""
Overview:
init SaveCkptHook
Arguments:
- ext_args (:obj:`EasyDict`): extended_args, use ext_args.freq to set save_ckpt_freq
"""
super().__init__(*args, **kwargs)
if ext_args == {}:
self._freq = 1
else:
self._freq = ext_args.freq
[docs] def __call__(self, engine: 'BaseLearner') -> None: # noqa
"""
Overview:
Save checkpoint in corresponding path.
Checkpoint info includes policy state_dict and iter num.
Arguments:
- engine (:obj:`BaseLearner`): the BaseLearner which needs to save checkpoint
"""
if engine.rank == 0 and engine.last_iter.val % self._freq == 0:
if engine.instance_name == 'learner':
dirname = './{}/ckpt'.format(engine.exp_name)
else:
dirname = './{}/ckpt_{}'.format(engine.exp_name, engine.instance_name)
if not os.path.exists(dirname):
try:
os.makedirs(dirname)
except FileExistsError:
pass
ckpt_name = engine.ckpt_name if engine.ckpt_name else 'iteration_{}.pth.tar'.format(engine.last_iter.val)
path = os.path.join(dirname, ckpt_name)
state_dict = engine.policy.state_dict()
state_dict.update({'last_iter': engine.last_iter.val})
save_file(path, state_dict)
engine.info('{} save ckpt in {}'.format(engine.instance_name, path))
[docs]class LogShowHook(LearnerHook):
"""
Overview:
Hook to show log
Interfaces:
__init__, __call__
Property:
name, priority, position
"""
[docs] def __init__(self, *args, ext_args: EasyDict = EasyDict(), **kwargs) -> None:
"""
Overview:
init LogShowHook
Arguments:
- ext_args (:obj:`EasyDict`): extended_args, use ext_args.freq to set freq
"""
super().__init__(*args, **kwargs)
if ext_args == {}:
self._freq = 1
else:
self._freq = ext_args.freq
[docs] def __call__(self, engine: 'BaseLearner') -> None: # noqa
"""
Overview:
Show log, update record and tb_logger if rank is 0 and at interval iterations,
clear the log buffer for all learners regardless of rank
Arguments:
- engine (:obj:`BaseLearner`): the BaseLearner
"""
# Only show log for rank 0 learner
if engine.rank != 0:
for k in engine.log_buffer:
engine.log_buffer[k].clear()
return
# For 'scalar' type variables: log_buffer -> tick_monitor -> monitor_time.step
for k, v in engine.log_buffer['scalar'].items():
setattr(engine.monitor, k, v)
engine.monitor.time.step()
iters = engine.last_iter.val
if iters % self._freq == 0:
engine.info("=== Training Iteration {} Result ===".format(iters))
# For 'scalar' type variables: tick_monitor -> var_dict -> text_logger & tb_logger
var_dict = {}
log_vars = engine.policy.monitor_vars()
attr = 'avg'
for k in log_vars:
k_attr = k + '_' + attr
var_dict[k_attr] = getattr(engine.monitor, attr)[k]()
engine.logger.info(engine.logger.get_tabulate_vars_hor(var_dict))
for k, v in var_dict.items():
engine.tb_logger.add_scalar('{}_iter/'.format(engine.instance_name) + k, v, iters)
engine.tb_logger.add_scalar('{}_step/'.format(engine.instance_name) + k, v, engine._collector_envstep)
# For 'histogram' type variables: log_buffer -> tb_var_dict -> tb_logger
tb_var_dict = {}
for k in engine.log_buffer['histogram']:
new_k = '{}/'.format(engine.instance_name) + k
tb_var_dict[new_k] = engine.log_buffer['histogram'][k]
for k, v in tb_var_dict.items():
engine.tb_logger.add_histogram(k, v, iters)
for k in engine.log_buffer:
engine.log_buffer[k].clear()
[docs]class LogReduceHook(LearnerHook):
"""
Overview:
Hook to reduce the distributed(multi-gpu) logs
Interfaces:
__init__, __call__
Property:
name, priority, position
"""
[docs] def __init__(self, *args, ext_args: EasyDict = EasyDict(), **kwargs) -> None:
"""
Overview:
init LogReduceHook
Arguments:
- ext_args (:obj:`EasyDict`): extended_args, use ext_args.freq to set log_reduce_freq
"""
super().__init__(*args, **kwargs)
[docs] def __call__(self, engine: 'BaseLearner') -> None: # noqa
"""
Overview:
reduce the logs from distributed(multi-gpu) learners
Arguments:
- engine (:obj:`BaseLearner`): the BaseLearner
"""
def aggregate(data):
r"""
Overview:
aggregate the information from all ranks(usually use sync allreduce)
Arguments:
- data (:obj:`dict`): Data that needs to be reduced. \
Could be dict, torch.Tensor, numbers.Integral or numbers.Real.
Returns:
- new_data (:obj:`dict`): data after reduce
"""
if isinstance(data, dict):
new_data = {k: aggregate(v) for k, v in data.items()}
elif isinstance(data, list) or isinstance(data, tuple):
new_data = [aggregate(t) for t in data]
elif isinstance(data, torch.Tensor):
new_data = data.clone().detach()
if ding.enable_linklink:
allreduce(new_data)
else:
new_data = new_data.to(get_rank())
allreduce(new_data)
new_data = new_data.cpu()
elif isinstance(data, numbers.Integral) or isinstance(data, numbers.Real):
new_data = torch.scalar_tensor(data).reshape([1])
if ding.enable_linklink:
allreduce(new_data)
else:
new_data = new_data.to(get_rank())
allreduce(new_data)
new_data = new_data.cpu()
new_data = new_data.item()
else:
raise TypeError("invalid type in reduce: {}".format(type(data)))
return new_data
engine.log_buffer = aggregate(engine.log_buffer)
hook_mapping = {
'load_ckpt': LoadCkptHook,
'save_ckpt': SaveCkptHook,
'log_show': LogShowHook,
'log_reduce': LogReduceHook,
}
def register_learner_hook(name: str, hook_type: type) -> None:
"""
Overview:
Add a new LearnerHook class to hook_mapping, so you can build one instance with `build_learner_hook_by_cfg`.
Arguments:
- name (:obj:`str`): name of the register hook
- hook_type (:obj:`type`): the register hook_type you implemented that realize LearnerHook
Examples:
>>> class HookToRegister(LearnerHook):
>>> def __init__(*args, **kargs):
>>> ...
>>> ...
>>> def __call__(*args, **kargs):
>>> ...
>>> ...
>>> ...
>>> register_learner_hook('name_of_hook', HookToRegister)
>>> ...
>>> hooks = build_learner_hook_by_cfg(cfg)
"""
assert issubclass(hook_type, LearnerHook)
hook_mapping[name] = hook_type
simplified_hook_mapping = {
'log_show_after_iter': lambda freq: hook_mapping['log_show']
('log_show', 20, position='after_iter', ext_args=EasyDict({'freq': freq})),
'load_ckpt_before_run': lambda path: hook_mapping['load_ckpt']
('load_ckpt', 20, position='before_run', ext_args=EasyDict({'load_path': path})),
'save_ckpt_after_iter': lambda freq: hook_mapping['save_ckpt']
('save_ckpt_after_iter', 20, position='after_iter', ext_args=EasyDict({'freq': freq})),
'save_ckpt_after_run': lambda _: hook_mapping['save_ckpt']('save_ckpt_after_run', 20, position='after_run'),
'log_reduce_after_iter': lambda _: hook_mapping['log_reduce']('log_reduce_after_iter', 10, position='after_iter'),
}
def find_char(s: str, flag: str, num: int, reverse: bool = False) -> int:
assert num > 0, num
count = 0
iterable_obj = reversed(range(len(s))) if reverse else range(len(s))
for i in iterable_obj:
if s[i] == flag:
count += 1
if count == num:
return i
return -1
def build_learner_hook_by_cfg(cfg: EasyDict) -> Dict[str, List[Hook]]:
"""
Overview:
Build the learner hooks in hook_mapping by config.
This function is often used to initialize ``hooks`` according to cfg,
while add_learner_hook() is often used to add an existing LearnerHook to `hooks`.
Arguments:
- cfg (:obj:`EasyDict`): Config dict. Should be like {'hook': xxx}.
Returns:
- hooks (:obj:`Dict[str, List[Hook]`): Keys should be in ['before_run', 'after_run', 'before_iter', \
'after_iter'], each value should be a list containing all hooks in this position.
Note:
Lower value means higher priority.
"""
hooks = {k: [] for k in LearnerHook.positions}
for key, value in cfg.items():
if key in simplified_hook_mapping and not isinstance(value, dict):
pos = key[find_char(key, '_', 2, reverse=True) + 1:]
hook = simplified_hook_mapping[key](value)
priority = hook.priority
else:
priority = value.get('priority', 100)
pos = value.position
ext_args = value.get('ext_args', {})
hook = hook_mapping[value.type](value.name, priority, position=pos, ext_args=ext_args)
idx = 0
for i in reversed(range(len(hooks[pos]))):
if priority >= hooks[pos][i].priority:
idx = i + 1
break
hooks[pos].insert(idx, hook)
return hooks
def add_learner_hook(hooks: Dict[str, List[Hook]], hook: LearnerHook) -> None:
"""
Overview:
Add a learner hook(:obj:`LearnerHook`) to hooks(:obj:`Dict[str, List[Hook]`)
Arguments:
- hooks (:obj:`Dict[str, List[Hook]`): You can refer to ``build_learner_hook_by_cfg``'s return ``hooks``.
- hook (:obj:`LearnerHook`): The LearnerHook which will be added to ``hooks``.
"""
position = hook.position
priority = hook.priority
idx = 0
for i in reversed(range(len(hooks[position]))):
if priority >= hooks[position][i].priority:
idx = i + 1
break
assert isinstance(hook, LearnerHook)
hooks[position].insert(idx, hook)
def merge_hooks(hooks1: Dict[str, List[Hook]], hooks2: Dict[str, List[Hook]]) -> Dict[str, List[Hook]]:
"""
Overview:
Merge two hooks dict, which have the same keys, and each value is sorted by hook priority with stable method.
Arguments:
- hooks1 (:obj:`Dict[str, List[Hook]`): hooks1 to be merged.
- hooks2 (:obj:`Dict[str, List[Hook]`): hooks2 to be merged.
Returns:
- new_hooks (:obj:`Dict[str, List[Hook]`): New merged hooks dict.
Note:
This merge function uses stable sort method without disturbing the same priority hook.
"""
assert set(hooks1.keys()) == set(hooks2.keys())
new_hooks = {}
for k in hooks1.keys():
new_hooks[k] = sorted(hooks1[k] + hooks2[k], key=lambda x: x.priority)
return new_hooks
def show_hooks(hooks: Dict[str, List[Hook]]) -> None:
for k in hooks.keys():
print('{}: {}'.format(k, [x.__class__.__name__ for x in hooks[k]]))