Buffer User Guide¶
Buffer Getting Started¶
Basic Concepts of Buffer
In the off-policy RL algorithms, we usually use Experience Replay to improve sample efficiency and reduce the correlation between samples from different time frames. DI-engine uses DequeBuffer to implement common features, like data input, sampling and so on, of experience replay pools. Users can create DequeBuffer object through the follwing codes:
from ding.data import DequeBuffer
buffer = DequeBuffer(size=10)
In DI-engine, we use the dataclass as a structural carrier of the data in buffer as well as some other components. Dataclass is a python3 feature that makes data neat and consistent by specifying the data type of a field (property of a class) in a class, as opposed to a Dict. Compared to a Namedtuple, it can be used to set default values to enable parameter defaults during the initialization phase, or to enable flexible assignment operations during use. In the following, we will introduce the specific operation of buffer for users.
# Data is deposited and processed one sample at a time.
# In the middleware of DI-engine, the cache data type is usually a Dict, which records the obs, next_obs, actions, rewards, etc. of the samples.
for _ in range(10):
# The BufferedData object contains three fields: data, index and meta.
# "Data" is the data to be cached and "Meta" is its meta information (optional, defaults to None), both of which are passed into the buffer via the push method.
# "Index" indicates the index of the logical storage address of the data in the buffer, which is automatically generated by the buffer and does not need to be set manually by the user.
buffer.push('a', meta={})
# Data sampling processes multiple samples at a time, and the user needs to explicitly specify the number of samples. The parameter "replace" indicates whether to put back when sampling, and the default value is False.
# The sampling operation returns a data class object named BufferedData, e.g. BufferedData(data='a', index='67bdfadcd', meta={})
buffered_data = buffer.sample(3, replace=False)
data = [d.data for d in buffered_data]
Using Buffer to Complete Online Training
In the previous subsection, we introduced the actual storage structure of the data in the buffer, as well as the most basic deposit and sample operations. In fact, in most tasks, the user does not need to use these underlying atomic operations. We recommend user to call the buffer object through the DI-engine wrapped middleware to complete the training.
from ding.framework import task
from ding.framework.middleware import data_pusher, OffPolicyLearner
task.use(data_pusher(cfg, buffer))
task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer))
Using Buffer to Load Expert Data
During the imitation learning tasks, like SQIL and DQFD, we need to load some expert experiences before training. Actually, users can use another buffer to hold the expert data, taking SQIL as an example (the complete code can be found at ./ding/example/sqil.py):
from ding.framework.middleware import sqil_data_pusher
buffer = DequeBuffer(size=10)
expert_buffer = DequeBuffer(size=10)
task.use(sqil_data_pusher(cfg, buffer_=buffer, expert=False))
task.use(sqil_data_pusher(cfg, buffer_=expert_buffer, expert=True))
task.use(OffPolicyLearner(cfg, policy.learn_mode, [(buffer, 0.5), (expert_buffer, 0.5)]))
Buffer Advanced¶
In the previous section, we provided a basic application scenario for buffer. Next, we will show you a more comprehensive view of the buffer’s capabilities.
Priority Experience Replay(PER)
In some algorithms, priority experience replay is needed. In DI-engine, you can use PriorityExperienceReplay middleware to enable the buffer priority experience replay function. If users enable the function when putting samples, they must explicitly pass the meta information about the priority of each sample, as shown below. Priority sampling increases the sample elapsed time.
from ding.data.buffer.middleware import PriorityExperienceReplay
buffer = DequeBuffer(size=10)
buffer.use(PriorityExperienceReplay(buffer, IS_weight=True))
for _ in range(10):
# The meta is essentially a Dict that complements the description of the sample and is empty by default.
buffer.push('a', meta={"priority": 2.0})
buffered_data = buffer.sample(3)
Sample Cloning
By default, for mutable objects stored in a buffer (such as list, np.array, torch.tensor, etc.), the sampling operation in fact returns a reference to that object. If the user subsequently makes changes to the content of the reference, it may cause the corresponding content in the sample pool to change as well. In some application scenarios, the user may expect the data in the sample pool to remain unchanged, and this can be done by using the clone_object middleware to return a copy of the object in the buffer at sampling time. In this way, modifications to the copy contents do not affect the original data in the buffer. Sample cloning significantly increases the sampling elapsed time.
from ding.data.buffer.middleware import clone_object
buffer = DequeBuffer(size=10)
buffer.use(clone_object())
Group Sampling
In some specific environments or algorithms, users may wish to collect, store, and process samples by entire episodes.For example, in chess, Go, or card games where players are only rewarded at the end of the game, algorithms solving such tasks often want to process the entire game, and some algorithms like Hindsight Experience Replay (HER) need to sample complete episodes and process them in episodic units. In this case, the user can use group sampling to achieve this goal.
Custom Implementation via Atomic Operations
The aforementioned demand can be implemented by some atomic operations to achieve customization and more flexibility. For example, when storing samples, you can add “episode” information to the meta to specify the episode to which the sample belongs, and when sampling, you can set groupby=”episode” to enable group sampling with the episode keyword. Sampling in groups can seriously increase the sampling time.
buffer = DequeBuffer(size=10) # When storing data, the user needs to add grouping information to the meta, e.g., "episode" as the grouping keyword, and the corresponding value is the specific group buffer.push("a", {"episode": 1}) buffer.push("b", {"episode": 2}) buffer.push("c", {"episode": 2}) # Grouping according to the keyword "episode" requires that the number of different groups in the buffer is not less than the number of samples. grouped_data = buffer.sample(2, groupby="episode")
Implementation through Middleware
In DI-engine, we also provide an integral group sampling operation by the data_pusher middleware. Take the R2D2 algorithm, where episodes of samples are passed in sequences through the LSTM network, for example. In data collection, each env instance corresponds to a unique decision track, so the env_id is recommended to use as the key to distinguish different episodes. The code for R2D2 using group sampling lies below, the full version can be found at ./ding/example/r2d2.py:
buffer = DequeBuffer(size=10) # Here 'env' is used as the keyword for grouping, so that samples with the same env_id will be classified into the same group when sampling. task.use(data_pusher(cfg, buffer, group_by_env=True))
(Available Options) On top of group sampling, you can also use group_sample middleware to implement post-processing of samples, such as: choosing whether to disrupt data within the same group, and setting the maximum length of each group of data.
from ding.data.buffer.middleware import group_sample
buffer = DequeBuffer(size=10)
# The maximum length of each group of data is set to 3, keeping the original order within the group
buffer.use(group_sample(size_in_group=3, ordered_in_group=True))
Delete Multiple Use Samples
By default in the Dequebuffer, samples may be collected repeatedly by multiple sample function calls. If it is not controlled, the training performance will decay since it fits partial samples too many times. To avoid this problem, we can use use_time_check middleware to set the maximum number of times the samples can be used.
from ding.data.buffer.middleware import use_time_check
buffer = DequeBuffer(size=10)
# Set the maximum number of times a single sample can be used to 2
buffer.use(use_time_check(buffer, max_use=2))
The middleware maintains a counter to record the picked times of each sample and writes it to the use_count field in the meta. When a sample is picked, the count will ascend by 1 until it is greater than the maximum to tolerant times by setting and is deleted.