Shortcuts

Source code for ding.framework.middleware.functional.collector

from typing import TYPE_CHECKING, Callable, List, Tuple, Any
from functools import reduce
import treetensor.torch as ttorch
import numpy as np
from ditk import logging
from ding.utils import EasyTimer
from ding.envs import BaseEnvManager
from ding.policy import Policy
from ding.torch_utils import to_ndarray, get_shape0

if TYPE_CHECKING:
    from ding.framework import OnlineRLContext


class TransitionList:

    def __init__(self, env_num: int) -> None:
        self.env_num = env_num
        self._transitions = [[] for _ in range(env_num)]
        self._done_idx = [[] for _ in range(env_num)]

    def append(self, env_id: int, transition: Any) -> None:
        self._transitions[env_id].append(transition)
        if transition.done:
            self._done_idx[env_id].append(len(self._transitions[env_id]))

    def to_trajectories(self) -> Tuple[List[Any], List[int]]:
        trajectories = sum(self._transitions, [])
        lengths = [len(t) for t in self._transitions]
        trajectory_end_idx = [reduce(lambda x, y: x + y, lengths[:i + 1]) for i in range(len(lengths))]
        trajectory_end_idx = [t - 1 for t in trajectory_end_idx]
        return trajectories, trajectory_end_idx

    def to_episodes(self) -> List[List[Any]]:
        episodes = []
        for env_id in range(self.env_num):
            last_idx = 0
            for done_idx in self._done_idx[env_id]:
                episodes.append(self._transitions[env_id][last_idx:done_idx])
                last_idx = done_idx
        return episodes

    def clear(self):
        for item in self._transitions:
            item.clear()
        for item in self._done_idx:
            item.clear()


