Overview
Implementation of A2C (Advantage Actor-Critic) Related Link
from collections import namedtuple
import torch
import torch.nn.functional as F
a2c_data = namedtuple('a2c_data', ['logit', 'action', 'value', 'adv', 'return_', 'weight'])
a2c_loss = namedtuple('a2c_loss', ['policy_loss', 'value_loss', 'entropy_loss'])
def a2c_error(data: namedtuple) -> namedtuple:
Unpack data: $$<\pi(a|s), a, V(s), A^{\pi}(s, a), G_t, w>$$
logit, action, value, adv, return_, weight = data
Prepare weight for default cases.
if weight is None:
weight = torch.ones_like(value)
Prepare policy distribution from logit and get log propability.
dist = torch.distributions.categorical.Categorical(logits=logit)
logp = dist.log_prob(action)
Policy loss: $$- \frac 1 N \sum_{n=1}^{N} log(\pi(a^n|s^n)) A^{\pi}(s^n, a^n)$$
policy_loss = -(logp * adv * weight).mean()
Value loss: $$\frac 1 N \sum_{n=1}^{N} (G_t^n - V(s^n))^2$$
value_loss = (F.mse_loss(return_, value, reduction='none') * weight).mean()
Entropy bonus: $$\frac 1 N \sum_{n=1}^{N} \sum_{a^n}\pi(a^n|s^n) log(\pi(a^n|s^n))$$
P.S. the final loss is policy_loss + value_weight * value_loss - entropy_weight * entropy_loss .
entropy_loss = (dist.entropy() * weight).mean()
Return the concrete loss items.
return a2c_loss(policy_loss, value_loss, entropy_loss)
Overview
Test function of A2C, for both forward and backward operations.
def test_a2c():
batch size=4, action=32
B, N = 4, 32
Generate logit, action, value, adv, return_.
logit = torch.randn(B, N).requires_grad_(True)
action = torch.randint(0, N, size=(B, ))
value = torch.randn(B).requires_grad_(True)
adv = torch.rand(B)
return_ = torch.randn(B) * 2
data = a2c_data(logit, action, value, adv, return_, None)
Compute A2C error.
loss = a2c_error(data)
Assert the loss is differentiable.
assert logit.grad is None
assert value.grad is None
total_loss = sum(loss)
total_loss.backward()
assert isinstance(logit.grad, torch.Tensor)
assert isinstance(value.grad, torch.Tensor)
If you have any questions or advices about this documation, you can raise issues in GitHub (https://github.com/opendilab/PPOxFamily) or email us (opendilab@pjlab.org.cn).