Shortcuts

DT (DecisionTransformer)

Overview

Applying reinforcement learning technology in a specific decision-making field necessitates the transformation of the original problem into a reasonable Markov Decision Process (MDP) problem. However, the effectiveness of conventional reinforcement learning methods may be significantly reduced if the problem environment possesses certain unfriendly characteristics, such as partial observability or non-stationary processes. On the other hand, with the development of the data-driven paradigm in recent years, big data and pre-trained large models have shined in the fields of Computer Vision and Natural Language Processing, such as CLIP, DALL·E, and GPT-3, etc., all of which have achieved amazing results, and sequence prediction technology is one of the core modules among them. But for decision intelligence, especially Reinforcement Learning, due to the lack of large datasets similar to CV and NLP and suitable pre-training tasks, the progress of decision-making large models has been slow.

To promote the development of decision-making large models and enhance the practical value of related technologies, many researchers have shifted their focus to the subfield of Offline RL/Batch RL. Offline RL is a reinforcement learning task that trains policies solely through an offline dataset, without any interaction with the environment during the training process. This raises the question: can we learn from some of the research results in the fields of CV and NLP, such as sequence prediction related technologies?

Consequently, in 2021, a series of works represented by Decision Transformer[3]/Trajectory Transformer[1-2] Transformer emerged. These works aimed to simplify decision-making problems into sequence predictions, apply transformer architectures to Reinforcement Learning (RL) tasks, and establish connections with language models such as GPT-x and BERT. Unlike traditional RL, which involves calculating value functions or policy gradients, the DT directly outputs the optimal action selection using a transformer that masks subsequent sequences. By specifying the return-to-go and utilizing state and action information, it can provide the next action and achieve the desired reward. Impressively, DT has not only reached but also surpassed the performance of state-of-the-art (SOTA) model-free offline RL algorithms in environments like Atari and D4RL (MuJoCo).

Quick Facts

  1. DT is offline rl algorithm

  2. DT support discrete and continuous action space

  3. The DT utilizes a transformer for action prediction, but it has undergone modifications in the architecture of self-attention.

  4. The architecture of the dataset utilized by then DT is dictated by the algorithm’s characteristics. It is imperative that these requirements are met during both the training and testing phases of the model.

Key Equations or Key Graphs

The architecture of DT as following:

../_images/DT.png

The diagram illustrates that when the DT algorithm predicts the action at, it is only related to the current timestep’s rtand st, as well as the previous rt-n, st-n, at-n. It is not related to the subsequent steps. The causal transformer is the module used to achieve this effect.

Pseudo-code

../_images/DT_algo.png

Implementations

The default config of DTPolicy as following:

class ding.policy.dt.DTPolicy(cfg: EasyDict, model: Module | None = None, enable_field: List[str] | None = None)[source]
Overview:

Policy class of Decision Transformer algorithm in discrete environments. Paper link: https://arxiv.org/abs/2106.01345.

The neural network interface used is as follows:

class ding.model.DecisionTransformer(state_dim: int | SequenceType, act_dim: int, n_blocks: int, h_dim: int, context_len: int, n_heads: int, drop_p: float, max_timestep: int = 4096, state_encoder: Module | None = None, continuous: bool = False)[source]
Overview:

The implementation of decision transformer.

Interfaces:

__init__, forward, configure_optimizers

forward(timesteps: Tensor, states: Tensor, actions: Tensor, returns_to_go: Tensor, tar: int | None = None) Tuple[Tensor, Tensor, Tensor][source]
Overview:

Forward computation graph of the decision transformer, input a sequence tensor and return a tensor with the same shape.

Arguments:
  • timesteps (torch.Tensor): The timestep for input sequence.

  • states (torch.Tensor): The sequence of states.

  • actions (torch.Tensor): The sequence of actions.

  • returns_to_go (torch.Tensor): The sequence of return-to-go.

  • tar (Optional[int]): Whether to predict action, regardless of index.

Returns:
  • output (Tuple[torch.Tensor, torch.Tensor, torch.Tensor]): Output contains three tensors, they are correspondingly the predicted states, predicted actions and predicted return-to-go.

