Shortcuts

Source code for ding.utils.default_helper

from typing import Union, Mapping, List, NamedTuple, Tuple, Callable, Optional, Any, Dict
import copy
from ditk import logging
import random
from functools import lru_cache  # in python3.9, we can change to cache
import numpy as np
import torch
import treetensor.torch as ttorch


[docs]def get_shape0(data: Union[List, Dict, torch.Tensor, ttorch.Tensor]) -> int: """ Overview: Get shape[0] of data's torch tensor or treetensor Arguments: - data (:obj:`Union[List,Dict,torch.Tensor,ttorch.Tensor]`): data to be analysed Returns: - shape[0] (:obj:`int`): first dimension length of data, usually the batchsize. """ if isinstance(data, list) or isinstance(data, tuple): return get_shape0(data[0]) elif isinstance(data, dict): for k, v in data.items(): return get_shape0(v) elif isinstance(data, torch.Tensor): return data.shape[0] elif isinstance(data, ttorch.Tensor): def fn(t): item = list(t.values())[0] if np.isscalar(item[0]): return item[0] else: return fn(item) return fn(data.shape) else: raise TypeError("Error in getting shape0, not support type: {}".format(data))
[docs]def lists_to_dicts( data: Union[List[Union[dict, NamedTuple]], Tuple[Union[dict, NamedTuple]]], recursive: bool = False, ) -> Union[Mapping[object, object], NamedTuple]: """ Overview: Transform a list of dicts to a dict of lists. Arguments: - data (:obj:`Union[List[Union[dict, NamedTuple]], Tuple[Union[dict, NamedTuple]]]`): A dict of lists need to be transformed - recursive (:obj:`bool`): whether recursively deals with dict element Returns: - newdata (:obj:`Union[Mapping[object, object], NamedTuple]`): A list of dicts as a result Example: >>> from ding.utils import * >>> lists_to_dicts([{1: 1, 10: 3}, {1: 2, 10: 4}]) {1: [1, 2], 10: [3, 4]} """ if len(data) == 0: raise ValueError("empty data") if isinstance(data[0], dict): if recursive: new_data = {} for k in data[0].keys(): if isinstance(data[0][k], dict) and k != 'prev_state': tmp = [data[b][k] for b in range(len(data))] new_data[k] = lists_to_dicts(tmp) else: new_data[k] = [data[b][k] for b in range(len(data))] else: new_data = {k: [data[b][k] for b in range(len(data))] for k in data[0].keys()} elif isinstance(data[0], tuple) and hasattr(data[0], '_fields'): # namedtuple new_data = type(data[0])(*list(zip(*data))) else: raise TypeError("not support element type: {}".format(type(data[0]))) return new_data
[docs]def dicts_to_lists(data: Mapping[object, List[object]]) -> List[Mapping[object, object]]: """ Overview: Transform a dict of lists to a list of dicts. Arguments: - data (:obj:`Mapping[object, list]`): A list of dicts need to be transformed Returns: - newdata (:obj:`List[Mapping[object, object]]`): A dict of lists as a result Example: >>> from ding.utils import * >>> dicts_to_lists({1: [1, 2], 10: [3, 4]}) [{1: 1, 10: 3}, {1: 2, 10: 4}] """ new_data = [v for v in data.values()] new_data = [{k: v for k, v in zip(data.keys(), t)} for t in list(zip(*new_data))] return new_data
[docs]def override(cls: type) -> Callable[[ Callable, ], Callable]: """ Overview: Annotation for documenting method overrides. Arguments: - cls (:obj:`type`): The superclass that provides the overridden method. If this cls does not actually have the method, an error is raised. """ def check_override(method: Callable) -> Callable: if method.__name__ not in dir(cls): raise NameError("{} does not override any method of {}".format(method, cls)) return method return check_override
[docs]def squeeze(data: object) -> object: """ Overview: Squeeze data from tuple, list or dict to single object Arguments: - data (:obj:`object`): data to be squeezed Example: >>> a = (4, ) >>> a = squeeze(a) >>> print(a) >>> 4 """ if isinstance(data, tuple) or isinstance(data, list): if len(data) == 1: return data[0] else: return tuple(data) elif isinstance(data, dict): if len(data) == 1: return list(data.values())[0] return data
default_get_set = set()
[docs]def default_get( data: dict, name: str, default_value: Optional[Any] = None, default_fn: Optional[Callable] = None, judge_fn: Optional[Callable] = None ) -> Any: """ Overview: Getting the value by input, checks generically on the inputs with \ at least ``data`` and ``name``. If ``name`` exists in ``data``, \ get the value at ``name``; else, add ``name`` to ``default_get_set``\ with value generated by \ ``default_fn`` (or directly as ``default_value``) that \ is checked by `` judge_fn`` to be legal. Arguments: - data(:obj:`dict`): Data input dictionary - name(:obj:`str`): Key name - default_value(:obj:`Optional[Any]`) = None, - default_fn(:obj:`Optional[Callable]`) = Value - judge_fn(:obj:`Optional[Callable]`) = None Returns: - ret(:obj:`list`): Splitted data - residual(:obj:`list`): Residule list """ if name in data: return data[name] else: assert default_value is not None or default_fn is not None value = default_fn() if default_fn is not None else default_value if judge_fn: assert judge_fn(value), "defalut value({}) is not accepted by judge_fn".format(type(value)) if name not in default_get_set: logging.warning("{} use default value {}".format(name, value)) default_get_set.add(name) return value
[docs]def list_split(data: list, step: int) -> List[list]: """ Overview: Split list of data by step. Arguments: - data(:obj:`list`): List of data for spliting - step(:obj:`int`): Number of step for spliting Returns: - ret(:obj:`list`): List of splitted data. - residual(:obj:`list`): Residule list. This value is ``None`` when ``data`` divides ``steps``. Example: >>> list_split([1,2,3,4],2) ([[1, 2], [3, 4]], None) >>> list_split([1,2,3,4],3) ([[1, 2, 3]], [4]) """ if len(data) < step: return [], data ret = [] divide_num = len(data) // step for i in range(divide_num): start, end = i * step, (i + 1) * step ret.append(data[start:end]) if divide_num * step < len(data): residual = data[divide_num * step:] else: residual = None return ret, residual
[docs]def error_wrapper(fn, default_ret, warning_msg=""): """ Overview: wrap the function, so that any Exception in the function will be catched and return the default_ret Arguments: - fn (:obj:`Callable`): the function to be wraped - default_ret (:obj:`obj`): the default return when an Exception occurred in the function Returns: - wrapper (:obj:`Callable`): the wrapped function Examples: >>> # Used to checkfor Fakelink (Refer to utils.linklink_dist_helper.py) >>> def get_rank(): # Get the rank of linklink model, return 0 if use FakeLink. >>> if is_fake_link: >>> return 0 >>> return error_wrapper(link.get_rank, 0)() """ def wrapper(*args, **kwargs): try: ret = fn(*args, **kwargs) except Exception as e: ret = default_ret if warning_msg != "": one_time_warning(warning_msg, "\ndefault_ret = {}\terror = {}".format(default_ret, e)) return ret return wrapper
[docs]class LimitedSpaceContainer: """ Overview: A space simulator. Interfaces: ``__init__``, ``get_residual_space``, ``release_space`` """
[docs] def __init__(self, min_val: int, max_val: int) -> None: """ Overview: Set ``min_val`` and ``max_val`` of the container, also set ``cur`` to ``min_val`` for initialization. Arguments: - min_val (:obj:`int`): Min volume of the container, usually 0. - max_val (:obj:`int`): Max volume of the container. """ self.min_val = min_val self.max_val = max_val assert (max_val >= min_val) self.cur = self.min_val
[docs] def get_residual_space(self) -> int: """ Overview: Get all residual pieces of space. Set ``cur`` to ``max_val`` Arguments: - ret (:obj:`int`): Residual space, calculated by ``max_val`` - ``cur``. """ ret = self.max_val - self.cur self.cur = self.max_val return ret
[docs] def acquire_space(self) -> bool: """ Overview: Try to get one pice of space. If there is one, return True; Otherwise return False. Returns: - flag (:obj:`bool`): Whether there is any piece of residual space. """ if self.cur < self.max_val: self.cur += 1 return True else: return False
[docs] def release_space(self) -> None: """ Overview: Release only one piece of space. Decrement ``cur``, but ensure it won't be negative. """ self.cur = max(self.min_val, self.cur - 1)
[docs] def increase_space(self) -> None: """ Overview: Increase one piece in space. Increment ``max_val``. """ self.max_val += 1
[docs] def decrease_space(self) -> None: """ Overview: Decrease one piece in space. Decrement ``max_val``. """ self.max_val -= 1
[docs]def deep_merge_dicts(original: dict, new_dict: dict) -> dict: """ Overview: Merge two dicts by calling ``deep_update`` Arguments: - original (:obj:`dict`): Dict 1. - new_dict (:obj:`dict`): Dict 2. Returns: - merged_dict (:obj:`dict`): A new dict that is d1 and d2 deeply merged. """ original = original or {} new_dict = new_dict or {} merged = copy.deepcopy(original) if new_dict: # if new_dict is neither empty dict nor None deep_update(merged, new_dict, True, []) return merged
[docs]def deep_update( original: dict, new_dict: dict, new_keys_allowed: bool = False, whitelist: Optional[List[str]] = None, override_all_if_type_changes: Optional[List[str]] = None ): """ Overview: Update original dict with values from new_dict recursively. Arguments: - original (:obj:`dict`): Dictionary with default values. - new_dict (:obj:`dict`): Dictionary with values to be updated - new_keys_allowed (:obj:`bool`): Whether new keys are allowed. - whitelist (:obj:`Optional[List[str]]`): List of keys that correspond to dict values where new subkeys can be introduced. This is only at the top level. - override_all_if_type_changes(:obj:`Optional[List[str]]`): List of top level keys with value=dict, for which we always simply override the entire value (:obj:`dict`), if the "type" key in that value dict changes. .. note:: If new key is introduced in new_dict, then if new_keys_allowed is not True, an error will be thrown. Further, for sub-dicts, if the key is in the whitelist, then new subkeys can be introduced. """ whitelist = whitelist or [] override_all_if_type_changes = override_all_if_type_changes or [] for k, value in new_dict.items(): if k not in original and not new_keys_allowed: raise RuntimeError("Unknown config parameter `{}`. Base config have: {}.".format(k, original.keys())) # Both original value and new one are dicts. if isinstance(original.get(k), dict) and isinstance(value, dict): # Check old type vs old one. If different, override entire value. if k in override_all_if_type_changes and \ "type" in value and "type" in original[k] and \ value["type"] != original[k]["type"]: original[k] = value # Whitelisted key -> ok to add new subkeys. elif k in whitelist: deep_update(original[k], value, True) # Non-whitelisted key. else: deep_update(original[k], value, new_keys_allowed) # Original value not a dict OR new value not a dict: # Override entire value. else: original[k] = value return original
[docs]def flatten_dict(data: dict, delimiter: str = "/") -> dict: """ Overview: Flatten the dict, see example Arguments: - data (:obj:`dict`): Original nested dict - delimiter (str): Delimiter of the keys of the new dict Returns: - data (:obj:`dict`): Flattened nested dict Example: >>> a {'a': {'b': 100}} >>> flatten_dict(a) {'a/b': 100} """ data = copy.deepcopy(data) while any(isinstance(v, dict) for v in data.values()): remove = [] add = {} for key, value in data.items(): if isinstance(value, dict): for subkey, v in value.items(): add[delimiter.join([key, subkey])] = v remove.append(key) data.update(add) for k in remove: del data[k] return data
[docs]def set_pkg_seed(seed: int, use_cuda: bool = True) -> None: """ Overview: Side effect function to set seed for ``random``, ``numpy random``, and ``torch's manual seed``.\ This is usaually used in entry scipt in the section of setting random seed for all package and instance Argument: - seed(:obj:`int`): Set seed - use_cuda(:obj:`bool`) Whether use cude Examples: >>> # ../entry/xxxenv_xxxpolicy_main.py >>> ... # Set random seed for all package and instance >>> collector_env.seed(seed) >>> evaluator_env.seed(seed, dynamic_seed=False) >>> set_pkg_seed(seed, use_cuda=cfg.policy.cuda) >>> ... # Set up RL Policy, etc. >>> ... """ random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if use_cuda and torch.cuda.is_available(): torch.cuda.manual_seed(seed)
[docs]@lru_cache() def one_time_warning(warning_msg: str) -> None: """ Overview: Print warning message only once. Arguments: - warning_msg (:obj:`str`): Warning message. """ logging.warning(warning_msg)
[docs]def split_fn(data, indices, start, end): """ Overview: Split data by indices Arguments: - data (:obj:`Union[List, Dict, torch.Tensor, ttorch.Tensor]`): data to be analysed - indices (:obj:`np.ndarray`): indices to split - start (:obj:`int`): start index - end (:obj:`int`): end index """ if data is None: return None elif isinstance(data, list): return [split_fn(d, indices, start, end) for d in data] elif isinstance(data, dict): return {k1: split_fn(v1, indices, start, end) for k1, v1 in data.items()} elif isinstance(data, str): return data else: return data[indices[start:end]]
[docs]def split_data_generator(data: dict, split_size: int, shuffle: bool = True) -> dict: """ Overview: Split data into batches Arguments: - data (:obj:`dict`): data to be analysed - split_size (:obj:`int`): split size - shuffle (:obj:`bool`): whether shuffle """ assert isinstance(data, dict), type(data) length = [] for k, v in data.items(): if v is None: continue elif k in ['prev_state', 'prev_actor_state', 'prev_critic_state']: length.append(len(v)) elif isinstance(v, list) or isinstance(v, tuple): if isinstance(v[0], str): # some buffer data contains useless string infos, such as 'buffer_id', # which should not be split, so we just skip it continue else: length.append(get_shape0(v[0])) elif isinstance(v, dict): length.append(len(v[list(v.keys())[0]])) else: length.append(len(v)) assert len(length) > 0 # assert len(set(length)) == 1, "data values must have the same length: {}".format(length) # if continuous action, data['logit'] is list of length 2 length = length[0] assert split_size >= 1 if shuffle: indices = np.random.permutation(length) else: indices = np.arange(length) for i in range(0, length, split_size): if i + split_size > length: i = length - split_size batch = split_fn(data, indices, i, i + split_size) yield batch
[docs]class RunningMeanStd(object): """ Overview: Wrapper to update new variable, new mean, and new count Interfaces: ``__init__``, ``update``, ``reset``, ``new_shape`` Properties: - ``mean``, ``std``, ``_epsilon``, ``_shape``, ``_mean``, ``_var``, ``_count`` """
[docs] def __init__(self, epsilon=1e-4, shape=(), device=torch.device('cpu')): """ Overview: Initialize ``self.`` See ``help(type(self))`` for accurate \ signature; setup the properties. Arguments: - env (:obj:`gym.Env`): the environment to wrap. - epsilon (:obj:`Float`): the epsilon used for self for the std output - shape (:obj: `np.array`): the np array shape used for the expression \ of this wrapper on attibutes of mean and variance """ self._epsilon = epsilon self._shape = shape self._device = device self.reset()
[docs] def update(self, x): """ Overview: Update mean, variable, and count Arguments: - ``x``: the batch """ batch_mean = np.mean(x, axis=0) batch_var = np.var(x, axis=0) batch_count = x.shape[0] new_count = batch_count + self._count mean_delta = batch_mean - self._mean new_mean = self._mean + mean_delta * batch_count / new_count # this method for calculating new variable might be numerically unstable m_a = self._var * self._count m_b = batch_var * batch_count m2 = m_a + m_b + np.square(mean_delta) * self._count * batch_count / new_count new_var = m2 / new_count self._mean = new_mean self._var = new_var self._count = new_count
[docs] def reset(self): """ Overview: Resets the state of the environment and reset properties: ``_mean``, ``_var``, ``_count`` """ if len(self._shape) > 0: self._mean = np.zeros(self._shape, 'float32') self._var = np.ones(self._shape, 'float32') else: self._mean, self._var = 0., 1. self._count = self._epsilon
@property def mean(self) -> np.ndarray: """ Overview: Property ``mean`` gotten from ``self._mean`` """ if np.isscalar(self._mean): return self._mean else: return torch.FloatTensor(self._mean).to(self._device) @property def std(self) -> np.ndarray: """ Overview: Property ``std`` calculated from ``self._var`` and the epsilon value of ``self._epsilon`` """ std = np.sqrt(self._var + 1e-8) if np.isscalar(std): return std else: return torch.FloatTensor(std).to(self._device)
[docs] @staticmethod def new_shape(obs_shape, act_shape, rew_shape): """ Overview: Get new shape of observation, acton, and reward; in this case unchanged. Arguments: obs_shape (:obj:`Any`), act_shape (:obj:`Any`), rew_shape (:obj:`Any`) Returns: obs_shape (:obj:`Any`), act_shape (:obj:`Any`), rew_shape (:obj:`Any`) """ return obs_shape, act_shape, rew_shape
[docs]def make_key_as_identifier(data: Dict[str, Any]) -> Dict[str, Any]: """ Overview: Make the key of dict into legal python identifier string so that it is compatible with some python magic method such as ``__getattr``. Arguments: - data (:obj:`Dict[str, Any]`): The original dict data. Return: - new_data (:obj:`Dict[str, Any]`): The new dict data with legal identifier keys. """ def legalization(s: str) -> str: if s[0].isdigit(): s = '_' + s return s.replace('.', '_') new_data = {} for k in data: new_k = legalization(k) new_data[new_k] = data[k] return new_data
[docs]def remove_illegal_item(data: Dict[str, Any]) -> Dict[str, Any]: """ Overview: Remove illegal item in dict info, like str, which is not compatible with Tensor. Arguments: - data (:obj:`Dict[str, Any]`): The original dict data. Return: - new_data (:obj:`Dict[str, Any]`): The new dict data without legal items. """ new_data = {} for k, v in data.items(): if isinstance(v, str): continue new_data[k] = data[k] return new_data