HOME

GitHub bilibili twitter
View code on GitHub

PyTorch implementation of Proximal Policy Optimization (PPO)
PPO is one of the most common algorithms in reinforcement learning, which combines Actor-Critic methods and Trust Region Policy Optimization.
For the policy part, PPO combines clipped optimization target and pessimistic bound to update policy. For the value function part, PPO usually uses classical temporal difference methods (such as GAE).
This final target function is formulated as:
min(πθ(atst)πθk(atst)Aθk(st,at),clip(πθ(atst)πθk(atst),1ϵ,1+ϵ)Aθk(st,at))\min(\frac{\pi_{\theta}(a_{t}|s_{t})}{\pi_{\theta_k}(a_{t}|s_{t})}A^{\theta_k}(s_{t},a_{t}),\text{clip}(\frac{\pi_{\theta}(a_{t}|s_{t})}{\pi_{\theta_k}(a_{t}|s_{t})}, 1-\epsilon,1+\epsilon)A^{\theta_k}(s_{t},a_{t}))
This document mainly includes:
- Implementation of PPO error.
- Main function (test function)

Overview
Implementation of Proximal Policy Optimization (PPO) Related Link with entropy bonus, value_clip and dual_clip.

Unpack data: <πnew(as),πold(as),a,Aπold(s,a),w><\pi_{new}(a|s), \pi_{old}(a|s), a, A^{\pi_{old}}(s, a), w>

Prepare weight for default cases.

Prepare policy distribution from logit and get log propability.

Entropy bonus: 1Nn=1Nanπnew(ansn)log(πnew(ansn))\frac 1 N \sum_{n=1}^{N} \sum_{a^n}\pi_{new}(a^n|s^n) log(\pi_{new}(a^n|s^n))
P.S. the final loss is policy_loss - entropy_weight * entropy_loss .

Importance sampling weight: r(θ)=πnew(as)πold(as)r(\theta) = \frac{\pi_{new}(a|s)}{\pi_{old}(a|s)}

Original surrogate objective: r(θ)Aπold(s,a)r(\theta) A^{\pi_{old}}(s, a)

Clipped surrogate objective: clip(r(θ),1ϵ,1+ϵ)Aπold(s,a)clip(r(\theta), 1-\epsilon, 1+\epsilon) A^{\pi_{old}}(s, a)

Dual clip proposed by Related Link .
Only use dual_clip when adv < 0.

PPO-Clipped Loss: min(r(θ)Aπold(s,a),clip(r(θ),1ϵ,1+ϵ)Aπold(s,a))min(r(\theta) A^{\pi_{old}}(s, a), clip(r(\theta), 1-\epsilon, 1+\epsilon) A^{\pi_{old}}(s, a))
Multiply sample-wise weight and reduce mean in batch dimension.

Add some visualization metrics to monitor optimization status.

Return final loss items and information.

Overview
Test function of PPO, for both forward and backward operations.

batch size=4, action=32

Generate logit_new, logit_old, action, adv.

Compute PPO error.

Assert the loss is differentiable.

If you have any questions or advices about this documation, you can raise issues in GitHub (https://github.com/opendilab/PPOxFamily) or email us (opendilab@pjlab.org.cn).