Examples:
>>> B, T = 4, 6
>>> state_dim = 3
>>> act_dim = 2
>>> DT_model = DecisionTransformer(                state_dim=state_dim,                act_dim=act_dim,                n_blocks=3,                h_dim=8,                context_len=T,                n_heads=2,                drop_p=0.1,            )
>>> timesteps = torch.randint(0, 100, [B, 3 * T - 1, 1], dtype=torch.long)  # B x T
>>> states = torch.randn([B, T, state_dim])  # B x T x state_dim
>>> actions = torch.randint(0, act_dim, [B, T, 1])
>>> action_target = torch.randint(0, act_dim, [B, T, 1])
>>> returns_to_go_sample = torch.tensor([1, 0.8, 0.6, 0.4, 0.2, 0.]).repeat([B, 1]).unsqueeze(-1).float()
>>> traj_mask = torch.ones([B, T], dtype=torch.long)  # B x T
>>> actions = actions.squeeze(-1)
>>> state_preds, action_preds, return_preds = DT_model.forward(                timesteps=timesteps, states=states, actions=actions, returns_to_go=returns_to_go            )
>>> assert state_preds.shape == torch.Size([B, T, state_dim])
>>> assert return_preds.shape == torch.Size([B, T, 1])
>>> assert action_preds.shape == torch.Size([B, T, act_dim])

Benchmark

Benchmark and comparison of DT algorithm

environment

best mean reward (normalized)

evaluation results

config link

comparison

Hopper
(Hopper-medium)

0.753 +- 0.035

../_images/hopper_medium_dt.png

link_2_Hopper-medium

DT paper

Hopper
(Hopper-expert)

1.170 +- 0.003

../_images/hopper_expert_dt.png

link_2_Hopper-expert

DT paper

Hopper
(Hopper-medium-replay)

0.651 +- 0.096

../_images/hopper_medium_replay_dt.png

link_2_Hopper-medium-replay

DT paper

Hopper
(Hopper-medium-expert)

1.150 +- 0.016

../_images/hopper_medium_expert_dt.png

link_2_Hopper-medium-expert

DT paper

Walker2d
(Walker2d-medium)

0.829 +- 0.020

../_images/walker2d_medium_dt.png

link_2_Walker2d-medium

DT paper

Walker2d
(Walker2d-expert)

1.093 +- 0.004

../_images/walker2d_expert_dt.png

link_2_Walker2d-expert

DT paper

Walker2d
(Walker2d-medium-replay)

0.603 +- 0.014

../_images/walker2d_medium_replay_dt.png

link_2_Walker2d-medium-replay

DT paper

Walker2d
(Walker2d-medium-expert)

1.091 +- 0.002

../_images/walker2d_medium_expert_dt.png

link_2_Walker2d-medium-expert

DT paper

HalfCheetah
(HalfCheetah-medium)

0.433 +- 0.0007

../_images/halfcheetah_medium_dt.png

link_2_HalfCheetah-medium

DT paper

HalfCheetah
(HalfCheetah-expert)

0.662 +- 0.057

../_images/halfcheetah_expert_dt.png

link_2_HalfCheetah-expert

DT paper

HalfCheetah
(HalfCheetah-medium-replay)

0.401 +- 0.007

../_images/halfcheetah_medium_replay_dt.png

link_2_HalfCheetah-medium-replay

DT paper

HalfCheetah
(HalfCheetah-medium-expert)

0.517 +- 0.043

../_images/halfcheetah_medium_expert_dt.png

link_2_HalfCheetah-medium-expert

DT paper

Pong
(PongNoFrameskip-v4)

0.956 +- 0.020

../_images/pong_dt.png

link_2_Pong

DT paper

Breakout
(BreakoutNoFrameskip-v4)

0.976 +- 0.190

../_images/breakout_dt.png

link_2_Breakout

DT paper

P.S.: The above results were obtained by running the same configuration with three different random seeds(123, 213, 321).

Reference