Source code for ding.policy.dt
from typing import List, Dict, Any, Tuple, Optional
from collections import namedtuple
import torch.nn.functional as F
import torch
import numpy as np
from ding.torch_utils import to_device
from ding.utils import POLICY_REGISTRY
from ding.utils.data import default_decollate
from .base_policy import Policy
[docs]@POLICY_REGISTRY.register('dt')
class DTPolicy(Policy):
"""
Overview:
Policy class of Decision Transformer algorithm in discrete environments.
Paper link: https://arxiv.org/abs/2106.01345.
"""
config = dict(
# (str) RL policy register name (refer to function "POLICY_REGISTRY").
type='dt',
# (bool) Whether to use cuda for network.
cuda=False,
# (bool) Whether the RL algorithm is on-policy or off-policy.
on_policy=False,
# (bool) Whether use priority(priority sample, IS weight, update priority)
priority=False,
# (int) N-step reward for target q_value estimation
obs_shape=4,
action_shape=2,
rtg_scale=1000, # normalize returns to go
max_eval_ep_len=1000, # max len of one episode
batch_size=64, # training batch size
wt_decay=1e-4, # decay weight in optimizer
warmup_steps=10000, # steps for learning rate warmup
context_len=20, # length of transformer input
learning_rate=1e-4,
)
def default_model(self) -> Tuple[str, List[str]]:
"""
Overview:
Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \
automatically call this method to get the default model setting and create model.
Returns:
- model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names.
.. note::
The user can define and use customized network model but must obey the same inferface definition indicated \
by import_names path. For example about DQN, its registered name is ``dqn`` and the import_names is \
``ding.model.template.q_learning``.
"""
return 'dt', ['ding.model.template.dt']
[docs] def _init_learn(self) -> None:
"""
Overview:
Initialize the learn mode of policy, including related attributes and modules. For Decision Transformer, \
it mainly contains the optimizer, algorithm-specific arguments such as rtg_scale and lr scheduler.
This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``.
.. note::
For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \
and ``_load_state_dict_learn`` methods.
.. note::
For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method.
.. note::
If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \
with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``.
"""
# rtg_scale: scale of `return to go`
# rtg_target: max target of `return to go`
# Our goal is normalize `return to go` to (0, 1), which will favour the covergence.
# As a result, we usually set rtg_scale == rtg_target.
self.rtg_scale = self._cfg.rtg_scale # normalize returns to go
self.rtg_target = self._cfg.rtg_target # max target reward_to_go
self.max_eval_ep_len = self._cfg.max_eval_ep_len # max len of one episode
lr = self._cfg.learning_rate # learning rate
wt_decay = self._cfg.wt_decay # weight decay
warmup_steps = self._cfg.warmup_steps # warmup steps for lr scheduler
self.clip_grad_norm_p = self._cfg.clip_grad_norm_p
self.context_len = self._cfg.model.context_len # K in decision transformer
self.state_dim = self._cfg.model.state_dim
self.act_dim = self._cfg.model.act_dim
self._learn_model = self._model
self._atari_env = 'state_mean' not in self._cfg
self._basic_discrete_env = not self._cfg.model.continuous and 'state_mean' in self._cfg
if self._atari_env:
self._optimizer = self._learn_model.configure_optimizers(wt_decay, lr)
else:
self._optimizer = torch.optim.AdamW(self._learn_model.parameters(), lr=lr, weight_decay=wt_decay)
self._scheduler = torch.optim.lr_scheduler.LambdaLR(
self._optimizer, lambda steps: min((steps + 1) / warmup_steps, 1)
)
self.max_env_score = -1.0
[docs] def _forward_learn(self, data: List[torch.Tensor]) -> Dict[str, Any]:
"""
Overview:
Policy forward function of learn mode (training policy and updating parameters). Forward means \
that the policy inputs some training batch data from the offline dataset and then returns the output \
result, including various training information such as loss, current learning rate.
Arguments:
- data (:obj:`List[torch.Tensor]`): The input data used for policy forward, including a series of \
processed torch.Tensor data, i.e., timesteps, states, actions, returns_to_go, traj_mask.
Returns:
- info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \
recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \
detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method.
.. note::
The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
For the data type that not supported, the main reason is that the corresponding model does not support it. \
You can implement you own model rather than use the default model. For more information, please raise an \
issue in GitHub repo and we will continue to follow up.
"""
self._learn_model.train()
timesteps, states, actions, returns_to_go, traj_mask = data
# The shape of `returns_to_go` may differ with different dataset (B x T or B x T x 1),
# and we need a 3-dim tensor
if len(returns_to_go.shape) == 2:
returns_to_go = returns_to_go.unsqueeze(-1)
if self._basic_discrete_env:
actions = actions.to(torch.long)
actions = actions.squeeze(-1)
action_target = torch.clone(actions).detach().to(self._device)
if self._atari_env:
state_preds, action_preds, return_preds = self._learn_model.forward(
timesteps=timesteps, states=states, actions=actions, returns_to_go=returns_to_go, tar=1
)
else:
state_preds, action_preds, return_preds = self._learn_model.forward(
timesteps=timesteps, states=states, actions=actions, returns_to_go=returns_to_go
)
if self._atari_env:
action_loss = F.cross_entropy(action_preds.reshape(-1, action_preds.size(-1)), action_target.reshape(-1))
else:
traj_mask = traj_mask.view(-1, )
# only consider non padded elements
action_preds = action_preds.view(-1, self.act_dim)[traj_mask > 0]
if self._cfg.model.continuous:
action_target = action_target.view(-1, self.act_dim)[traj_mask > 0]
action_loss = F.mse_loss(action_preds, action_target)
else:
action_target = action_target.view(-1)[traj_mask > 0]
action_loss = F.cross_entropy(action_preds, action_target)
self._optimizer.zero_grad()
action_loss.backward()
if self._cfg.multi_gpu:
self.sync_gradients(self._learn_model)
torch.nn.utils.clip_grad_norm_(self._learn_model.parameters(), self.clip_grad_norm_p)
self._optimizer.step()
self._scheduler.step()
return {
'cur_lr': self._optimizer.state_dict()['param_groups'][0]['lr'],
'action_loss': action_loss.detach().cpu().item(),
'total_loss': action_loss.detach().cpu().item(),
}
[docs] def _init_eval(self) -> None:
"""
Overview:
Initialize the eval mode of policy, including related attributes and modules. For DQN, it contains the \
eval model, some algorithm-specific parameters such as context_len, max_eval_ep_len, etc.
This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``.
.. tip::
For the evaluation of complete episodes, we need to maintain some historical information for transformer \
inference. These variables need to be initialized in ``_init_eval`` and reset in ``_reset_eval`` when \
necessary.
.. note::
If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \
with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``.
"""
self._eval_model = self._model
# init data
self._device = torch.device(self._device)
self.rtg_scale = self._cfg.rtg_scale # normalize returns to go
self.rtg_target = self._cfg.rtg_target # max target reward_to_go
self.state_dim = self._cfg.model.state_dim
self.act_dim = self._cfg.model.act_dim
self.eval_batch_size = self._cfg.evaluator_env_num
self.max_eval_ep_len = self._cfg.max_eval_ep_len
self.context_len = self._cfg.model.context_len # K in decision transformer
self.t = [0 for _ in range(self.eval_batch_size)]
if self._cfg.model.continuous:
self.actions = torch.zeros(
(self.eval_batch_size, self.max_eval_ep_len, self.act_dim), dtype=torch.float32, device=self._device
)
else:
self.actions = torch.zeros(
(self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.long, device=self._device
)
self._atari_env = 'state_mean' not in self._cfg
self._basic_discrete_env = not self._cfg.model.continuous and 'state_mean' in self._cfg
if self._atari_env:
self.states = torch.zeros(
(
self.eval_batch_size,
self.max_eval_ep_len,
) + tuple(self.state_dim),
dtype=torch.float32,
device=self._device
)
self.running_rtg = [self.rtg_target for _ in range(self.eval_batch_size)]
else:
self.running_rtg = [self.rtg_target / self.rtg_scale for _ in range(self.eval_batch_size)]
self.states = torch.zeros(
(self.eval_batch_size, self.max_eval_ep_len, self.state_dim), dtype=torch.float32, device=self._device
)
self.state_mean = torch.from_numpy(np.array(self._cfg.state_mean)).to(self._device)
self.state_std = torch.from_numpy(np.array(self._cfg.state_std)).to(self._device)
self.timesteps = torch.arange(
start=0, end=self.max_eval_ep_len, step=1
).repeat(self.eval_batch_size, 1).to(self._device)
self.rewards_to_go = torch.zeros(
(self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.float32, device=self._device
)
[docs] def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]:
"""
Overview:
Policy forward function of eval mode (evaluation policy performance, such as interacting with envs. \
Forward means that the policy gets some input data (current obs/return-to-go and historical information) \
from the envs and then returns the output data, such as the action to interact with the envs. \
Arguments:
- data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs and \
reward to calculate running return-to-go. The key of the dict is environment id and the value is the \
corresponding data of the env.
Returns:
- output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \
key of the dict is the same as the input data, i.e. environment id.
.. note::
Decision Transformer will do different operations for different types of envs in evaluation.
"""
# save and forward
data_id = list(data.keys())
self._eval_model.eval()
with torch.no_grad():
if self._atari_env:
states = torch.zeros(
(
self.eval_batch_size,
self.context_len,
) + tuple(self.state_dim),
dtype=torch.float32,
device=self._device
)
timesteps = torch.zeros((self.eval_batch_size, 1, 1), dtype=torch.long, device=self._device)
else:
states = torch.zeros(
(self.eval_batch_size, self.context_len, self.state_dim), dtype=torch.float32, device=self._device
)
timesteps = torch.zeros((self.eval_batch_size, self.context_len), dtype=torch.long, device=self._device)
if not self._cfg.model.continuous:
actions = torch.zeros(
(self.eval_batch_size, self.context_len, 1), dtype=torch.long, device=self._device
)
else:
actions = torch.zeros(
(self.eval_batch_size, self.context_len, self.act_dim), dtype=torch.float32, device=self._device
)
rewards_to_go = torch.zeros(
(self.eval_batch_size, self.context_len, 1), dtype=torch.float32, device=self._device
)
for i in data_id:
if self._atari_env:
self.states[i, self.t[i]] = data[i]['obs'].to(self._device)
else:
self.states[i, self.t[i]] = (data[i]['obs'].to(self._device) - self.state_mean) / self.state_std
self.running_rtg[i] = self.running_rtg[i] - (data[i]['reward'] / self.rtg_scale).to(self._device)
self.rewards_to_go[i, self.t[i]] = self.running_rtg[i]
if self.t[i] <= self.context_len:
if self._atari_env:
timesteps[i] = min(self.t[i], self._cfg.model.max_timestep) * torch.ones(
(1, 1), dtype=torch.int64
).to(self._device)
else:
timesteps[i] = self.timesteps[i, :self.context_len]
states[i] = self.states[i, :self.context_len]
actions[i] = self.actions[i, :self.context_len]
rewards_to_go[i] = self.rewards_to_go[i, :self.context_len]
else:
if self._atari_env:
timesteps[i] = min(self.t[i], self._cfg.model.max_timestep) * torch.ones(
(1, 1), dtype=torch.int64
).to(self._device)
else:
timesteps[i] = self.timesteps[i, self.t[i] - self.context_len + 1:self.t[i] + 1]
states[i] = self.states[i, self.t[i] - self.context_len + 1:self.t[i] + 1]
actions[i] = self.actions[i, self.t[i] - self.context_len + 1:self.t[i] + 1]
rewards_to_go[i] = self.rewards_to_go[i, self.t[i] - self.context_len + 1:self.t[i] + 1]
if self._basic_discrete_env:
actions = actions.squeeze(-1)
_, act_preds, _ = self._eval_model.forward(timesteps, states, actions, rewards_to_go)
del timesteps, states, actions, rewards_to_go
logits = act_preds[:, -1, :]
if not self._cfg.model.continuous:
if self._atari_env:
probs = F.softmax(logits, dim=-1)
act = torch.zeros((self.eval_batch_size, 1), dtype=torch.long, device=self._device)
for i in data_id:
act[i] = torch.multinomial(probs[i], num_samples=1)
else:
act = torch.argmax(logits, axis=1).unsqueeze(1)
else:
act = logits
for i in data_id:
self.actions[i, self.t[i]] = act[i] # TODO: self.actions[i] should be a queue when exceed max_t
self.t[i] += 1
if self._cuda:
act = to_device(act, 'cpu')
output = {'action': act}
output = default_decollate(output)
return {i: d for i, d in zip(data_id, output)}
[docs] def _reset_eval(self, data_id: Optional[List[int]] = None) -> None:
"""
Overview:
Reset some stateful variables for eval mode when necessary, such as the historical info of transformer \
for decision transformer. If ``data_id`` is None, it means to reset all the stateful \
varaibles. Otherwise, it will reset the stateful variables according to the ``data_id``. For example, \
different environments/episodes in evaluation in ``data_id`` will have different history.
Arguments:
- data_id (:obj:`Optional[List[int]]`): The id of the data, which is used to reset the stateful variables \
specified by ``data_id``.
"""
# clean data
if data_id is None:
self.t = [0 for _ in range(self.eval_batch_size)]
self.timesteps = torch.arange(
start=0, end=self.max_eval_ep_len, step=1
).repeat(self.eval_batch_size, 1).to(self._device)
if not self._cfg.model.continuous:
self.actions = torch.zeros(
(self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.long, device=self._device
)
else:
self.actions = torch.zeros(
(self.eval_batch_size, self.max_eval_ep_len, self.act_dim),
dtype=torch.float32,
device=self._device
)
if self._atari_env:
self.states = torch.zeros(
(
self.eval_batch_size,
self.max_eval_ep_len,
) + tuple(self.state_dim),
dtype=torch.float32,
device=self._device
)
self.running_rtg = [self.rtg_target for _ in range(self.eval_batch_size)]
else:
self.states = torch.zeros(
(self.eval_batch_size, self.max_eval_ep_len, self.state_dim),
dtype=torch.float32,
device=self._device
)
self.running_rtg = [self.rtg_target / self.rtg_scale for _ in range(self.eval_batch_size)]
self.rewards_to_go = torch.zeros(
(self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.float32, device=self._device
)
else:
for i in data_id:
self.t[i] = 0
if not self._cfg.model.continuous:
self.actions[i] = torch.zeros((self.max_eval_ep_len, 1), dtype=torch.long, device=self._device)
else:
self.actions[i] = torch.zeros(
(self.max_eval_ep_len, self.act_dim), dtype=torch.float32, device=self._device
)
if self._atari_env:
self.states[i] = torch.zeros(
(self.max_eval_ep_len, ) + tuple(self.state_dim), dtype=torch.float32, device=self._device
)
self.running_rtg[i] = self.rtg_target
else:
self.states[i] = torch.zeros(
(self.max_eval_ep_len, self.state_dim), dtype=torch.float32, device=self._device
)
self.running_rtg[i] = self.rtg_target / self.rtg_scale
self.timesteps[i] = torch.arange(start=0, end=self.max_eval_ep_len, step=1).to(self._device)
self.rewards_to_go[i] = torch.zeros((self.max_eval_ep_len, 1), dtype=torch.float32, device=self._device)
[docs] def _monitor_vars_learn(self) -> List[str]:
"""
Overview:
Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \
as text logger, tensorboard logger, will use these keys to save the corresponding data.
Returns:
- necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged.
"""
return ['cur_lr', 'action_loss']
def _init_collect(self) -> None:
pass
def _forward_collect(self, data: Dict[int, Any], eps: float) -> Dict[int, Any]:
pass
def _get_train_sample(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
pass
def _process_transition(self, obs: Any, policy_output: Dict[str, Any], timestep: namedtuple) -> Dict[str, Any]:
pass