worker.learner¶
learner_hook¶
Please Reference ding/worker/learner/learner_hook.py for usage
Hook¶
LearnerHook¶
LoadCkptHook¶
- class ding.worker.learner.learner_hook.LoadCkptHook(*args, ext_args: EasyDict = {}, **kwargs)[source]¶
- Overview:
Hook to load checkpoint
- Interfaces:
__init__, __call__
- Property:
name, priority, position
- __call__(engine: BaseLearner) None [source]¶
- Overview:
Load checkpoint to learner. Checkpoint info includes policy state_dict and iter num.
- Arguments:
engine (
BaseLearner
): The BaseLearner to load checkpoint to.
SaveCkptHook¶
- class ding.worker.learner.learner_hook.SaveCkptHook(*args, ext_args: EasyDict = {}, **kwargs)[source]¶
- Overview:
Hook to save checkpoint
- Interfaces:
__init__, __call__
- Property:
name, priority, position
- __call__(engine: BaseLearner) None [source]¶
- Overview:
Save checkpoint in corresponding path. Checkpoint info includes policy state_dict and iter num.
- Arguments:
engine (
BaseLearner
): the BaseLearner which needs to save checkpoint
LogShowHook¶
- class ding.worker.learner.learner_hook.LogShowHook(*args, ext_args: EasyDict = {}, **kwargs)[source]¶
- Overview:
Hook to show log
- Interfaces:
__init__, __call__
- Property:
name, priority, position
- __call__(engine: BaseLearner) None [source]¶
- Overview:
Show log, update record and tb_logger if rank is 0 and at interval iterations, clear the log buffer for all learners regardless of rank
- Arguments:
engine (
BaseLearner
): the BaseLearner
LogReduceHook¶
- class ding.worker.learner.learner_hook.LogReduceHook(*args, ext_args: EasyDict = {}, **kwargs)[source]¶
- Overview:
Hook to reduce the distributed(multi-gpu) logs
- Interfaces:
__init__, __call__
- Property:
name, priority, position
- __call__(engine: BaseLearner) None [source]¶
- Overview:
reduce the logs from distributed(multi-gpu) learners
- Arguments:
engine (
BaseLearner
): the BaseLearner
register_learner_hook¶
- Overview:
Add a new LearnerHook class to hook_mapping, so you can build one instance with build_learner_hook_by_cfg.
- Arguments:
name (
str
): name of the register hookhook_type (
type
): the register hook_type you implemented that realize LearnerHook
- Examples:
>>> class HookToRegister(LearnerHook): >>> def __init__(*args, **kargs): >>> ... >>> ... >>> def __call__(*args, **kargs): >>> ... >>> ... >>> ... >>> register_learner_hook('name_of_hook', HookToRegister) >>> ... >>> hooks = build_learner_hook_by_cfg(cfg)
build_learner_hook_by_cfg¶
- Overview:
Build the learner hooks in hook_mapping by config. This function is often used to initialize
hooks
according to cfg, while add_learner_hook() is often used to add an existing LearnerHook to hooks.- Arguments:
cfg (
EasyDict
): Config dict. Should be like {‘hook’: xxx}.
- Returns:
hooks (
Dict[str, List[Hook]
): Keys should be in [‘before_run’, ‘after_run’, ‘before_iter’, ‘after_iter’], each value should be a list containing all hooks in this position.
- Note:
Lower value means higher priority.
merge_hooks¶
- Overview:
Merge two hooks dict, which have the same keys, and each value is sorted by hook priority with stable method.
- Arguments:
hooks1 (
Dict[str, List[Hook]
): hooks1 to be merged.hooks2 (
Dict[str, List[Hook]
): hooks2 to be merged.
- Returns:
new_hooks (
Dict[str, List[Hook]
): New merged hooks dict.
- Note:
This merge function uses stable sort method without disturbing the same priority hook.
base_learner¶
Please Reference ding/worker/learner/base_learner.py for usage
BaseLearner¶
- class ding.worker.learner.base_learner.BaseLearner(cfg: EasyDict, policy: namedtuple = None, tb_logger: SummaryWriter | None = None, dist_info: Tuple[int, int] = None, exp_name: str | None = 'default_experiment', instance_name: str | None = 'learner')[source]¶
- Overview:
Base class for policy learning.
- Interface:
train, call_hook, register_hook, save_checkpoint, start, setup_dataloader, close
- Property:
learn_info, priority_info, last_iter, train_iter, rank, world_size, policy monitor, log_buffer, logger, tb_logger, ckpt_name, exp_name, instance_name
- __init__(cfg: EasyDict, policy: namedtuple = None, tb_logger: SummaryWriter | None = None, dist_info: Tuple[int, int] = None, exp_name: str | None = 'default_experiment', instance_name: str | None = 'learner') None [source]¶
- Overview:
Initialization method, build common learner components according to cfg, such as hook, wrapper and so on.
- Arguments:
cfg (
EasyDict
): Learner config, you can refer cls.config for details.policy (
namedtuple
): A collection of policy function of learn mode. And policy can also be initialized when runtime.tb_logger (
SummaryWriter
): Tensorboard summary writer.dist_info (
Tuple[int, int]
): Multi-GPU distributed training information.exp_name (
str
): Experiment name, which is used to indicate output directory.instance_name (
str
): Instance name, which should be unique among different learners.
- Notes:
If you want to debug in sync CUDA mode, please add the following code at the beginning of
__init__
.os.environ['CUDA_LAUNCH_BLOCKING'] = "1" # for debug async CUDA
- _setup_hook() None [source]¶
- Overview:
Setup hook for base_learner. Hook is the way to implement some functions at specific time point in base_learner. You can refer to
learner_hook.py
.
- _setup_wrapper() None [source]¶
- Overview:
Use
_time_wrapper
to gettrain_time
.- Note:
data_time
is wrapped insetup_dataloader
.
- call_hook(name: str) None [source]¶
- Overview:
Call the corresponding hook plugins according to position name.
- Arguments:
name (
str
): Hooks in which position to call, should be in [‘before_run’, ‘after_run’, ‘before_iter’, ‘after_iter’].
- close() None [source]¶
- Overview:
[Only Used In Parallel Mode] Close the related resources, e.g. dataloader, tensorboard logger, etc.
- register_hook(hook: LearnerHook) None [source]¶
- Overview:
Add a new learner hook.
- Arguments:
hook (
LearnerHook
): The hook to be addedr.
- save_checkpoint(ckpt_name: str | None = None) None [source]¶
- Overview:
Directly call
save_ckpt_after_run
hook to save checkpoint.- Note:
Must guarantee that “save_ckpt_after_run” is registered in “after_run” hook. This method is called in:
auto_checkpoint
(torch_utils/checkpoint_helper.py
), which is designed for saving checkpoint whenever an exception raises.serial_pipeline
(entry/serial_entry.py
). Used to save checkpoint when reaching new highest episode return.
- setup_dataloader() None [source]¶
- Overview:
[Only Used In Parallel Mode] Setup learner’s dataloader.
Note
Only in parallel mode will we use attributes
get_data
and_dataloader
to get data from file system; Instead, in serial version, we can fetch data from memory directly.In parallel mode,
get_data
is set byLearnerCommHelper
, and should be callable. Users don’t need to know the related details if not necessary.
- train(data: dict, envstep: int = -1, policy_kwargs: dict | None = None) None [source]¶
- Overview:
Given training data, implement network update for one iteration and update related variables. Learner’s API for serial entry. Also called in
start
for each iteration’s training.- Arguments:
data (
dict
): Training data which is retrieved from repaly buffer.
Note
_policy
must be set before calling this method._policy.forward
method contains: forward, backward, grad sync(if in multi-gpu mode) and parameter update.before_iter
andafter_iter
hooks are called at the beginning and ending.
create_learner¶
- Overview:
Given the key(learner_name), create a new learner instance if in learner_mapping’s values, or raise an KeyError. In other words, a derived learner must first register, then can call
create_learner
to get the instance.- Arguments:
cfg (
EasyDict
): Learner config. Necessary keys: [learner.import_module, learner.learner_type].
- Returns:
learner (
BaseLearner
): The created new learner, should be an instance of one of learner_mapping’s values.