Source code for core.data.carla_benchmark_collector

import os
import numpy as np
from collections import deque
from typing import Any, Dict, List, Optional, Union
from itertools import product
import random

from .base_collector import BaseCollector
from core.data.benchmark import ALL_SUITES
from core.data.benchmark.benchmark_utils import get_suites_list, read_pose_txt, get_benchmark_dir
from ding.envs import BaseEnvManager
from ding.torch_utils.data_helper import to_ndarray


[docs]class CarlaBenchmarkCollector(BaseCollector): """ Collector to collect Carla benchmark data with envs. It uses several environments in ``EnvManager`` to collect data. It will automatically get params to reset environments. For every suite provided by user, collector will find all available reset params from benchmark files and store them in a list. When collecting data, the collector will collect each suite in average and store the index of each suite, to make sure each reset param is collected once and only once. The collected data are stored in a trajectory list, with observations, actions and reset param of the episode. Note: Env manager must run WITHOUT auto reset. :Arguments: - cfg (Dict): Config dict. - env (BaseEnvManager): Env manager used to collect data. - policy (Any): Policy used to collect data. Must have ``forward`` method. :Interfaces: reset, collect, close :Properties: - env (BaseEnvManager): Env manager with several environments used to sample data. - policy (Any): Policy instance to interact with envs. """ config = dict( benchmark_dir=None, # suite name, can be str or list suite='FullTown01-v0', seed=None, # whether make seed of each env different dynamic_seed=True, # manually set weathers rather than read from suite weathers=None, # whether apply hard failure judgement in suite # by default in benchmark, collided will not cause failure nocrash=False, # whether shuffle env setting in suite shuffle=False, ) def __init__( self, cfg: Dict, env: BaseEnvManager, policy: Any, ) -> None: super().__init__(cfg, env, policy) self._benchmark_dir = self._cfg.benchmark_dir suite = self._cfg.suite self._seed = self._cfg.seed self._dynamic_seed = self._cfg.dynamic_seed self._weathers = self._cfg.weathers self._shuffle = self._cfg.shuffle if self._benchmark_dir is None: self._benchmark_dir = get_benchmark_dir() self._collect_suite_list = get_suites_list(suite) print('[COLLECTOR] Find suites:', self._collect_suite_list) self._suite_num = len(self._collect_suite_list) self._close_flag = False self._collect_suite_reset_params = dict() self._collect_suite_index_dict = dict() self._traj_cache = {env_id: deque() for env_id in range(self._env_num)} self._obs_cache = [None for _ in range(self._env_num)] self._actions_cache = [None for _ in range(self._env_num)] self._generate_suite_reset_params() @property def env(self) -> BaseEnvManager: return self._env_manager @env.setter def env(self, _env_manager: BaseEnvManager) -> None: assert not _env_manager._auto_reset, "auto reset for env manager should be closed!" self._end_flag = False self._env_manager = _env_manager self._env_manager.launch() self._env_num = self._env_manager.env_num
[docs] def close(self) -> None: """ Close collector and env manager if not closed. """ if self._end_flag: return self._collect_suite_reset_params.clear() self._collect_suite_index_dict.clear() self._env_manager.close() self._end_flag = True
def _generate_suite_reset_params(self): for suite in self._collect_suite_list: self._collect_suite_reset_params[suite] = list() self._collect_suite_index_dict[suite] = 0 args, kwargs = ALL_SUITES[suite] assert len(args) == 0 reset_params = kwargs.copy() poses_txt = reset_params.pop('poses_txt') weathers = reset_params.pop('weathers') if self._weathers is not None: weathers = self._weathers pose_pairs = read_pose_txt(self._benchmark_dir, poses_txt) for weather, (start, end) in product(weathers, pose_pairs): param = reset_params.copy() param['start'] = start param['end'] = end param['weather'] = weather if self._cfg.nocrash: param['col_is_failure'] = True self._collect_suite_reset_params[suite].append(param) if self._shuffle: random.shuffle(self._collect_suite_reset_params[suite])
[docs] def reset(self, suite: Union[List, str] = None) -> None: """ Reset collector and policies. Clear data cache storing data trajectories. If 'suite' is provided in arguments, the collector will change its collected suites and generate reset params again. :Arguments: - suite (Union[List, str], optional): Collected suites after reset. Defaults to None. """ for env_id in range(self._env_num): self._traj_cache[env_id].clear() self._policy.reset([i for i in range(self._env_num)]) self._end_flag = False if suite is not None: self._collect_suite_reset_params.clear() self._collect_suite_index_dict.clear() self._collect_suite_list = get_suites_list(suite) print('[COLLECTOR] Find suites:', [s for s in self._collect_suite_list]) self._suite_num = len(self._collect_suite_list) self._generate_suite_reset_params()
[docs] def collect( self, n_episode: int, policy_kwargs: Optional[Dict] = None, ) -> List: """ Collect data from policy and env manager. It will collect each benchmark suite in average according to 'n_episode'. :Arguments: - n_episode (int): Num of episodes to collect. - policy_kwargs (Dict, optional): Additional arguments in policy forward. Defaults to None. :Returns: List: List of collected data. Each elem stores an episode trajectory. """ if policy_kwargs is None: policy_kwargs = dict() assert len(self._collect_suite_list) > 0, self._collect_suite_list if n_episode < self._env_num: print("[WARNING] Number of envs larger than number of episodes. May waste resource") for env_id in range(self._env_num): self._traj_cache[env_id].clear() self._policy.reset([i for i in range(self._env_num)]) running_env_params = dict() running_envs = 0 prepare_enough = False while not prepare_enough: for suite in self._collect_suite_list: suite_index = self._collect_suite_index_dict[suite] suite_params = self._collect_suite_reset_params[suite] reset_param = suite_params[suite_index] if running_envs < self._env_num and running_envs < n_episode: running_env_params[running_envs] = reset_param running_envs += 1 self._collect_suite_index_dict[suite] += 1 self._collect_suite_index_dict[suite] %= len(suite_params) else: prepare_enough = True break if self._seed is not None: # dynamic seed: different seed for each env if self._dynamic_seed: self._env_manager.seed(self._seed) else: for env_id in running_env_params: self._env_manager.seed({env_id: self._seed}) self._env_manager.reset(running_env_params) return_data = [] env_fail_times = {env_id: 0 for env_id in running_env_params} collected_episodes = running_envs - 1 collected_samples = 0 with self._timer: while True: obs = self._env_manager.ready_obs env_ids = list(obs.keys()) for env_id in env_ids: if env_id not in running_env_params: obs.pop(env_id) if len(obs) == 0: break policy_output = self._policy.forward(obs, **policy_kwargs) actions = {env_id: output['action'] for env_id, output in policy_output.items()} actions = to_ndarray(actions) for env_id in actions: self._obs_cache[env_id] = obs[env_id] self._actions_cache[env_id] = actions[env_id] timesteps = self._env_manager.step(actions) for env_id, timestep in timesteps.items(): if timestep.info.get('abnormal', False): # If there is an abnormal timestep, reset all the related variables(including this env). self._traj_cache[env_id].clear() self._policy.reset([timestep]) self._env_manager.reset(reset_param={env_id: running_env_params[env_id]}) print('[COLLECTOR] env_id abnormal step', env_id, timestep.info) continue transition = self._policy.process_transition( self._obs_cache[env_id], self._actions_cache[env_id], timestep ) self._traj_cache[env_id].append(transition) if timestep.done: if timestep.info['success'] and len(self._traj_cache[env_id]) > 50: env_fail_times[env_id] = 0 env_param = running_env_params[env_id] episode_data = {'env_param': env_param, 'data': list(self._traj_cache[env_id])} return_data.append(episode_data) collected_samples += len(self._traj_cache[env_id]) collected_episodes += 1 if collected_episodes < n_episode: suite_index = collected_episodes % self._suite_num next_suite = self._collect_suite_list[suite_index] reset_param_index = self._collect_suite_index_dict[next_suite] reset_param = self._collect_suite_reset_params[next_suite][reset_param_index] self._collect_suite_index_dict[next_suite] += 1 self._collect_suite_index_dict[next_suite] %= len( self._collect_suite_reset_params[next_suite] ) running_env_params[env_id] = reset_param self._env_manager.reset({env_id: reset_param}) else: env_fail_times[env_id] += 1 info = timestep.info for k in list(info.keys()): if 'reward' in k: info.pop(k) if k in ['timestamp']: info.pop(k) print('[COLLECTOR] env_id {} not success'.format(env_id), info) if env_fail_times[env_id] < 5: # not reach max fail times, continue reset param reset_param = running_env_params[env_id] else: # reach max fail times, skip to next reset param env_fail_times[env_id] = 0 suite_index = collected_episodes % self._suite_num next_suite = self._collect_suite_list[suite_index] reset_param_index = self._collect_suite_index_dict[next_suite] reset_param = self._collect_suite_reset_params[next_suite][reset_param_index] self._collect_suite_index_dict[next_suite] += 1 self._collect_suite_index_dict[next_suite] %= len( self._collect_suite_reset_params[next_suite] ) running_env_params[env_id] = reset_param self._env_manager.reset({env_id: reset_param}) self._traj_cache[env_id].clear() self._policy.reset([env_id]) if self._env_manager.done: break duration = self._timer.value print("[COLLECTOR] Finish collection, time cost: {:.2f}s, total frames: {}".format(duration, collected_samples)) return return_data