How to migrate your own environment to DI-engine¶
DI-zoo provides users with a large number of commonly used environments for reinforcement learning( supported environments ),but in many research and engineering scenarios, users still need to implement an environment by themselves, and expect to quickly migrate it to DI-engine to meet the relevant specifications of DI-engine. Therefore, in this section, we will introduce how to perform the above migration step by step to meet the specification of the environment base class BaseEnv of DI-engine , so that it can be easily applied in the training pipeline.
The following introduction will start with Basic and Advanced . Basic describes the functions that must be implemented, and the details that you should pay attention to ; Advanced describes some extended functions.
Then DingEnvWrapper will be introduced , it is a “tool” that can quickly convert simple environments such as ClassicControl, Box2d, Atari, Mujoco, GymHybrid, etc. into environments that conform to BaseEnv. And there is a Q & A at the end.
Basic¶
This section describes the specification constraints that users MUST meet, and the features that must be implemented when migrating environments.
If you want to use the environment in the DI-engine, you need to implement a subclass environment that inherits from BaseEnv, such as YourEnv. The relationship between YourEnv and your own environment is a composition relationship, that is, within a YourEnv instance, there will be an instance of an environment that is native to the user (eg, a gym-type environment).
Reinforcement learning environments have some common major interfaces that are implemented by most environments, such as reset(), step(), seed(), etc. In DI-engine, BaseEnv will further encapsulate these interfaces. In most cases, Atari will be used as an example to illustrate. For specific code, please refer to Atari Env and Atari Env Wrapper
__init__()In general, the environment may be instantiated in the
__init__method, but in DI-engine, in order to facilitate the support of “environment vectorization” modules likeEnvManager, the environment instances generally use the Lazy Init mechanism, that is, the__init__method does not initialize the real original environment instance, but only sets the relevant parameter configuration. When theresetmethod is called for the first time , the actual environment initialization will take place.Take Atari for example.
__init__does not instantiate the environment, it just sets the parameter configuration valueself._cfg, and initializes the variableself._init_flagtoFalse(indicating that the environment has not been instantiated).class AtariEnv(BaseEnv): def __init__(self, cfg: dict) -> None: self._cfg = cfg self._init_flag = False
seed()seedis used to set the random seed in the environment. There are two types of the random seed in the environment that need to be set, one is the random seed of the original environment, the other is the library seeds (e.g.random,np.random, etc.) in various environment transformations.For the second type, the setting of the seed of the random library is relatively simple, and it is set directly in the
seedmethod of the environment.For the first type, the seed of the original environment is only assigned in the
seedmethod, but not really set; the real setting is inside theresetmethod of the calling environment, the specific original environmentresetbefore setting.class AtariEnv(BaseEnv): def seed(self, seed: int, dynamic_seed: bool = True) -> None: self._seed = seed self._dynamic_seed = dynamic_seed np.random.seed(self._seed)
For the seeds of the original environment, DI-engine has the concepts of static seeds and dynamic seeds.
Static seed is used in the test environment (evaluator_env) to ensure that the random seed of all episodes are the same, that is, only the fixed static seed value of
self._seedis used whenreset. Need to manually pass thedynamic_seedparameter toFalsein theseedmethod.Dynamic seed is used for the training environment (collector_env), try to make the random seed of each episode different, that is, when
reset, a random number generator will be used100 * np.random.randint(1, 1000)(but the seed of this random number generator is fixed by the environment’sseedmethod, so the reproducibility of the experiment is guaranteed). You need to manually pass in thedynamic_seedparameter asTrueinseed(or you can not pass it, because the default parameter isTrue).reset()The Lazy Init initialization method of DI-engine has been introduced in the
__init__method, that is, the actual environment initialization is performed when the first callresetmethod is performed.The
resetmethod will judge whether the actual environment needs to be instantiated according toself._init_flag(if it isFalse, it will be instantiated; otherwise, it has already been instantiated and can be used directly), and Set the random seed, then call theresetmethod of the original environment to get the observation valueobsin the initial state, and convert it to thenp.ndarraydata format (will be explained in detail in 4) , and initialize the value ofself._eval_episode_return(will be explained in detail in 5), in Atariself._eval_episode_returnrefers to the cumulative sum of the real rewards obtained by a whole episode, used to evaluate the agent Performance on this environment, not used for training.class AtariEnv(BaseEnv): def __init__(self, cfg: dict) -> None: self._cfg = cfg self._init_flag = False def reset(self) -> np.ndarray: if not self._init_flag: self._env = self._make_env(only_info=False) self._init_flag = True if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed: np_seed = 100 * np.random.randint(1, 1000) self._env.seed(self._seed + np_seed) elif hasattr(self, '_seed'): self._env.seed(self._seed) obs = self._env.reset() obs = to_ndarray(obs) self._eval_episode_return = 0. return obs
step()The
stepmethod is responsible for receiving theactionof the current timestep, and then giving therewardof the current timestep and theobsof the next timestep. In DI-engine, you also need to give: The flagdoneof whether the current episode ends (here requiresdonemust be of typebool, notnp.bool), other information in the form of a dictionaryinfo(which includes at least the keyself._eval_episode_return).After getting
reward,obs,done,infoand other data, it needs to be processed and converted intonp.ndarrayformat to conform to the DI-engine specification.self._eval_episode_returnwill accumulate the actual reward obtained at the current step at each time step, and return the accumulated value at the end of an episode (done == True).Finally, put the above four data into
BaseEnvTimestepdefined asnamedtupleand return (defined as:BaseEnvTimestep = namedtuple('BaseEnvTimestep', ['obs', 'reward', 'done ', 'info']))from ding.envs import BaseEnvTimestep class AtariEnv(BaseEnv): def step(self, action: np.ndarray) -> BaseEnvTimestep: assert isinstance(action, np.ndarray), type(action) action = action.item() obs, rew, done, info = self._env.step(action) self._eval_episode_return += rew obs = to_ndarray(obs) rew = to_ndarray([rew]) # Transformed to an array with shape (1, ) if done: info['eval_episode_return'] = self._eval_episode_return return BaseEnvTimestep(obs, rew, done, info)
self._eval_episode_returnIn the Atari environment,
self._eval_episode_returnrefers to the cumulative sum of all rewards of an episode, and the data type ofself._eval_episode_returnmust be a python native type, notnp.array.In the
resetmethod, set the currentself._eval_episode_returnto 0;In the
stepmethod, add the actual reward obtained at each time step toself._eval_episode_return.In the
stepmethod, if the current episode has ended (done == True), then add to theinfodictionary and return:info['eval_episode_return'] = self._eval_episode_return
However, other environments may not require the sum of
self._eval_episode_return. For example, in smac, the winning percentage of the current episode is required, so it is necessary to modify the simple accumulation in the second stepstepmethod. Instead, we should record the game situation and finally return the calculated winning percentage at the end of the episode.Data Specifications
DI-engine requires that the input and output data of each method in the environment must be in
np.ndarrayformat, and the data dtype must benp.int64(integer),np.float32( float) ornp.uint8(image). include:obsreturned by theresetmethodactionreceived by thestepmethodobsreturned by thestepmethodrewardreturned by thestepmethod, here also requires thatrewardmust be one-dimensional, not zero-dimensional, for example, Atari will expand zero-dimensional to one-dimensionalrew = to_ndarray([rew])donereturned by thestepmethod must be of typebool, notnp.bool
Advanced¶
Environment preprocessing wrapper
If many environments are to be used in reinforcement learning training, some preprocessing is required to achieve the purpose of increasing randomness, data normalization, and ease of training. These preprocessing are implemented in the form of wrappers (for the introduction of wrappers, please refer to here ).
Each wrapper for environment preprocessing is a subclass of
gym.Wrapper. For example,NoopResetEnvis to perform a random number of No-Operation actions at the beginning of the episode. It is a means of increasing randomness. It is used as follows:env = gym.make('Pong-v4') env = NoopResetEnv(env)
Since the
resetmethod is implemented inNoopResetEnv, the corresponding logic inNoopResetEnvwill be executed whenenv.reset().The following env wrapper has been implemented in DI-engine:( in
ding/envs/env_wrappers/env_wrappers.py)NoopResetEnv: perform a random number of No-Operation actions at the beginning of the episodeMaxAndSkipEnv: Returns the maximum value in several frames, which can be considered as a kind of max pooling on time stepsWarpFrame: Convert the original image to the color code usingcvtColorof thecv2library, and resize it into an image of a certain length and width (usually 84x84)ScaledFloatFrame: normalize the observation to the interval [0, 1] (keep the dtype asnp.float32)ClipRewardEnv: Pass the reward through a symbolic function to{+1, 0, -1}FrameStack: stacks a certain number (usually 4) of frames together as a new observation, which can be used to deal with POMDP situations, for example, the speed direction of the movement cannot be known by a single frame of informationObsTransposeWrapper: Transpose observation to put channel to first dimObsNormEnv: useRunningMeanStdto normalize the observation for sliding windowsRewardNormEnv: useRunningMeanStdto normalize the reward with sliding windowRamWrapper: Wrap ram env into image-like envEpisodicLifeEnv: treat environments with multiple lives built in (eg Qbert), and treat each life as an episodeFireResetEnv: execute action 1 (fire) immediately after environment resetGymHybridDictActionWrapper: Transform Gym-Hybrid’s originalgym.spaces.Tupleaction space togym.spaces.Dict
If the above wrappers cannot meet your needs, you can also customize the wrappers yourself.
It is worth mentioning that each wrapper must not only complete the change of the corresponding observation/action/reward value, but also modify its space accordingly (if and only when shpae, dtype, etc. are modified), this method will be described in the next described in detail in the section.
Three space attributes
observation/action/reward spaceIf you want to automatically create a neural network based on the dimensions of the environment, or use the
shared_memorytechnique in theEnvManagerto speed up the transmission of large tensor data returned by the environment, you need to let the environment support provide the attributeobservation_spaceaction_spacereward_space.Note
For the sake of code extensibility, we strongly recommend implementing these three space attributes.
The spaces here are all instances of subclasses of
gym.spaces.Space, the most commonly usedgym.spaces.SpaceincludeDiscreteBoxTupleDictetc. shape and dtype need to be given in space. In the original gym environment, most of them will supportobservation_space,action_spaceandreward_range. In DI-engine,reward_rangeis also expanded intoreward_space, so that this All three remain the same.For example, here are the three attributes of cartpole:
class CartpoleEnv(BaseEnv): def __init__(self, cfg: dict = {}) -> None: self._observation_space = gym.spaces.Box( low=np.array([-4.8, float("-inf"), -0.42, float("-inf")]), high=np.array([4.8, float("inf"), 0.42, float("inf")]), shape=(4, ), dtype=np.float32 ) self._action_space = gym.spaces.Discrete(2) self._reward_space = gym.spaces.Box(low=0.0, high=1.0, shape=(1, ), dtype=np.float32) @property def observation_space(self) -> gym.spaces.Space: return self._observation_space @property def action_space(self) -> gym.spaces.Space: return self._action_space @property def reward_space(self) -> gym.spaces.Space: return self._reward_space
Since the cartpole does not use any wrapper, its three spaces are fixed. However, if an environment like Atari has been decorated with multiple wrappers, it is necessary to modify the corresponding space after each wrapper wraps the original environment. For example, Atari will use
ScaledFloatFrameWrapperto normalize the observation to the interval [0, 1], then it will modify itsobservation_spaceaccordingly:class ScaledFloatFrameWrapper(gym.ObservationWrapper): def __init__(self, env): # ... self.observation_space = gym.spaces.Box(low=0., high=1., shape=env.observation_space.shape, dtype=np.float32)
enable_save_replay()DI-enginedoes not require the implementation of therendermethod. If you want to complete the visualization, we recommend implementing theenable_save_replaymethod to save the game video.This method is called before the
resetmethod and after theseedmethod, in which the path to the recording storage is specified. It should be noted that this method does not directly store the video, but only sets a flag for whether to save the video. The code and logic for actually storing the video needs to be implemented by yourself. (Because multiple environments may be opened, and each environment runs multiple episodes, it needs to be distinguished in the file name)Here, an example in DI-engine is given. The
resetmethod uses the decorator provided bygymto encapsulate the environment, giving it the function of storing game videos, as shown in the code:class AtariEnv(BaseEnv): def enable_save_replay(self, replay_path: Optional[str] = None) -> None: if replay_path is None: replay_path = './video' self._replay_path = replay_path def reset(): # ... if self._replay_path is not None: self._env = gym.wrappers.RecordVideo( self._env, video_folder=self._replay_path, episode_trigger=lambda episode_id: True, name_prefix='rl-video-{}'.format(id(self)) ) # ...
In actual use, the order of calling these methods should be:
atari_env = AtariEnv(easydict_cfg) atari_env.seed(413) atari_env.enable_save_replay('./replay_video') obs = atari_env.reset() # ...
Use different config for training environment and test environment
The environment used for training (collector_env) and the environment used for testing (evaluator_env) may use different configuration items. You can implement a static method in the environment to implement custom configuration for different environment configuration items. Take Atari as an example:
class AtariEnv(BaseEnv): @staticmethod def create_collector_env_cfg(cfg: dict) -> List[dict]: collector_env_num = cfg.pop('collector_env_num') cfg = copy.deepcopy(cfg) cfg.is_train = True return [cfg for _ in range(collector_env_num)] @staticmethod def create_evaluator_env_cfg(cfg: dict) -> List[dict]: evaluator_env_num = cfg.pop('evaluator_env_num') cfg = copy.deepcopy(cfg) cfg.is_train = False return [cfg for _ in range(evaluator_env_num)]
In actual use, the original configuration item
cfgcan be converted to obtain two versions of configuration items for training and testing:# env_fn is an env class collector_env_cfg = env_fn.create_collector_env_cfg(cfg) evaluator_env_cfg = env_fn.create_evaluator_env_cfg(cfg)
Setting the
cfg.is_trainitem will use different decorations in the wrapper accordingly. For example, ifcfg.is_train == True, a symbolic function of reward will be used to map to{+1, 0, -1}to facilitate training, ifcfg.is_train == FalseThen the original reward value will remain unchanged, which is convenient for evaluating the performance of the agent during testing.random_action()Some off-policy algorithms hope to use a random strategy to collect some data to fill the buffer before training starts, and complete the initialization of the buffer. For such a need, DI-engine encourages the implementation of the
random_actionmethod.Since the environment already implements
action_space, you can directly call theSpace.sample()method provided in the gym to randomly select actions. But it should be noted that since DI-engine requires all returned actions to be innp.ndarrayformat, some necessary dtype conversions may be required. Theintanddicttypes are converted to thenp.ndarraytype using theto_ndarrayfunction, as shown in the following code:def random_action(self) -> np.ndarray: random_action = self.action_space.sample() if isinstance(random_action, np.ndarray): pass elif isinstance(random_action, int): random_action = to_ndarray([random_action], dtype=np.int64) elif isinstance(random_action, dict): random_action = to_ndarray(random_action) else: raise TypeError( '`random_action` should be either int/np.ndarray or dict of int/np.ndarray, but get {}: {}'.format( type(random_action), random_action ) ) return random_action
default_config()If an environment has some default or commonly used configuration items, you can consider setting the class variable
configas default config (for the convenience of external access, you can also implement the class methoddefault_config, which returns config). As shown in the following code:When running an experiment, a user config file for this experiment is configured, such as
dizoo/mujoco/config/ant_ddpg_config.py. In the user config file, you can omit this part of the key-value pair, and merge default config with user config throughdeep_merge_dicts(remember to use the default config as the first parameter here, the user config is used as the second parameter to ensure that the user config has a higher priority). As shown in the following code:class MujocoEnv(BaseEnv): @classmethod def default_config(cls: type) -> EasyDict: cfg = EasyDict(copy.deepcopy(cls.config)) cfg.cfg_type = cls.__name__ + 'Dict' return cfg config = dict( use_act_scale=False, delay_reward_step=0, ) def __init__(self, cfg) -> None: self._cfg = deep_merge_dicts(self.config, cfg)
Environment implementation correctness check
We provide a set of inspection tools for user-implemented environments to check:
data type of observation/action/reward
reset/step method
Whether there are unreasonable identical references in the observation of two adjacent time steps (that is, deepcopy should be used to avoid identical references)
The implementation of the check tool is in
ding/envs/env/env_implementation_check.py. For the usage of the check tool, please refer toding/envs/env/tests/test_env_implementation_check.py‘stest_an_implemented_env。
DingEnvWrapper¶
DingEnvWrapper can quickly convert simple environments such as ClassicControl, Box2d, Atari, Mujoco, GymHybrid, etc., to BaseEnv compliant environments.
Note: The specific implementation of DingEnvWrapper can be found in ding/envs/env/ding_env_wrapper.py, in addition, you can see Example for more info.
Q & A¶
How should the MARL environment be migrated?
You can refer to Competitive RL
If the environment supports both single-agent, double-agent or even multi-agent, consider different mode classifications
In a multi-agent environment, the number of action and observation matches the number of agents, but the reward and done are not necessarily the same. It is necessary to clarify the definition of reward
Note how the original environment requires actions and observations to be combined (tuples, lists, dictionaries, stacked arrays and so on)
How should the environment of the hybrid action space be migrated?
You can refer to Gym-Hybrid
Some discrete actions (Accelerate, Turn) in Gym-Hybrid need to give corresponding 1-dimensional continuous parameters to represent acceleration and rotation angle, so similar environments need to focus on the definition of their action space