Shortcuts

Source code for ding.data.buffer.buffer

from abc import abstractmethod, ABC
from typing import Any, List, Optional, Union, Callable
import copy
from dataclasses import dataclass
from functools import wraps
from ding.utils import fastcopy


def apply_middleware(func_name: str):

    def wrap_func(base_func: Callable):

        @wraps(base_func)
        def handler(buffer, *args, **kwargs):
            """
            Overview:
                The real processing starts here, we apply the middleware one by one,
                each middleware will receive next `chained` function, which is an executor of next
                middleware. You can change the input arguments to the next `chained` middleware, and you
                also can get the return value from the next middleware, so you have the
                maximum freedom to choose at what stage to implement your method.
            """

            def wrap_handler(middleware, *args, **kwargs):
                if len(middleware) == 0:
                    return base_func(buffer, *args, **kwargs)

                def chain(*args, **kwargs):
                    return wrap_handler(middleware[1:], *args, **kwargs)

                func = middleware[0]
                return func(func_name, chain, *args, **kwargs)

            return wrap_handler(buffer._middleware, *args, **kwargs)

        return handler

    return wrap_func


@dataclass
class BufferedData:
    data: Any
    index: str
    meta: dict


# Register new dispatcher on fastcopy to avoid circular references
def _copy_buffereddata(d: BufferedData) -> BufferedData:
    return BufferedData(data=fastcopy.copy(d.data), index=d.index, meta=fastcopy.copy(d.meta))


fastcopy.dispatch[BufferedData] = _copy_buffereddata


[docs]class Buffer(ABC): """ Buffer is an abstraction of device storage, third-party services or data structures, For example, memory queue, sum-tree, redis, or di-store. """ def __init__(self, size: int) -> None: self._middleware = [] self.size = size
[docs] @abstractmethod def push(self, data: Any, meta: Optional[dict] = None) -> BufferedData: """ Overview: Push data and it's meta information in buffer. Arguments: - data (:obj:`Any`): The data which will be pushed into buffer. - meta (:obj:`dict`): Meta information, e.g. priority, count, staleness. Returns: - buffered_data (:obj:`BufferedData`): The pushed data. """ raise NotImplementedError
[docs] @abstractmethod def sample( self, size: Optional[int] = None, indices: Optional[List[str]] = None, replace: bool = False, sample_range: Optional[slice] = None, ignore_insufficient: bool = False, groupby: Optional[str] = None, unroll_len: Optional[int] = None ) -> Union[List[BufferedData], List[List[BufferedData]]]: """ Overview: Sample data with length ``size``. Arguments: - size (:obj:`Optional[int]`): The number of the data that will be sampled. - indices (:obj:`Optional[List[str]]`): Sample with multiple indices. - replace (:obj:`bool`): If use replace is true, you may receive duplicated data from the buffer. - sample_range (:obj:`slice`): Sample range slice. - ignore_insufficient (:obj:`bool`): If ignore_insufficient is true, sampling more than buffer size with no repetition will not cause an exception. - groupby (:obj:`Optional[str]`): Groupby key in meta, i.e. groupby="episode" - unroll_len (:obj:`Optional[int]`): Number of consecutive frames within a group. Returns: - sample_data (:obj:`Union[List[BufferedData], List[List[BufferedData]]]`): A list of data with length ``size``, may be nested if groupby is set. """ raise NotImplementedError
[docs] @abstractmethod def update(self, index: str, data: Optional[Any] = None, meta: Optional[dict] = None) -> bool: """ Overview: Update data and meta by index Arguments: - index (:obj:`str`): Index of data. - data (:obj:`any`): Pure data. - meta (:obj:`dict`): Meta information. Returns: - success (:obj:`bool`): Success or not, if data with the index not exist in buffer, return false. """ raise NotImplementedError
[docs] @abstractmethod def delete(self, index: str): """ Overview: Delete one data sample by index Arguments: - index (:obj:`str`): Index """ raise NotImplementedError
@abstractmethod def save_data(self, file_name: str): """ Overview: Save buffer data into a file. Arguments: - file_name (:obj:`str`): file name of buffer data """ raise NotImplementedError @abstractmethod def load_data(self, file_name: str): """ Overview: Load buffer data from a file. Arguments: - file_name (:obj:`str`): file name of buffer data """ raise NotImplementedError @abstractmethod def count(self) -> int: raise NotImplementedError @abstractmethod def clear(self) -> None: raise NotImplementedError
[docs] @abstractmethod def get(self, idx: int) -> BufferedData: """ Overview: Get item by subscript index Arguments: - idx (:obj:`int`): Subscript index Returns: - buffered_data (:obj:`BufferedData`): Item from buffer """ raise NotImplementedError
[docs] def use(self, func: Callable) -> "Buffer": """ Overview: Use algorithm middleware to modify the behavior of the buffer. Every middleware should be a callable function, it will receive three argument parts, including: 1. The buffer instance, you can use this instance to visit every thing of the buffer, including the storage. 2. The functions called by the user, there are three methods named `push` , `sample` and `clear` , \ so you can use these function name to decide which action to choose. 3. The remaining arguments passed by the user to the original function, will be passed in `*args` . Each middleware handler should return two parts of the value, including: 1. The first value is `done` (True or False), if done==True, the middleware chain will stop immediately, \ no more middleware will be executed during this execution 2. The remaining values, will be passed to the next middleware or the default function in the buffer. Arguments: - func (:obj:`Callable`): The middleware handler Returns: - buffer (:obj:`Buffer`): The instance self """ self._middleware.append(func) return self
[docs] def view(self) -> "Buffer": r""" Overview: A view is a new instance of buffer, with a deepcopy of every property except the storage. The storage is shared among all the buffer instances. Returns: - buffer (:obj:`Buffer`): The instance self """ return copy.copy(self)
def __copy__(self) -> "Buffer": raise NotImplementedError def __len__(self) -> int: return self.count() def __getitem__(self, idx: int) -> BufferedData: return self.get(idx)