Overview
The definition of discrete action policy network used in PPO, which is mainly composed of two parts: encoder and head.
from typing import List
import torch
import torch.nn as nn
class DiscretePolicyNetwork(nn.Module):
def __init__(self, obs_shape: int, action_shape: int) -> None:
PyTorch necessary requirements for extending nn.Module . Our network should also subclass this class.
super(DiscretePolicyNetwork, self).__init__()
Define encoder module, which maps raw state into embedding vector.
It could be different for various state, such as Convolution Neural Network (CNN) for image state and Multilayer perceptron (MLP) for vector state, respectively.
Here we use one-layer MLP for vector state, i.e.
$$y = max(Wx+b, 0)$$
self.encoder = nn.Sequential(
nn.Linear(obs_shape, 32),
nn.ReLU(),
)
Define discrete action logit output network, just one-layer FC, i.e.
$$y=Wx+b$$
self.head = nn.Linear(32, action_shape)
Overview
The computation graph of discrete action policy network used in PPO.
x -> encoder -> head -> logit .
def forward(self, x: torch.Tensor) -> torch.Tensor:
Transform original state into embedding vector, i.e. $$(B, *) -> (B, N)$$
x = self.encoder(x)
Calculate logit for each possible discrete action choices, i.e. $$(B, N) -> (B, A)$$
logit = self.head(x)
return logit
Overview
The definition of multi discrete action policy network used in PPO, which uses multiple discrete head.
class MultiDiscretePolicyNetwork(nn.Module):
def __init__(self, obs_shape: int, action_shape: List[int]) -> None:
PyTorch necessary requirements for extending nn.Module . Our network should also subclass this class.
super(MultiDiscretePolicyNetwork, self).__init__()
Define encoder module, which maps raw state into embedding vector.
It could be different for various state, such as Convolution Neural Network (CNN) for image state and Multilayer perceptron (MLP) for vector state, respectively.
Here we use one-layer MLP for vector state, i.e.
$$y = max(Wx+b, 0)$$
self.encoder = nn.Sequential(
nn.Linear(obs_shape, 32),
nn.ReLU(),
)
Define multiple discrete head according to the concrete sub action size.
self.head = nn.ModuleList()
for size in action_shape:
self.head.append(nn.Linear(32, size))
Overview
The computation graph of discrete action policy network used in PPO.
x -> encoder -> multiple head -> multiple logit .
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
Transform original state into embedding vector, i.e. $$(B, *) -> (B, N)$$
x = self.encoder(x)
Calculate multiple logit for each possible discrete action, i.e. $$(B, N) -> [(B, A_1), ..., (B, A_N)]$$
logit = [h(x) for h in self.head]
return logit
Overview
The function of sampling discrete action, input shape = (B, action_shape), output shape = (B, ).
In this example, the distributions shapes are:
batch_shape = (B, ), event_shape = (), sample_shape = ().
def sample_action(logit: torch.Tensor) -> torch.Tensor:
Transform logit (raw output of policy network, e.g. last fully connected layer) into probability.
$$\text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}$$
prob = torch.softmax(logit, dim=-1)
Construct categorical distribution. The probability mass function is: $$f(x=i|\boldsymbol{p})=p_i$$
Related Link
dist = torch.distributions.Categorical(probs=prob)
Sample one discrete action per sample (state input) and return it.
return dist.sample()
Overview
The function of testing sampling discrete action. Construct a naive policy and sample a group of action.
def test_sample_discrete_action():
Set batch_size = 4, obs_shape = 10, action_shape = 6.
B, obs_shape, action_shape = 4, 10, 6
Generate state data from uniform distribution in [0, 1].
state = torch.rand(B, obs_shape)
Define policy network with encoder and head.
policy_network = DiscretePolicyNetwork(obs_shape, action_shape)
Policy network forward procedure, input state and output logit.
$$ logit = \pi(a|s)$$
logit = policy_network(state)
assert logit.shape == (B, action_shape)
Sample action accoding to corresponding logit.
action = sample_action(logit)
assert action.shape == (B, )
Overview
The function of testing sampling multi-discrete action. Construct a naive policy and sample a group of multi-discrete action.
def test_sample_multi_discrete_action():
Set batch_size = 4, obs_shape = 10, action_shape = [4, 5, 6].
B, obs_shape, action_shape = 4, 10, [4, 5, 6]
Generate state data from uniform distribution in [0, 1].
state = torch.rand(B, obs_shape)
Define policy network with encoder and head.
policy_network = MultiDiscretePolicyNetwork(obs_shape, action_shape)
Policy network forward procedure, input state and output multiple logit.
$$ logit = \pi(a|s)$$
logit = policy_network(state)
for i in range(len(logit)):
assert logit[i].shape == (B, action_shape[i])
Sample action accoding to corresponding logit one by one.
for i in range(len(logit)):
action_i = sample_action(logit[i])
assert action_i.shape == (B, )
If you have any questions or advices about this documation, you can raise issues in GitHub (https://github.com/opendilab/PPOxFamily) or email us (opendilab@pjlab.org.cn).