Overview
The definition of continuous action policy network used in PPO, which is mainly composed of three parts: encoder, mu and log_sigma.
x from typing import Dict
import torch
import torch.nn as nn
from torch.distributions import Normal, Independent
class ContinuousPolicyNetwork(nn.Module):
def __init__(self, obs_shape: int, action_shape: int) -> None:
PyTorch necessary requirements for extending nn.Module . Our network should also subclass this class.
xxxxxxxxxx
9 super(ContinuousPolicyNetwork, self).__init__()
Define encoder module, which maps raw state into embedding vector.
It could be different for various state, such as Convolution Neural Network (CNN) for image state and Multilayer perceptron (MLP) for vector state, respectively.
Here we use two-layer MLP for vector state.
xxxxxxxxxx
15 self.encoder = nn.Sequential(
nn.Linear(obs_shape, 16),
nn.ReLU(),
nn.Linear(16, 32),
nn.ReLU(),
)
Define mu module, which is a FC and outputs the argument mu for gaussian distribution.
xxxxxxxxxx
16 self.mu = nn.Linear(32, action_shape)
Define log_sigma module, which is a learnable parameter but independent to state.
Here we set it as log_sigma for the convenience of optimization and usage. You can also adjust its initial value for your demands.
xxxxxxxxxx
18 self.log_sigma = nn.Parameter(torch.zeros(1, action_shape))
Overview
The computation graph of continuous action policy network used in PPO.
x -> encoder -> mu -> \mu .
log_sigma -> exp -> sigma .
xxxxxxxxxx
19 def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
Transform original state into embedding vector, i.e.
xxxxxxxxxx
20 x = self.encoder(x)
Output the argument mu depending on the embedding vector, i.e.
xxxxxxxxxx
21 mu = self.mu(x)
Utilize broadcast mechanism to make the same shape between log_sigma and mu.
zeros_like operation doesn't pass gradient.
Related Link
xxxxxxxxxx
22 log_sigma = self.log_sigma + torch.zeros_like(mu)
Utilize exponential operation to produce the actual sigma.
xxxxxxxxxx
26 sigma = torch.exp(log_sigma)
return {'mu': mu, 'sigma': sigma}
Overview
The function of sampling continuous action, input is a dict with two keys mu and sigma ,
both of them has shape = (B, action_shape), output shape = (B, action_shape).
In this example, the distributions shapes are:
batch_shape = (B, ), event_shape = (action_shape, ), sample_shape = ().
xxxxxxxxxx
27 def sample_continuous_action(logit: Dict[str, torch.Tensor]) -> torch.Tensor:
Construct gaussian distribution, i.e.
Its probability density function is:
Related Link
xxxxxxxxxx
28 dist = Normal(logit['mu'], logit['sigma'])
Reinterpret action_shape gaussian distribution into a multivariate gaussian distribution with diagonal convariance matrix.
Ensure each event is independent with each other.
Related Link
xxxxxxxxxx
29 dist = Independent(dist, 1)
Sample one action of the shape action_shape per sample (state input) and return it.
xxxxxxxxxx
32 return dist.sample()
Overview
The function of testing sampling continuous action. Construct a standard continuous action
policy and sample a group of action.
xxxxxxxxxx
33 def test_sample_continuous_action():
Set batch_size = 4, obs_shape = 10, action_shape = 6.
action_shape is different from discrete and continuous action. The former is the possible
choice of a discrete action while the latter is the dimension of continuous action.
xxxxxxxxxx
34 B, obs_shape, action_shape = 4, 10, 6
Generate state data from uniform distribution in [0, 1].
xxxxxxxxxx
35 state = torch.rand(B, obs_shape)
Define continuous action network (which is similar to reparameterization) with encoder, mu and log_sigma.
xxxxxxxxxx
36 policy_network = ContinuousPolicyNetwork(obs_shape, action_shape)
Policy network forward procedure, input state and output dict-type logit.
xxxxxxxxxx
40 logit = policy_network(state)
assert isinstance(logit, dict)
assert logit['mu'].shape == (B, action_shape)
assert logit['sigma'].shape == (B, action_shape)
Sample action accoding to corresponding logit (i.e., mu and sigma).
xxxxxxxxxx
44 action = sample_continuous_action(logit)
assert action.shape == (B, action_shape)
If you have any questions or advices about this documation, you can raise issues in GitHub (https://github.com/opendilab/PPOxFamily) or email us (opendilab@pjlab.org.cn).