Shortcuts

TD3

概述

Twin Delayed DDPG (TD3) 首次在2018年发表的论文 Addressing Function Approximation Error in Actor-Critic Methods 中被提出,它是一种考虑了策略和值更新中函数逼近误差之间相互作用的算法。 TD3 是一种基于 deep deterministic policy gradient (DDPG)无模型(model-free) 算法,属于 演员—评委(actor-critic) 类算法中的一员。此算法可以解决高估偏差,时间差分方法中的误差积累以及连续动作空间中对超参数的高敏感性的问题。具体来说,TD3通过引入以下三个关键技巧来解决这些问题:

  1. 截断双 Q 学习(Clipped Double-Q Learning):在计算Bellman误差损失函数中的目标时,TD3 学习两个 Q 函数而不是一个,并使用较小的 Q 值。

  2. 延迟的策略更新(Delayed Policy Updates): TD3更新策略(和目标网络)的频率低于 Q 函数的更新频率。在本文中,作者建议在对 Q 函数更新两次后进行一次策略更新。在我们的实现中,TD3 仅在对 critic 网络更新一定次数 \(d\) 后,才对策略和目标网络进行一次更新。我们通过配置参数 learn.actor_update_freq 来实现策略更新延迟。

  3. 目标策略平滑(Target Policy Smoothing):通过沿动作变化平滑 Q 值,TD3 为目标动作引入噪声,使策略更加难以利用 Q 函数的预测错误。

核心要点

  1. TD3 仅支持 连续动作空间 (例如: MuJoCo).

  2. TD3 是一种 异策略(off-policy) 算法.

  3. TD3 是一种 无模型(model-free)演员—评委(actor-critic) 的强化学习算法,它会分别优化策略网络和Q网络。

关键方程或关键框图

TD3 提出了一个截断双 Q 学习变体(Clipped Double-Q Learning),它利用了这样一个概念,即遭受高估偏差的值估计可以用作真实值估计的近似上限。结合下式计算 \(Q_{\theta_1}\) 的 target,当 \(Q_{\theta_2} \textless Q_{\theta_1}\) 时,我们认为 \(Q_{\theta_1}\) 高估了,并将其当作真实值估计的近似上限,取较小的 \(Q_{\theta_2}\) 计算 \(y_1\) 以减少过估计。

作为原始版本双 Q 学习的一种拓展,此扩展的动机是,如果目标和当前网络过于相似,例如在actor-critic框架中使用缓慢变化的策略,原始版本的双 Q 学习有时是无效的。

TD3表明,目标网络是深度 Q 学习方法中的一种常见方法,通过减少误差积累来减少目标的方差是至关重要的。

首先,为了解决动作价值估计和策略提升的耦合问题,TD3建议延迟策略更新,直到动作价值估计值尽可能小。因此,TD3只在固定数量次数的 critic 网络更新后再更新策略和目标网络。 我们通过配置参数 learn.actor_update_freq 来实现策略更新延迟。

其次,截断双 Q 学习(Clipped Double Q-learning)算法的目标更新如下:

\[y_{1}=r+\gamma \min _{i=1,2} Q_{\theta_{i}^{\prime}}\left(s^{\prime}, \pi_{\phi_{1}}\left(s^{\prime}\right)\right)\]

在实现中,我们可以通过使用单一的 actor 来优化 \(Q_{\theta_1}\) 以减少计算开销。由于 TD target 计算过程中使用了同样的策略,因此对于 \(Q_{\theta_2}\) 的优化目标, \(y_2= y_1\)

最后,确定性策略的一个问题是,由于以神经网络参数化的 Q 函数对 buffer 中动作的价值估计存在突然激增的尖峰(narrow peaks),这会导致策略网络过拟合到这些动作上。并且当更新 critic 网络时,使用确定性策略的学习目标极易受到函数近似误差引起的不准确性的影响,从而增加了目标的方差。 TD3 引入了一种用于深度价值学习的正则化策略,即目标策略平滑,它模仿了SARSA的学习更新。具体来说,TD3通过在目标策略中添加少量随机噪声并在多次计算以下数值后,取平均值来近似此期望:

\[\begin{split}\begin{array}{l} y=r+\gamma Q_{\theta^{\prime}}\left(s^{\prime}, \pi_{\phi^{\prime}}\left(s^{\prime}\right)+\epsilon\right) \\ \epsilon \sim \operatorname{clip}(\mathcal{N}(0, \sigma),-c, c) \end{array}\end{split}\]

我们通过配置 learn.noiselearn.noise_sigmalearn.noise_range 来实现目标策略平滑。

伪代码

\[ \begin{align}\begin{aligned}:nowrap:\\\begin{split}\begin{algorithm}[H] \caption{Twin Delayed DDPG} \label{alg1} \begin{algorithmic}[1] \STATE Input: initial policy parameters $\theta$, Q-function parameters $\phi_1$, $\phi_2$, empty replay buffer $\mathcal{D}$ \STATE Set target parameters equal to main parameters $\theta_{\text{targ}} \leftarrow \theta$, $\phi_{\text{targ},1} \leftarrow \phi_1$, $\phi_{\text{targ},2} \leftarrow \phi_2$ \REPEAT \STATE Observe state $s$ and select action $a = \text{clip}(\mu_{\theta}(s) + \epsilon, a_{Low}, a_{High})$, where $\epsilon \sim \mathcal{N}$ \STATE Execute $a$ in the environment \STATE Observe next state $s'$, reward $r$, and done signal $d$ to indicate whether $s'$ is terminal \STATE Store $(s,a,r,s',d)$ in replay buffer $\mathcal{D}$ \STATE If $s'$ is terminal, reset environment state. \IF{it's time to update} \FOR{$j$ in range(however many updates)} \STATE Randomly sample a batch of transitions, $B = \{ (s,a,r,s',d) \}$ from $\mathcal{D}$ \STATE Compute target actions \begin{equation*} a'(s') = \text{clip}\left(\mu_{\theta_{\text{targ}}}(s') + \text{clip}(\epsilon,-c,c), a_{Low}, a_{High}\right), \;\;\;\;\; \epsilon \sim \mathcal{N}(0, \sigma) \end{equation*} \STATE Compute targets \begin{equation*} y(r,s',d) = r + \gamma (1-d) \min_{i=1,2} Q_{\phi_{\text{targ},i}}(s', a'(s')) \end{equation*} \STATE Update Q-functions by one step of gradient descent using \begin{align*} & \nabla_{\phi_i} \frac{1}{|B|}\sum_{(s,a,r,s',d) \in B} \left( Q_{\phi_i}(s,a) - y(r,s',d) \right)^2 && \text{for } i=1,2 \end{align*} \IF{ $j \mod$ \texttt{policy\_delay} $ = 0$} \STATE Update policy by one step of gradient ascent using \begin{equation*} \nabla_{\theta} \frac{1}{|B|}\sum_{s \in B}Q_{\phi_1}(s, \mu_{\theta}(s)) \end{equation*} \STATE Update target networks with \begin{align*} \phi_{\text{targ},i} &\leftarrow \rho \phi_{\text{targ}, i} + (1-\rho) \phi_i && \text{for } i=1,2\\ \theta_{\text{targ}} &\leftarrow \rho \theta_{\text{targ}} + (1-\rho) \theta \end{align*} \ENDIF \ENDFOR \ENDIF \UNTIL{convergence} \end{algorithmic} \end{algorithm}\end{split}\end{aligned}\end{align} \]
../_images/TD3.png

扩展

TD3 可以与以下技术相结合使用:

  • 遵循随机策略的经验回放池初始采集

    在优化模型参数前,我们需要让经验回放池存有足够数目的遵循随机策略的 transition 数据,从而确保在算法初期模型不会对经验回放池数据过拟合。 DDPG/TD3 的 random-collect-size 默认设置为25000, SAC 为10000。 我们只是简单地遵循 SpinningUp 默认设置,并使用随机策略来收集初始化数据。 我们通过配置 random-collect-size 来控制初始经验回放池中的 transition 数目。

实现

默认配置定义如下:

class ding.policy.td3.TD3Policy(cfg: EasyDict, model: Module | None = None, enable_field: List[str] | None = None)[source]
Overview:

Policy class of TD3 algorithm. Since DDPG and TD3 share many common things, we can easily derive this TD3 class from DDPG class by changing _actor_update_freq, _twin_critic and noise in model wrapper. Paper link: https://arxiv.org/pdf/1802.09477.pdf

Config:

ID

Symbol

Type

Default Value

Description

Other(Shape)

1

type

str

td3

RL policy register name, refer
to registry POLICY_REGISTRY
this arg is optional,
a placeholder

2

cuda

bool

False

Whether to use cuda for network

3

random_
collect_size

int

25000

Number of randomly collected
training samples in replay
buffer when training starts.
Default to 25000 for
DDPG/TD3, 10000 for
sac.

4

model.twin_
critic


bool

True

Whether to use two critic
networks or only one.


Default True for TD3,
Clipped Double
Q-learning method in
TD3 paper.

5

learn.learning
_rate_actor

float

1e-3

Learning rate for actor
network(aka. policy).


6

learn.learning
_rate_critic

float

1e-3

Learning rates for critic
network (aka. Q-network).


7

learn.actor_
update_freq


int

2

When critic network updates
once, how many times will actor
network update.

Default 2 for TD3, 1
for DDPG. Delayed
Policy Updates method
in TD3 paper.

8

learn.noise




bool

True

Whether to add noise on target
network’s action.



Default True for TD3,
False for DDPG.
Target Policy Smoo-
thing Regularization
in TD3 paper.

9

learn.noise_
range

dict

dict(min=-0.5,
max=0.5,)

Limit for range of target
policy smoothing noise,
aka. noise_clip.



10

learn.-
ignore_done

bool

False

Determine whether to ignore
done flag.
Use ignore_done only
in halfcheetah env.

11

learn.-
target_theta


float

0.005

Used for soft update of the
target network.


aka. Interpolation
factor in polyak aver
-aging for target
networks.

12

collect.-
noise_sigma



float

0.1

Used for add noise during co-
llection, through controlling
the sigma of distribution


Sample noise from dis
-tribution, Ornstein-
Uhlenbeck process in
DDPG paper, Gaussian
process in ours.
  1. 模型

    在这里,我们提供了 ContinuousQAC 模型作为 TD3 的默认模型的示例。

    class ding.model.ContinuousQAC(obs_shape: int | SequenceType, action_shape: int | SequenceType | EasyDict, action_space: str, twin_critic: bool = False, actor_head_hidden_size: int = 64, actor_head_layer_num: int = 1, critic_head_hidden_size: int = 64, critic_head_layer_num: int = 1, activation: Module | None = ReLU(), norm_type: str | None = None, encoder_hidden_size_list: SequenceType | None = None, share_encoder: bool | None = False)[source]
    Overview:

    The neural network and computation graph of algorithms related to Q-value Actor-Critic (QAC), such as DDPG/TD3/SAC. This model now supports continuous and hybrid action space. The ContinuousQAC is composed of four parts: actor_encoder, critic_encoder, actor_head and critic_head. Encoders are used to extract the feature from various observation. Heads are used to predict corresponding Q-value or action logit. In high-dimensional observation space like 2D image, we often use a shared encoder for both actor_encoder and critic_encoder. In low-dimensional observation space like 1D vector, we often use different encoders.

    Interfaces:

    __init__, forward, compute_actor, compute_critic

    compute_actor(obs: Tensor) Dict[str, Tensor | Dict[str, Tensor]][source]
    Overview:

    QAC forward computation graph for actor part, input observation tensor to predict action or action logit.

    Arguments:
    • x (torch.Tensor): The input observation tensor data.

    Returns:
    • outputs (Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]): Actor output dict varying from action_space: regression, reparameterization, hybrid.

    ReturnsKeys (regression):
    • action (torch.Tensor): Continuous action with same size as action_shape, usually in DDPG/TD3.

    ReturnsKeys (reparameterization):
    • logit (Dict[str, torch.Tensor]): The predictd reparameterization action logit, usually in SAC. It is a list containing two tensors: mu and sigma. The former is the mean of the gaussian distribution, the latter is the standard deviation of the gaussian distribution.

    ReturnsKeys (hybrid):
    • logit (torch.Tensor): The predicted discrete action type logit, it will be the same dimension as action_type_shape, i.e., all the possible discrete action types.

    • action_args (torch.Tensor): Continuous action arguments with same size as action_args_shape.

    Shapes:
    • obs (torch.Tensor): \((B, N0)\), B is batch size and N0 corresponds to obs_shape.

    • action (torch.Tensor): \((B, N1)\), B is batch size and N1 corresponds to action_shape.

    • logit.mu (torch.Tensor): \((B, N1)\), B is batch size and N1 corresponds to action_shape.

    • logit.sigma (torch.Tensor): \((B, N1)\), B is batch size.

    • logit (torch.Tensor): \((B, N2)\), B is batch size and N2 corresponds to action_shape.action_type_shape.

    • action_args (torch.Tensor): \((B, N3)\), B is batch size and N3 corresponds to action_shape.action_args_shape.

    Examples:
    >>> # Regression mode
    >>> model = ContinuousQAC(64, 6, 'regression')
    >>> obs = torch.randn(4, 64)
    >>> actor_outputs = model(obs,'compute_actor')
    >>> assert actor_outputs['action'].shape == torch.Size([4, 6])
    >>> # Reparameterization Mode
    >>> model = ContinuousQAC(64, 6, 'reparameterization')
    >>> obs = torch.randn(4, 64)
    >>> actor_outputs = model(obs,'compute_actor')
    >>> assert actor_outputs['logit'][0].shape == torch.Size([4, 6])  # mu
    >>> actor_outputs['logit'][1].shape == torch.Size([4, 6]) # sigma
    
    compute_critic(inputs: Dict[str, Tensor]) Dict[str, Tensor][source]
    Overview:

    QAC forward computation graph for critic part, input observation and action tensor to predict Q-value.

    Arguments:
    • inputs (Dict[str, torch.Tensor]): The dict of input data, including obs and action tensor, also contains logit and action_args tensor in hybrid action_space.

    ArgumentsKeys:
    • obs: (torch.Tensor): Observation tensor data, now supports a batch of 1-dim vector data.

    • action (Union[torch.Tensor, Dict]): Continuous action with same size as action_shape.

    • logit (torch.Tensor): Discrete action logit, only in hybrid action_space.

    • action_args (torch.Tensor): Continuous action arguments, only in hybrid action_space.

    Returns:
    • outputs (Dict[str, torch.Tensor]): The output dict of QAC’s forward computation graph for critic, including q_value.

    ReturnKeys:
    • q_value (torch.Tensor): Q value tensor with same size as batch size.

    Shapes:
    • obs (torch.Tensor): \((B, N1)\), where B is batch size and N1 is obs_shape.

    • logit (torch.Tensor): \((B, N2)\), B is batch size and N2 corresponds to action_shape.action_type_shape.

    • action_args (torch.Tensor): \((B, N3)\), B is batch size and N3 corresponds to action_shape.action_args_shape.

    • action (torch.Tensor): \((B, N4)\), where B is batch size and N4 is action_shape.

    • q_value (torch.Tensor): \((B, )\), where B is batch size.

    Examples:
    >>> inputs = {'obs': torch.randn(4, 8), 'action': torch.randn(4, 1)}
    >>> model = ContinuousQAC(obs_shape=(8, ),action_shape=1, action_space='regression')
    >>> assert model(inputs, mode='compute_critic')['q_value'].shape == (4, )  # q value
    
    forward(inputs: Tensor | Dict[str, Tensor], mode: str) Dict[str, Tensor][source]
    Overview:

    QAC forward computation graph, input observation tensor to predict Q-value or action logit. Different mode will forward with different network modules to get different outputs and save computation.

    Arguments:
    • inputs (Union[torch.Tensor, Dict[str, torch.Tensor]]): The input data for forward computation graph, for compute_actor, it is the observation tensor, for compute_critic, it is the dict data including obs and action tensor.

    • mode (str): The forward mode, all the modes are defined in the beginning of this class.

    Returns:
    • output (Dict[str, torch.Tensor]): The output dict of QAC forward computation graph, whose key-values vary in different forward modes.

    Examples (Actor):
    >>> # Regression mode
    >>> model = ContinuousQAC(64, 6, 'regression')
    >>> obs = torch.randn(4, 64)
    >>> actor_outputs = model(obs,'compute_actor')
    >>> assert actor_outputs['action'].shape == torch.Size([4, 6])
    >>> # Reparameterization Mode
    >>> model = ContinuousQAC(64, 6, 'reparameterization')
    >>> obs = torch.randn(4, 64)
    >>> actor_outputs = model(obs,'compute_actor')
    >>> assert actor_outputs['logit'][0].shape == torch.Size([4, 6])  # mu
    >>> actor_outputs['logit'][1].shape == torch.Size([4, 6]) # sigma
    
    Examples (Critic):
    >>> inputs = {'obs': torch.randn(4, 8), 'action': torch.randn(4, 1)}
    >>> model = ContinuousQAC(obs_shape=(8, ),action_shape=1, action_space='regression')
    >>> assert model(inputs, mode='compute_critic')['q_value'].shape == (4, )  # q value
    
  2. 训练 actor-critic 模型

    首先,我们在 _init_learn 中分别初始化 actor 和 critic 优化器。 设置两个独立的优化器可以保证我们在计算 actor 损失时只更新 actor 网络参数而不更新 critic 网络,反之亦然。

    # actor and critic optimizer
    self._optimizer_actor = Adam(
        self._model.actor.parameters(),
        lr=self._cfg.learn.learning_rate_actor,
        weight_decay=self._cfg.learn.weight_decay
    )
    self._optimizer_critic = Adam(
        self._model.critic.parameters(),
        lr=self._cfg.learn.learning_rate_critic,
        weight_decay=self._cfg.learn.weight_decay
    )
    
    _forward_learn 中,我们通过计算 critic 损失、更新 critic 网络、计算 actor 损失和更新 actor 网络来更新 actor-critic 策略。
    1. critic loss computation

      • 计算当前值和目标值

      # current q value
      q_value = self._learn_model.forward(data, mode='compute_critic')['q_value']
      q_value_dict = {}
      if self._twin_critic:
          q_value_dict['q_value'] = q_value[0].mean()
          q_value_dict['q_value_twin'] = q_value[1].mean()
      else:
          q_value_dict['q_value'] = q_value.mean()
      # target q value. SARSA: first predict next action, then calculate next q value
      with torch.no_grad():
          next_action = self._target_model.forward(next_obs, mode='compute_actor')['action']
          next_data = {'obs': next_obs, 'action': next_action}
          target_q_value = self._target_model.forward(next_data, mode='compute_critic')['q_value']
      
      • Q 网络目标(Clipped Double-Q Learning)和损失计算

      if self._twin_critic:
          # TD3: two critic networks
          target_q_value = torch.min(target_q_value[0], target_q_value[1])  # find min one as target q value
          # network1
          td_data = v_1step_td_data(q_value[0], target_q_value, reward, data['done'], data['weight'])
          critic_loss, td_error_per_sample1 = v_1step_td_error(td_data, self._gamma)
          loss_dict['critic_loss'] = critic_loss
          # network2(twin network)
          td_data_twin = v_1step_td_data(q_value[1], target_q_value, reward, data['done'], data['weight'])
          critic_twin_loss, td_error_per_sample2 = v_1step_td_error(td_data_twin, self._gamma)
          loss_dict['critic_twin_loss'] = critic_twin_loss
          td_error_per_sample = (td_error_per_sample1 + td_error_per_sample2) / 2
      else:
          # DDPG: single critic network
          td_data = v_1step_td_data(q_value, target_q_value, reward, data['done'], data['weight'])
          critic_loss, td_error_per_sample = v_1step_td_error(td_data, self._gamma)
          loss_dict['critic_loss'] = critic_loss
      
    2. critic network update

      self._optimizer_critic.zero_grad()
      for k in loss_dict:
          if 'critic' in k:
              loss_dict[k].backward()
      self._optimizer_critic.step()
      
    3. actor loss computationactor network update 取决于策略更新延迟(delaying the policy updates)的程度。

      if (self._forward_learn_cnt + 1) % self._actor_update_freq == 0:
          actor_data = self._learn_model.forward(data['obs'], mode='compute_actor')
          actor_data['obs'] = data['obs']
          if self._twin_critic:
              actor_loss = -self._learn_model.forward(actor_data, mode='compute_critic')['q_value'][0].mean()
          else:
              actor_loss = -self._learn_model.forward(actor_data, mode='compute_critic')['q_value'].mean()
      
          loss_dict['actor_loss'] = actor_loss
          # actor update
          self._optimizer_actor.zero_grad()
          actor_loss.backward()
          self._optimizer_actor.step()
      
  3. 目标网络(Target Network)

    我们通过 _init_learn 中的 self._target_model 初始化来实现目标网络。 我们配置 learn.target_theta 来控制平均中的插值因子。

    # main and target models
    self._target_model = copy.deepcopy(self._model)
    self._target_model = model_wrap(
        self._target_model,
        wrapper_name='target',
        update_type='momentum',
        update_kwargs={'theta': self._cfg.learn.target_theta}
    )
    
  4. 目标策略平滑正则(Target Policy Smoothing Regularization)

    我们通过 _init_learn 中的目标模型初始化来实现目标策略平滑正则。 我们通过配置 learn.noiselearn.noise_sigmalearn.noise_range 来控制引入的噪声,通过对噪声进行截断使所选动作不会太过偏离原始动作。

    if self._cfg.learn.noise:
        self._target_model = model_wrap(
            self._target_model,
            wrapper_name='action_noise',
            noise_type='gauss',
            noise_kwargs={
                'mu': 0.0,
                'sigma': self._cfg.learn.noise_sigma
            },
            noise_range=self._cfg.learn.noise_range
        )
    

基准

environment

best mean reward

evaluation results

config link

comparison

HalfCheetah

(HalfCheetah-v3)

11148

../_images/halfcheetah_td3.png

config_link_p

Tianshou(10201) Spinning-up(9750) Sb3(9656)

Hopper

(Hopper-v2)

3720

../_images/hopper_td3.png

config_link_q

Tianshou(3472) Spinning-up(3982) sb3(3606 for Hopper-v3)

Walker2d

(Walker2d-v2)

4386

../_images/walker2d_td3.png

config_link_s

Tianshou(3982) Spinning-up(3472) sb3(4718 for Walker2d-v2)

P.S.:

  1. 上述结果是通过在五个不同的随机种子(0,1,2,3,4)上运行相同的配置获得的。

参考文献

Scott Fujimoto, Herke van Hoof, David Meger: “Addressing Function Approximation Error in Actor-Critic Methods”, 2018; [http://arxiv.org/abs/1802.09477 arXiv:1802.09477].

其他公开的实现