[docs]def inferencer(seed: int, policy: Policy, env: BaseEnvManager) -> Callable: """ Overview: The middleware that executes the inference process. Arguments: - seed (:obj:`int`): Random seed. - policy (:obj:`Policy`): The policy to be inferred. - env (:obj:`BaseEnvManager`): The env where the inference process is performed. \ The env.ready_obs (:obj:`tnp.array`) will be used as model input. """ env.seed(seed) def _inference(ctx: "OnlineRLContext"): """ Output of ctx: - obs (:obj:`Union[torch.Tensor, Dict[torch.Tensor]]`): The input observations collected \ from all collector environments. - action: (:obj:`List[np.ndarray]`): The inferred actions listed by env_id. - inference_output (:obj:`Dict[int, Dict]`): The dict of which the key is env_id (int), \ and the value is inference result (Dict). """ if env.closed: env.launch() obs = ttorch.as_tensor(env.ready_obs) ctx.obs = obs obs = obs.to(dtype=ttorch.float32) # TODO mask necessary rollout obs = {i: obs[i] for i in range(get_shape0(obs))} # TBD inference_output = policy.forward(obs, **ctx.collect_kwargs) ctx.action = [to_ndarray(v['action']) for v in inference_output.values()] # TBD ctx.inference_output = inference_output return _inference
[docs]def rolloutor( policy: Policy, env: BaseEnvManager, transitions: TransitionList, collect_print_freq=100, ) -> Callable: """ Overview: The middleware that executes the transition process in the env. Arguments: - policy (:obj:`Policy`): The policy to be used during transition. - env (:obj:`BaseEnvManager`): The env for the collection, the BaseEnvManager object or \ its derivatives are supported. - transitions (:obj:`TransitionList`): The transition information which will be filled \ in this process, including `obs`, `next_obs`, `action`, `logit`, `value`, `reward` \ and `done`. """ env_episode_id = [_ for _ in range(env.env_num)] current_id = env.env_num timer = EasyTimer() last_train_iter = 0 total_envstep_count = 0 total_episode_count = 0 total_train_sample_count = 0 env_info = {env_id: {'time': 0., 'step': 0, 'train_sample': 0} for env_id in range(env.env_num)} episode_info = [] def _rollout(ctx: "OnlineRLContext"): """ Input of ctx: - action: (:obj:`List[np.ndarray]`): The inferred actions from previous inference process. - obs (:obj:`Dict[Tensor]`): The states fed into the transition dict. - inference_output (:obj:`Dict[int, Dict]`): The inference results to be fed into the \ transition dict. - train_iter (:obj:`int`): The train iteration count to be fed into the transition dict. - env_step (:obj:`int`): The count of env step, which will increase by 1 for a single \ transition call. - env_episode (:obj:`int`): The count of env episode, which will increase by 1 if the \ trajectory stops. """ nonlocal current_id, env_info, episode_info, timer, \ total_episode_count, total_envstep_count, total_train_sample_count, last_train_iter timesteps = env.step(ctx.action) ctx.env_step += len(timesteps) timesteps = [t.tensor() for t in timesteps] collected_sample = 0 collected_step = 0 collected_episode = 0 interaction_duration = timer.value / len(timesteps) for i, timestep in enumerate(timesteps): with timer: transition = policy.process_transition(ctx.obs[i], ctx.inference_output[i], timestep) transition = ttorch.as_tensor(transition) transition.collect_train_iter = ttorch.as_tensor([ctx.train_iter]) transition.env_data_id = ttorch.as_tensor([env_episode_id[timestep.env_id]]) transitions.append(timestep.env_id, transition) collected_step += 1 collected_sample += len(transition.obs) env_info[timestep.env_id.item()]['step'] += 1 env_info[timestep.env_id.item()]['train_sample'] += len(transition.obs) env_info[timestep.env_id.item()]['time'] += timer.value + interaction_duration if timestep.done: info = { 'reward': timestep.info['eval_episode_return'], 'time': env_info[timestep.env_id.item()]['time'], 'step': env_info[timestep.env_id.item()]['step'], 'train_sample': env_info[timestep.env_id.item()]['train_sample'], } episode_info.append(info) policy.reset([timestep.env_id.item()]) env_episode_id[timestep.env_id.item()] = current_id collected_episode += 1 current_id += 1 ctx.env_episode += 1 total_envstep_count += collected_step total_episode_count += collected_episode total_train_sample_count += collected_sample if (ctx.train_iter - last_train_iter) >= collect_print_freq and len(episode_info) > 0: output_log(episode_info, total_episode_count, total_envstep_count, total_train_sample_count) last_train_iter = ctx.train_iter return _rollout
def output_log(episode_info, total_episode_count, total_envstep_count, total_train_sample_count) -> None: """ Overview: Print the output log information. You can refer to the docs of `Best Practice` to understand \ the training generated logs and tensorboards. Arguments: - train_iter (:obj:`int`): the number of training iteration. """ episode_count = len(episode_info) envstep_count = sum([d['step'] for d in episode_info]) train_sample_count = sum([d['train_sample'] for d in episode_info]) duration = sum([d['time'] for d in episode_info]) episode_return = [d['reward'].item() for d in episode_info] info = { 'episode_count': episode_count, 'envstep_count': envstep_count, 'train_sample_count': train_sample_count, 'avg_envstep_per_episode': envstep_count / episode_count, 'avg_sample_per_episode': train_sample_count / episode_count, 'avg_envstep_per_sec': envstep_count / duration, 'avg_train_sample_per_sec': train_sample_count / duration, 'avg_episode_per_sec': episode_count / duration, 'reward_mean': np.mean(episode_return), 'reward_std': np.std(episode_return), 'reward_max': np.max(episode_return), 'reward_min': np.min(episode_return), 'total_envstep_count': total_envstep_count, 'total_train_sample_count': total_train_sample_count, 'total_episode_count': total_episode_count, # 'each_reward': episode_return, } episode_info.clear() logging.info("collect end:\n{}".format('\n'.join(['{}: {}'.format(k, v) for k, v in info.items()])))