Shortcuts

R2D2

Overview

R2D2 was first proposed in Recurrent experience replay in distributed reinforcement learning. In RNN training with experience replay, the RL algorithms usually face the problem of representational drift and recurrent state staleness. R2D2 utilizes two approaches: stored states and burn-in to mitigate the aforementioned effects. R2D2 agent integrates these findings to achieve significant advances in the state of the art on Atari-57 and matches the state of the art on DMLab-30. The authors claim that, Recurrent Replay Distributed DQN (R2D2) is the first agent to achieve this using a single network architecture and fixed set of hyper-parameters.

Quick Facts

  1. R2D2 is an off-policy, model-free and value-based RL algorithm,

  2. R2D2 is essentially a DQN-based algorithm using a distributed framework, double Q networks, dueling architecture, n-step TD loss, and prioritized experience replay.

  3. R2D2 now only supports discrete action spaces and uses eps-greedy for exploration same as DQN.

  4. R2D2 uses the stored state and burn_in techniques to mitigate the effects of representational drift and recurrent state staleness.

  5. The DI-engine implementation of R2D2 provides res_link key to support residual link in recurrent Q network.

Key Equations or Key Graphs

R2D2 agent is most similar to Ape-X, built upon prioritized distributed replay and n-step double Q-learning (with n = 5), generating experience by a large number of actors (typically 256) and learning from batches of replayed experience by a single learner. The Q network of R2D2 use the dueling network architecture and provide an LSTM layer after the convolutional stack.

