Shortcuts

Source code for ding.rl_utils.ppo

from collections import namedtuple
from typing import Optional, Tuple
import torch
import torch.nn as nn
from torch.distributions import Independent, Normal
from ding.hpc_rl import hpc_wrapper

ppo_data = namedtuple(
    'ppo_data',
    ['logit_new', 'logit_old', 'action', 'value_new', 'value_old', 'adv', 'return_', 'weight', 'logit_pretrained']
)
ppo_data_continuous = namedtuple(
    'ppo_data_continuous', [
        'mu_sigma_new', 'mu_sigma_old', 'action', 'value_new', 'value_old', 'adv', 'return_', 'weight',
        'logit_pretrained'
    ]
)
ppo_policy_data = namedtuple(
    'ppo_policy_data', ['logit_new', 'logit_old', 'action', 'adv', 'weight', 'logit_pretrained']
)
ppo_policy_data_continuous = namedtuple(
    'ppo_policy_data_continuous', ['mu_sigma_new', 'mu_sigma_old', 'action', 'adv', 'weight', 'logit_pretrained']
)
ppo_value_data = namedtuple('ppo_value_data', ['value_new', 'value_old', 'return_', 'weight'])
ppo_loss = namedtuple('ppo_loss', ['policy_loss', 'value_loss', 'entropy_loss', 'kl_div'])
ppo_policy_loss = namedtuple('ppo_policy_loss', ['policy_loss', 'entropy_loss', 'kl_div'])
ppo_info = namedtuple('ppo_info', ['approx_kl', 'clipfrac'])


def calculate_kl_div(log_ratio: torch.Tensor, kl_type: str) -> torch.Tensor:
    """
    Overview:
        Calculate different Monte-Carlo estimators for KL-divergence KL(q, p) = E_q[log(q/p)],
        where q is the current policy and p is the pretrained policy.
        The implementation is based on John Schulman's blog post "Approximating KL Divergence".
        Reference: http://joschu.net/blog/kl-approx.html
    Arguments:
        - log_ratio (:obj:`torch.Tensor`): The log-ratio of probabilities, which should be
          log(q/p) = logp_new - logp_pretrained.
        - kl_type (:obj:`str`): The type of KL divergence estimator to use.
            - 'k1': The standard, unbiased but high-variance estimator: `E_q[log(q/p)]`.
            - 'k2': A biased, low-variance estimator from a second-order approximation: `E_q[1/2 * (log(p/q))^2]`.
            - 'k3': An unbiased, low-variance estimator: `E_q[(p/q - 1) - log(p/q)]`.
    Returns:
        - kl_div (:obj:`torch.Tensor`): The calculated KL divergence estimate.
    """
    if kl_type == 'k1':
        return log_ratio.mean()
    elif kl_type == 'k2':
        return (log_ratio ** 2 / 2).mean()
    elif kl_type == 'k3':
        return (torch.exp(-log_ratio) - 1 + log_ratio).mean()
    else:
        raise ValueError(f"Unknown kl_type: {kl_type}")


