Overview
Implementation of PG (Policy Gradient)
x from collections import namedtuple
import torch
pg_data = namedtuple('pg_data', ['logit', 'action', 'return_'])
pg_loss = namedtuple('pg_loss', ['policy_loss', 'entropy_loss'])
def pg_error(data: namedtuple) -> namedtuple:
Unpack data:
xxxxxxxxxx
9 logit, action, return_ = data
Prepare policy distribution from logit and get log propability.
xxxxxxxxxx
11 dist = torch.distributions.categorical.Categorical(logits=logit)
log_prob = dist.log_prob(action)
Policy loss:
xxxxxxxxxx
12 policy_loss = -(log_prob * return_).mean()
Entropy bonus:
P.S. the final loss is policy_loss - entropy_weight * entropy_loss .
xxxxxxxxxx
13 entropy_loss = dist.entropy().mean()
Return the concrete loss items.
xxxxxxxxxx
16 return pg_loss(policy_loss, entropy_loss)
Overview
Test function of PG, for both forward and backward operations.
xxxxxxxxxx
17 def test_pg():
batch size=4, action=32
xxxxxxxxxx
18 B, N = 4, 32
Generate logit, action, return_.
xxxxxxxxxx
21 logit = torch.randn(B, N).requires_grad_(True)
action = torch.randint(0, N, size=(B, ))
return_ = torch.randn(B) * 2
Compute PG error.
xxxxxxxxxx
23 data = pg_data(logit, action, return_)
loss = pg_error(data)
Assert the loss is differentiable.
xxxxxxxxxx
30 assert all([l.shape == tuple() for l in loss])
assert logit.grad is None
total_loss = sum(loss)
total_loss.backward()
assert isinstance(logit.grad, torch.Tensor)
If you have any questions or advices about this documation, you can raise issues in GitHub (https://github.com/opendilab/PPOxFamily) or email us (opendilab@pjlab.org.cn).