Instead of regular \((s, a, r, s^')\) transition tuples, R2D2 stores fixed-length (m = 80) sequences of \((s, a, r)\) in replay, with adjacent sequences overlapping each other by 40 time steps, and never crossing episode boundaries. Specifically, the n-step targets used in R2D2 is:

../_images/r2d2_q_targets.png

Here, \(\theta^{-}\) denotes the target network parameters which are copied from the online network parameters \(\theta\) every 2500 learner steps.

R2D uses the mixture of max and mean absolute n-step TD-errors \(\delta_i\) as prioritization metrics for prioritized experience replay over the sequence:

../_images/r2d2_priority.png

Note

In our DI-engine implementation, at each unroll step, the input to the LSTM-based Q network is just observation and the last hidden state, excluding reward and one-hot action.

For more details about how to use RNN in DI-engine, users can refer to How to use RNN, for data arrangement process in R2D2, users can refer to the section data-arrangement, for the burn-in technique in R2D2, users can refer to the section burn-in-in-r2d2.

Extensions

R2D2 can be combined with:

  • Learning from demonstrations

    Users can refer to R2D3 paper and R2D3 doc of our R2D3 implementation. R2D3 is an agent that makes efficient use of demonstrations to solve hard exploration problems in partially observable environments with highly variable initial conditions.

  • Transformers

    Transformers-based agents take advantage of their powerful attention mechanism to learn better policies in those environments where long-term memory can be beneficial. Users can refer to GTrXL paper and r2d2_gtrxl doc of our GTrXL implementation.

Implementations

The default config of R2D2Policy is defined as follows:

class ding.policy.r2d2.R2D2Policy(cfg: EasyDict, model: Module | None = None, enable_field: List[str] | None = None)[source]
Overview:

Policy class of R2D2, from paper Recurrent Experience Replay in Distributed Reinforcement Learning . R2D2 proposes that several tricks should be used to improve upon DRQN, namely some recurrent experience replay tricks and the burn-in mechanism for off-policy training.

Config:

ID

Symbol

Type

Default Value

Description

Other(Shape)

1

type

str

r2d2

RL policy register name, refer to
registry POLICY_REGISTRY
This arg is optional,
a placeholder

2

cuda

bool

False

Whether to use cuda for network
This arg can be diff-
erent from modes

3

on_policy

bool

False

Whether the RL algorithm is on-policy
or off-policy

4

priority

bool

False

Whether use priority(PER)
Priority sample,
update priority

5

priority_IS
_weight

bool

False

Whether use Importance Sampling Weight
to correct biased update. If True,
priority must be True.

6

discount_
factor

float

0.997, [0.95, 0.999]

Reward’s future discount factor, aka.
gamma
May be 1 when sparse
reward env

7

nstep

int

3, [3, 5]

N-step reward discount sum for target
q_value estimation

8

burnin_step

int

2

The timestep of burnin operation,
which is designed to RNN hidden state
difference caused by off-policy

9

learn.update
per_collect

int

1

How many updates(iterations) to train
after collector’s one collection. Only
valid in serial training
This args can be vary
from envs. Bigger val
means more off-policy

10

learn.batch_
size

int

64

The number of samples of an iteration

11

learn.learning
_rate

float

0.001

Gradient step length of an iteration.

12

learn.value_
rescale

bool

True

Whether use value_rescale function for
predicted value

13

learn.target_
update_freq

int

100

Frequence of target network update.
Hard(assign) update

14

learn.ignore_
done

bool

False

Whether ignore done for target value
calculation.
Enable it for some
fake termination env

15

collect.n_sample

int

[8, 128]

The number of training samples of a
call of collector.
It varies from
different envs

16

collect.unroll
_len

int

1

unroll length of an iteration
In RNN, unroll_len>1

The network interface R2D2 used is defined as follows:

class ding.model.template.q_learning.DRQN(obs_shape: int | SequenceType, action_shape: int | SequenceType, encoder_hidden_size_list: SequenceType = [128, 128, 64], dueling: bool = True, head_hidden_size: int | None = None, head_layer_num: int = 1, lstm_type: str | None = 'normal', activation: Module | None = ReLU(), norm_type: str | None = None, res_link: bool = False)[source]
Overview:

The neural network structure and computation graph of DRQN (DQN + RNN = DRQN) algorithm, which is the most common DQN variant for sequential data and paratially observable environment. The DRQN is composed of three parts: encoder, head and rnn. The encoder is used to extract the feature from various observation, the rnn is used to process the sequential observation and other data, and the head is used to compute the Q value of each action dimension.

Interfaces:

__init__, forward.

Note

Current DRQN supports two types of encoder: FCEncoder and ConvEncoder, two types of head: DiscreteHead and DuelingHead, three types of rnn: normal (LSTM with LayerNorm), pytorch and gru. You can customize your own encoder, rnn or head by inheriting this class.

forward(inputs: Dict, inference: bool = False, saved_state_timesteps: list | None = None) Dict[source]
Overview:

DRQN forward computation graph, input observation tensor to predict q_value.

Arguments:
  • inputs (torch.Tensor): The dict of input data, including observation and previous rnn state.

  • inference: (:obj:’bool’): Whether to enable inference forward mode, if True, we unroll the one timestep transition, otherwise, we unroll the eentire sequence transitions.

  • saved_state_timesteps: (:obj:’Optional[list]’): When inference is False, we unroll the sequence transitions, then we would use this list to indicate how to save and return hidden state.

ArgumentsKeys:
  • obs (torch.Tensor): The raw observation tensor.

  • prev_state (list): The previous rnn state tensor, whose structure depends on lstm_type.

Returns:
  • outputs (Dict): The output of DRQN’s forward, including logit (q_value) and next state.

ReturnsKeys:
  • logit (torch.Tensor): Discrete Q-value output of each possible action dimension.

  • next_state (list): The next rnn state tensor, whose structure depends on lstm_type.

Shapes:
  • obs (torch.Tensor): \((B, N)\), where B is batch size and N is obs_shape

  • logit (torch.Tensor): \((B, M)\), where B is batch size and M is action_shape

Examples:
>>> # Init input's Keys:
>>> prev_state = [[torch.randn(1, 1, 64) for __ in range(2)] for _ in range(4)] # B=4
>>> obs = torch.randn(4,64)
>>> model = DRQN(64, 64) # arguments: 'obs_shape' and 'action_shape'
>>> outputs = model({'obs': inputs, 'prev_state': prev_state}, inference=True)
>>> # Check outputs's Keys
>>> assert isinstance(outputs, dict)
>>> assert outputs['logit'].shape == (4, 64)
>>> assert len(outputs['next_state']) == 4
>>> assert all([len(t) == 2 for t in outputs['next_state']])
>>> assert all([t[0].shape == (1, 1, 64) for t in outputs['next_state']])

Benchmark

Benchmark and comparison of R2D2 algorithm

environment

best mean reward

evaluation results

config link

comparison

Pong (PongNoFrameskip-v4)

20

../_images/pong_r2d2.png

config_link_p


Qbert (QbertNoFrameskip-v4)

6000

../_images/qbert_r2d2_cfg2.png

config_link_q


SpaceInvaders (SpaceInvadersNoFrameskip-v4)

1400

../_images/spaceinvaders_r2d2.png

config_link_s


References

  • Kapturowski S, Ostrovski G, Quan J, et al. Recurrent experience replay in distributed reinforcement learning[C]//International conference on learning representations. 2018.

  • Volodymyr Mnih, Koray Kavukcuoglu, David Silver, Alex Graves, Ioannis Antonoglou, Daan Wierstra, Martin Riedmiller: “Playing Atari with Deep Reinforcement Learning”, 2013; arXiv:1312.5602.

  • Schaul, T., Quan, J., Antonoglou, I., & Silver, D. (2015). Prioritized experience replay. arXiv preprint arXiv:1511.05952.

  • Van Hasselt, H., Guez, A., & Silver, D. (2016, March). Deep reinforcement learning with double q-learning. In Proceedings of the AAAI conference on artificial intelligence (Vol. 30, No. 1).

  • Wang, Z., Schaul, T., Hessel, M., Hasselt, H., Lanctot, M., & Freitas, N. (2016, June). Dueling network architectures for deep reinforcement learning. In International conference on machine learning (pp. 1995-2003). PMLR.

  • Horgan D, Quan J, Budden D, et al. Distributed prioritized experience replay[J]. arXiv preprint arXiv:1803.00933, 2018.

Other Public Implementations

seed_rl

ray