[docs]def shape_fn_ppo(args, kwargs): r""" Overview: Return shape of ppo for hpc Returns: shape: [B, N] """ if len(args) <= 0: tmp = kwargs['data'].logit_new.shape else: tmp = args[0].logit_new.shape return tmp
[docs]@hpc_wrapper( shape_fn=shape_fn_ppo, namedtuple_data=True, include_args=[0, 1, 2, 3], include_kwargs=['data', 'clip_ratio', 'use_value_clip', 'dual_clip'] ) def ppo_error( data: namedtuple, clip_ratio: float = 0.2, use_value_clip: bool = True, dual_clip: Optional[float] = None, kl_type: str = 'k1' ) -> Tuple[namedtuple, namedtuple]: """ Overview: Implementation of Proximal Policy Optimization (arXiv:1707.06347) with value_clip and dual_clip Arguments: - data (:obj:`namedtuple`): the ppo input data with fieids shown in ``ppo_data`` - clip_ratio (:obj:`float`): the ppo clip ratio for the constraint of policy update, defaults to 0.2 - use_value_clip (:obj:`bool`): whether to use clip in value loss with the same ratio as policy - dual_clip (:obj:`float`): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf),\ defaults to 5.0, if you don't want to use it, set this parameter to None - kl_type (:obj:`str`): which kl loss to use, default set to 'k1'. Returns: - ppo_loss (:obj:`namedtuple`): the ppo loss item, all of them are the differentiable 0-dim tensor - ppo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar Shapes: - logit_new (:obj:`torch.FloatTensor`): :math:`(B, N)`, where B is batch size and N is action dim - logit_old (:obj:`torch.FloatTensor`): :math:`(B, N)` - action (:obj:`torch.LongTensor`): :math:`(B, )` - value_new (:obj:`torch.FloatTensor`): :math:`(B, )` - value_old (:obj:`torch.FloatTensor`): :math:`(B, )` - adv (:obj:`torch.FloatTensor`): :math:`(B, )` - return (:obj:`torch.FloatTensor`): :math:`(B, )` - weight (:obj:`torch.FloatTensor` or :obj:`None`): :math:`(B, )` - policy_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor - value_loss (:obj:`torch.FloatTensor`): :math:`()` - entropy_loss (:obj:`torch.FloatTensor`): :math:`()` Examples: >>> action_dim = 4 >>> data = ppo_data( >>> logit_new=torch.randn(3, action_dim), >>> logit_old=torch.randn(3, action_dim), >>> action=torch.randint(0, action_dim, (3,)), >>> value_new=torch.randn(3), >>> value_old=torch.randn(3), >>> adv=torch.randn(3), >>> return_=torch.randn(3), >>> weight=torch.ones(3), >>> ) >>> loss, info = ppo_error(data) .. note:: adv is already normalized value (adv - adv.mean()) / (adv.std() + 1e-8), and there are many ways to calculate this mean and std, like among data buffer or train batch, so we don't couple this part into ppo_error, you can refer to our examples for different ways. """ assert dual_clip is None or dual_clip > 1.0, "dual_clip value must be greater than 1.0, but get value: {}".format( dual_clip ) logit_new, logit_old, action, value_new, value_old, adv, return_, weight, logit_pretrained = data policy_data = ppo_policy_data(logit_new, logit_old, action, adv, weight, logit_pretrained) policy_output, policy_info = ppo_policy_error(policy_data, clip_ratio, dual_clip, kl_type=kl_type) value_data = ppo_value_data(value_new, value_old, return_, weight) value_loss = ppo_value_error(value_data, clip_ratio, use_value_clip) return ppo_loss( policy_output.policy_loss, value_loss, policy_output.entropy_loss, policy_output.kl_div ), policy_info
[docs]def ppo_policy_error( data: namedtuple, clip_ratio: float = 0.2, dual_clip: Optional[float] = None, entropy_bonus: bool = True, kl_type: str = 'k1' ) -> Tuple[namedtuple, namedtuple]: """ Overview: Get PPO policy loss (both for classical RL in control/video games and LLM/VLM RLHF). Arguments: - data (:obj:`namedtuple`): Ppo input data with fieids shown in ``ppo_policy_data``. - clip_ratio (:obj:`float`): Clip value for ratio, defaults to 0.2. - dual_clip (:obj:`float`): A parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf), \ defaults to 5.0, if you don't want to use it, set this parameter to None - entropy_bonus (:obj:`bool`): Whether to use entropy bonus, defaults to True. LLM RLHF usually does not use it. - kl_type (:obj:`str`): which kl loss to use, default set to 'k1'. Returns: - ppo_policy_loss (:obj:`namedtuple`): the ppo policy loss item, all of them are the differentiable 0-dim tensor - ppo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar Shapes: - logit_new (:obj:`torch.FloatTensor`): :math:`(B, N)`, where B is batch size and N is action dim - logit_old (:obj:`torch.FloatTensor`): :math:`(B, N)` - action (:obj:`torch.LongTensor`): :math:`(B, )` - adv (:obj:`torch.FloatTensor`): :math:`(B, )` - weight (:obj:`torch.FloatTensor` or :obj:`None`): :math:`(B, )` - policy_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor - entropy_loss (:obj:`torch.FloatTensor`): :math:`()` Examples: >>> action_dim = 4 >>> data = ppo_policy_data( >>> logit_new=torch.randn(3, action_dim), >>> logit_old=torch.randn(3, action_dim), >>> action=torch.randint(0, action_dim, (3,)), >>> adv=torch.randn(3), >>> weight=torch.ones(3), >>> ) >>> loss, info = ppo_policy_error(data) .. note:: This function can be extended from `B` to more parallel dimensions, like `(B, S)`, where `S` is the sequence length in LLM/VLM. .. note:: For the action mask often used in LLM/VLM, users can set the `weight` to the action mask. """ logit_new, logit_old, action, adv, weight, logit_pretrained = data if weight is None: weight = torch.ones_like(adv) dist_new = torch.distributions.categorical.Categorical(logits=logit_new) dist_old = torch.distributions.categorical.Categorical(logits=logit_old) logp_new = dist_new.log_prob(action) logp_old = dist_old.log_prob(action) if entropy_bonus: dist_new_entropy = dist_new.entropy() if dist_new_entropy.shape != weight.shape: # for the multi-agent rl case dist_new_entropy = dist_new.entropy().mean(dim=1) entropy_loss = (dist_new_entropy * weight).mean() else: entropy_loss = torch.tensor(0.0) # policy_loss ratio = torch.exp(logp_new - logp_old) if ratio.shape != adv.shape: ratio = ratio.mean(dim=1) surr1 = ratio * adv surr2 = ratio.clamp(1 - clip_ratio, 1 + clip_ratio) * adv if dual_clip is not None: clip1 = torch.min(surr1, surr2) clip2 = torch.max(clip1, dual_clip * adv) # only use dual_clip when adv < 0 policy_loss = -(torch.where(adv < 0, clip2, clip1) * weight).mean() else: policy_loss = (-torch.min(surr1, surr2) * weight).mean() with torch.no_grad(): approx_kl = (logp_old - logp_new).mean().item() clipped = ratio.gt(1 + clip_ratio) | ratio.lt(1 - clip_ratio) clipfrac = torch.as_tensor(clipped).float().mean().item() if logit_pretrained is not None: dist_pretrained = torch.distributions.categorical.Categorical(logits=logit_pretrained) logp_pretrained = dist_pretrained.log_prob(action) log_ratio = logp_new - logp_pretrained kl_div = calculate_kl_div(log_ratio, kl_type) else: kl_div = 0 return ppo_policy_loss(policy_loss, entropy_loss, kl_div), ppo_info(approx_kl, clipfrac)
[docs]def ppo_value_error( data: namedtuple, clip_ratio: float = 0.2, use_value_clip: bool = True, ) -> torch.Tensor: ''' Overview: Get PPO value loss Arguments: - data (:obj:`namedtuple`): ppo input data with fieids shown in ``ppo_value_data`` - clip_ratio (:obj:`float`): clip value for ratio - use_value_clip (:obj:`bool`): whether use value clip Returns: - value_loss (:obj:`torch.FloatTensor`): the ppo value loss item, \ all of them are the differentiable 0-dim tensor Shapes: - value_new (:obj:`torch.FloatTensor`): :math:`(B, )`, where B is batch size - value_old (:obj:`torch.FloatTensor`): :math:`(B, )` - return (:obj:`torch.FloatTensor`): :math:`(B, )` - weight (:obj:`torch.FloatTensor` or :obj:`None`): :math:`(B, )` - value_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor Examples: >>> action_dim = 4 >>> data = ppo_value_data( >>> value_new=torch.randn(3), >>> value_old=torch.randn(3), >>> return_=torch.randn(3), >>> weight=torch.ones(3), >>> ) >>> loss, info = ppo_value_error(data) ''' value_new, value_old, return_, weight = data if weight is None: weight = torch.ones_like(value_old) # value_loss if use_value_clip: value_clip = value_old + (value_new - value_old).clamp(-clip_ratio, clip_ratio) v1 = (return_ - value_new).pow(2) v2 = (return_ - value_clip).pow(2) value_loss = 0.5 * (torch.max(v1, v2) * weight).mean() else: value_loss = 0.5 * ((return_ - value_new).pow(2) * weight).mean() return value_loss
[docs]def ppo_error_continuous( data: namedtuple, clip_ratio: float = 0.2, use_value_clip: bool = True, dual_clip: Optional[float] = None, kl_type: str = 'k1' ) -> Tuple[namedtuple, namedtuple]: """ Overview: Implementation of Proximal Policy Optimization (arXiv:1707.06347) with value_clip and dual_clip Arguments: - data (:obj:`namedtuple`): the ppo input data with fieids shown in ``ppo_data`` - clip_ratio (:obj:`float`): the ppo clip ratio for the constraint of policy update, defaults to 0.2 - use_value_clip (:obj:`bool`): whether to use clip in value loss with the same ratio as policy - dual_clip (:obj:`float`): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf),\ defaults to 5.0, if you don't want to use it, set this parameter to None - kl_type (:obj:`str`): which kl loss to use, default set to 'k1'. Returns: - ppo_loss (:obj:`namedtuple`): the ppo loss item, all of them are the differentiable 0-dim tensor - ppo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar Shapes: - mu_sigma_new (:obj:`tuple`): :math:`((B, N), (B, N))`, where B is batch size and N is action dim - mu_sigma_old (:obj:`tuple`): :math:`((B, N), (B, N))`, where B is batch size and N is action dim - action (:obj:`torch.LongTensor`): :math:`(B, )` - value_new (:obj:`torch.FloatTensor`): :math:`(B, )` - value_old (:obj:`torch.FloatTensor`): :math:`(B, )` - adv (:obj:`torch.FloatTensor`): :math:`(B, )` - return (:obj:`torch.FloatTensor`): :math:`(B, )` - weight (:obj:`torch.FloatTensor` or :obj:`None`): :math:`(B, )` - policy_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor - value_loss (:obj:`torch.FloatTensor`): :math:`()` - entropy_loss (:obj:`torch.FloatTensor`): :math:`()` Examples: >>> action_dim = 4 >>> data = ppo_data_continuous( >>> mu_sigma_new= dict(mu=torch.randn(3, action_dim), sigma=torch.randn(3, action_dim)**2), >>> mu_sigma_old= dict(mu=torch.randn(3, action_dim), sigma=torch.randn(3, action_dim)**2), >>> action=torch.randn(3, action_dim), >>> value_new=torch.randn(3), >>> value_old=torch.randn(3), >>> adv=torch.randn(3), >>> return_=torch.randn(3), >>> weight=torch.ones(3), >>> ) >>> loss, info = ppo_error(data) .. note:: adv is already normalized value (adv - adv.mean()) / (adv.std() + 1e-8), and there are many ways to calculate this mean and std, like among data buffer or train batch, so we don't couple this part into ppo_error, you can refer to our examples for different ways. """ assert dual_clip is None or dual_clip > 1.0, "dual_clip value must be greater than 1.0, but get value: {}".format( dual_clip ) mu_sigma_new, mu_sigma_old, action, value_new, value_old, adv, return_, weight, logit_pretrained = data if weight is None: weight = torch.ones_like(adv) dist_new = Independent(Normal(mu_sigma_new['mu'], mu_sigma_new['sigma']), 1) if len(mu_sigma_old['mu'].shape) == 1: dist_old = Independent(Normal(mu_sigma_old['mu'].unsqueeze(-1), mu_sigma_old['sigma'].unsqueeze(-1)), 1) else: dist_old = Independent(Normal(mu_sigma_old['mu'], mu_sigma_old['sigma']), 1) logp_new = dist_new.log_prob(action) logp_old = dist_old.log_prob(action) entropy_loss = (dist_new.entropy() * weight).mean() # policy_loss ratio = torch.exp(logp_new - logp_old) surr1 = ratio * adv surr2 = ratio.clamp(1 - clip_ratio, 1 + clip_ratio) * adv if dual_clip is not None: policy_loss = (-torch.max(torch.min(surr1, surr2), dual_clip * adv) * weight).mean() else: policy_loss = (-torch.min(surr1, surr2) * weight).mean() with torch.no_grad(): approx_kl = (logp_old - logp_new).mean().item() clipped = ratio.gt(1 + clip_ratio) | ratio.lt(1 - clip_ratio) clipfrac = torch.as_tensor(clipped).float().mean().item() # value_loss if use_value_clip: value_clip = value_old + (value_new - value_old).clamp(-clip_ratio, clip_ratio) v1 = (return_ - value_new).pow(2) v2 = (return_ - value_clip).pow(2) value_loss = 0.5 * (torch.max(v1, v2) * weight).mean() else: value_loss = 0.5 * ((return_ - value_new).pow(2) * weight).mean() if logit_pretrained is not None: dist_pretrained = Independent(Normal(logit_pretrained['mu'], logit_pretrained['sigma']), 1) logp_pretrained = dist_pretrained.log_prob(action) log_ratio = logp_new - logp_pretrained kl_div = calculate_kl_div(log_ratio, kl_type) else: kl_div = 0 return ppo_loss(policy_loss, value_loss, entropy_loss, kl_div), ppo_info(approx_kl, clipfrac)
[docs]def ppo_policy_error_continuous( data: namedtuple, clip_ratio: float = 0.2, dual_clip: Optional[float] = None, kl_type: str = 'k1' ) -> Tuple[namedtuple, namedtuple]: """ Overview: Implementation of Proximal Policy Optimization (arXiv:1707.06347) with dual_clip Arguments: - data (:obj:`namedtuple`): the ppo input data with fieids shown in ``ppo_data`` - clip_ratio (:obj:`float`): the ppo clip ratio for the constraint of policy update, defaults to 0.2 - dual_clip (:obj:`float`): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf),\ defaults to 5.0, if you don't want to use it, set this parameter to None - kl_type (:obj:`str`): which kl loss to use, default set to 'k1'. Returns: - ppo_loss (:obj:`namedtuple`): the ppo loss item, all of them are the differentiable 0-dim tensor - ppo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar Shapes: - mu_sigma_new (:obj:`tuple`): :math:`((B, N), (B, N))`, where B is batch size and N is action dim - mu_sigma_old (:obj:`tuple`): :math:`((B, N), (B, N))`, where B is batch size and N is action dim - action (:obj:`torch.LongTensor`): :math:`(B, )` - adv (:obj:`torch.FloatTensor`): :math:`(B, )` - weight (:obj:`torch.FloatTensor` or :obj:`None`): :math:`(B, )` - policy_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor - entropy_loss (:obj:`torch.FloatTensor`): :math:`()` Examples: >>> action_dim = 4 >>> data = ppo_policy_data_continuous( >>> mu_sigma_new=dict(mu=torch.randn(3, action_dim), sigma=torch.randn(3, action_dim)**2), >>> mu_sigma_old=dict(mu=torch.randn(3, action_dim), sigma=torch.randn(3, action_dim)**2), >>> action=torch.randn(3, action_dim), >>> adv=torch.randn(3), >>> weight=torch.ones(3), >>> ) >>> loss, info = ppo_policy_error_continuous(data) """ assert dual_clip is None or dual_clip > 1.0, "dual_clip value must be greater than 1.0, but get value: {}".format( dual_clip ) mu_sigma_new, mu_sigma_old, action, adv, weight, logit_pretrained = data if weight is None: weight = torch.ones_like(adv) dist_new = Independent(Normal(mu_sigma_new['mu'], mu_sigma_new['sigma']), 1) if len(mu_sigma_old['mu'].shape) == 1: dist_old = Independent(Normal(mu_sigma_old['mu'].unsqueeze(-1), mu_sigma_old['sigma'].unsqueeze(-1)), 1) else: dist_old = Independent(Normal(mu_sigma_old['mu'], mu_sigma_old['sigma']), 1) logp_new = dist_new.log_prob(action) logp_old = dist_old.log_prob(action) entropy_loss = (dist_new.entropy() * weight).mean() # policy_loss ratio = torch.exp(logp_new - logp_old) surr1 = ratio * adv surr2 = ratio.clamp(1 - clip_ratio, 1 + clip_ratio) * adv if dual_clip is not None: policy_loss = (-torch.max(torch.min(surr1, surr2), dual_clip * adv) * weight).mean() else: policy_loss = (-torch.min(surr1, surr2) * weight).mean() with torch.no_grad(): approx_kl = (logp_old - logp_new).mean().item() clipped = ratio.gt(1 + clip_ratio) | ratio.lt(1 - clip_ratio) clipfrac = torch.as_tensor(clipped).float().mean().item() if logit_pretrained is not None: dist_pretrained = Independent(Normal(logit_pretrained['mu'], logit_pretrained['sigma']), 1) logp_pretrained = dist_pretrained.log_prob(action) log_ratio = logp_new - logp_pretrained kl_div = calculate_kl_div(log_ratio, kl_type) else: kl_div = 0 return ppo_policy_loss(policy_loss, entropy_loss, kl_div), ppo_info(approx_kl, clipfrac)