Source code for ding.envs.env.base_env
from abc import ABC, abstractmethod
from typing import Any, List, Tuple
import gym
import copy
from easydict import EasyDict
from collections import namedtuple
from ding.utils import import_module, ENV_REGISTRY
BaseEnvTimestep = namedtuple('BaseEnvTimestep', ['obs', 'reward', 'done', 'info'])
# for solving multiple inheritance metaclass conflict between gym and ABC
class FinalMeta(type(ABC), type(gym.Env)):
pass
[docs]class BaseEnv(gym.Env, ABC, metaclass=FinalMeta):
"""
Overview:
Basic environment class, extended from ``gym.Env``
Interface:
``__init__``, ``reset``, ``close``, ``step``, ``random_action``, ``create_collector_env_cfg``, \
``create_evaluator_env_cfg``, ``enable_save_replay``
"""
[docs] @abstractmethod
def __init__(self, cfg: dict) -> None:
"""
Overview:
Lazy init, only related arguments will be initialized in ``__init__`` method, and the concrete \
env will be initialized the first time ``reset`` method is called.
Arguments:
- cfg (:obj:`dict`): Environment configuration in dict type.
"""
raise NotImplementedError
[docs] @abstractmethod
def reset(self) -> Any:
"""
Overview:
Reset the env to an initial state and returns an initial observation.
Returns:
- obs (:obj:`Any`): Initial observation after reset.
"""
raise NotImplementedError
[docs] @abstractmethod
def close(self) -> None:
"""
Overview:
Close env and all the related resources, it should be called after the usage of env instance.
"""
raise NotImplementedError
[docs] @abstractmethod
def step(self, action: Any) -> 'BaseEnv.timestep':
"""
Overview:
Run one timestep of the environment's dynamics/simulation.
Arguments:
- action (:obj:`Any`): The ``action`` input to step with.
Returns:
- timestep (:obj:`BaseEnv.timestep`): The result timestep of env executing one step.
"""
raise NotImplementedError
@abstractmethod
def seed(self, seed: int) -> None:
"""
Overview:
Set the seed for this env's random number generator(s).
Arguments:
- seed (:obj:`Any`): Random seed.
"""
raise NotImplementedError
@abstractmethod
def __repr__(self) -> str:
"""
Overview:
Return the information string of this env instance.
Returns:
- info (:obj:`str`): Information of this env instance, like type and arguments.
"""
raise NotImplementedError
[docs] @staticmethod
def create_collector_env_cfg(cfg: dict) -> List[dict]:
"""
Overview:
Return a list of all of the environment from input config, used in env manager \
(a series of vectorized env), and this method is mainly responsible for envs collecting data.
Arguments:
- cfg (:obj:`dict`): Original input env config, which needs to be transformed into the type of creating \
env instance actually and generated the corresponding number of configurations.
Returns:
- env_cfg_list (:obj:`List[dict]`): List of ``cfg`` including all the config collector envs.
.. note::
Elements(env config) in collector_env_cfg/evaluator_env_cfg can be different, such as server ip and port.
"""
collector_env_num = cfg.pop('collector_env_num')
return [cfg for _ in range(collector_env_num)]
[docs] @staticmethod
def create_evaluator_env_cfg(cfg: dict) -> List[dict]:
"""
Overview:
Return a list of all of the environment from input config, used in env manager \
(a series of vectorized env), and this method is mainly responsible for envs evaluating performance.
Arguments:
- cfg (:obj:`dict`): Original input env config, which needs to be transformed into the type of creating \
env instance actually and generated the corresponding number of configurations.
Returns:
- env_cfg_list (:obj:`List[dict]`): List of ``cfg`` including all the config evaluator envs.
"""
evaluator_env_num = cfg.pop('evaluator_env_num')
return [cfg for _ in range(evaluator_env_num)]
# optional method
[docs] def enable_save_replay(self, replay_path: str) -> None:
"""
Overview:
Save replay file in the given path, and this method need to be self-implemented by each env class.
Arguments:
- replay_path (:obj:`str`): The path to save replay file.
"""
raise NotImplementedError
# optional method
[docs] def random_action(self) -> Any:
"""
Overview:
Return random action generated from the original action space, usually it is convenient for test.
Returns:
- random_action (:obj:`Any`): Action generated randomly.
"""
pass
[docs]def get_vec_env_setting(cfg: dict, collect: bool = True, eval_: bool = True) -> Tuple[type, List[dict], List[dict]]:
"""
Overview:
Get vectorized env setting (env_fn, collector_env_cfg, evaluator_env_cfg).
Arguments:
- cfg (:obj:`dict`): Original input env config in user config, such as ``cfg.env``.
Returns:
- env_fn (:obj:`type`): Callable object, call it with proper arguments and then get a new env instance.
- collector_env_cfg (:obj:`List[dict]`): A list contains the config of collecting data envs.
- evaluator_env_cfg (:obj:`List[dict]`): A list contains the config of evaluation envs.
.. note::
Elements (env config) in collector_env_cfg/evaluator_env_cfg can be different, such as server ip and port.
"""
import_module(cfg.get('import_names', []))
env_fn = ENV_REGISTRY.get(cfg.type)
collector_env_cfg = env_fn.create_collector_env_cfg(cfg) if collect else None
evaluator_env_cfg = env_fn.create_evaluator_env_cfg(cfg) if eval_ else None
return env_fn, collector_env_cfg, evaluator_env_cfg
[docs]def get_env_cls(cfg: EasyDict) -> type:
"""
Overview:
Get the env class by correspondng module of ``cfg`` and return the callable class.
Arguments:
- cfg (:obj:`dict`): Original input env config in user config, such as ``cfg.env``.
Returns:
- env_cls_type (:obj:`type`): Env module as the corresponding callable class type.
"""
import_module(cfg.get('import_names', []))
return ENV_REGISTRY.get(cfg.type)
def create_model_env(cfg: EasyDict) -> Any:
"""
Overview:
Create model env, which is used in model-based RL.
"""
cfg = copy.deepcopy(cfg)
model_env_fn = get_env_cls(cfg)
cfg.pop('import_names')
cfg.pop('type')
return model_env_fn(**cfg)