Overview
The definition of Pop-Art layer, i.e., a linear layer with popart normalization, which should be
used as the last layer of a network.
For more information, you can refer to the paper Related Link
x import pickle
import math
import torch
import torch.nn as nn
import treetensor.torch as ttorch
from torch.optim import AdamW
from torch.utils.data import DataLoader
class PopArt(nn.Module):
PyTorch necessary requirements for extending nn.Module . Our network should also subclass this class.
xxxxxxxxxx
12 super(PopArt, self).__init__()
Define soft-update parameter beta.
xxxxxxxxxx
13 self.beta = beta
Define the input and output feature dimensions of the linear layer.
xxxxxxxxxx
15 self.input_features = input_features
self.output_features = output_features
Initialize the linear layer parameters, weight and bias.
xxxxxxxxxx
17 self.weight = nn.Parameter(torch.Tensor(output_features, input_features))
self.bias = nn.Parameter(torch.Tensor(output_features))
Register a buffer for normalization parameters which can not be considered as model parameters.
Therefore, the tensor registered in buffer will not refer to gradient propagation but still can
be saved in state_dict.
The normalization parameters will be used later to save the target value's scale and shift.
xxxxxxxxxx
21 self.register_buffer('mu', torch.zeros(output_features, requires_grad=False))
self.register_buffer('sigma', torch.ones(output_features, requires_grad=False))
self.register_buffer('v', torch.ones(output_features, requires_grad=False))
Reset the learned parameters, i.e., weight and bias.
xxxxxxxxxx
23 self.reset_parameters()
Overview
The parameters initialization of the linear layer (i.e. weight and bias).
xxxxxxxxxx
24 def reset_parameters(self) -> None:
In Kaiming Initialization, the mean of weights increment slowly and the std is close to 1,
which avoid the vanishing gradient problem and exploding gradient problem of deep models.
Specifically, the Kaiming Initialization funciton is as follows:
where a is the the negative slope of the rectifier used after this layer (0 for ReLU by default),
and fan_in is the number of input dimension.
For more kaiming intialization info, you can refer to the paper:
Related Link
xxxxxxxxxx
30 nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
nn.init.uniform_(self.bias, -bound, bound)
Overview
The computation graph of the linear layer with popart mechanism, which outputs both the output and the normalized output of the layer.
xxxxxxxxxx
31 def forward(self, x: torch.Tensor) -> ttorch.Tensor:
Execute the linear layer computation , note here we use expand and broadcast to add bias.
xxxxxxxxxx
33 normalized_output = x.mm(self.weight.t())
normalized_output += self.bias.unsqueeze(0).expand_as(normalized_output)
Unnormalize the output for more convenient usage.
xxxxxxxxxx
38 with torch.no_grad():
output = normalized_output * self.sigma + self.mu
return ttorch.as_tensor({'output': output, 'normalized_output': normalized_output})
Overview
The parameters update defined in Pop-Art, which outputs both the output and the normalized output of the layer.
xxxxxxxxxx
39 def update_parameters(self, value: torch.Tensor) -> ttorch.Tensor:
Tensor device conversion of the normalization parameters.
xxxxxxxxxx
43 self.mu = self.mu.to(value.device)
self.sigma = self.sigma.to(value.device)
self.v = self.v.to(value.device)
Store the old normalization parameters for later usage.
xxxxxxxxxx
45 old_mu = self.mu
old_std = self.sigma
Calculate the first and second moments (mean and variance) of the target value:
.
xxxxxxxxxx
47 batch_mean = torch.mean(value, 0)
batch_v = torch.mean(torch.pow(value, 2), 0)
Replace the nan value with old value for more stable training.
xxxxxxxxxx
49 batch_mean[torch.isnan(batch_mean)] = self.mu[torch.isnan(batch_mean)]
batch_v[torch.isnan(batch_v)] = self.v[torch.isnan(batch_v)]
Soft update the normalization parameters according to:
.
xxxxxxxxxx
51 batch_mean = (1 - self.beta) * self.mu + self.beta * batch_mean
batch_v = (1 - self.beta) * self.v + self.beta * batch_v
Calculate the standard deviation with the mean and variance:
xxxxxxxxxx
52 batch_std = torch.sqrt(batch_v - (batch_mean ** 2))
Clip the standard deviation to reject the outlier data.
xxxxxxxxxx
53 batch_std = torch.clamp(batch_std, min=1e-4, max=1e+6)
Replace the nan value with old value.
xxxxxxxxxx
55 batch_std[torch.isnan(batch_std)] = self.sigma[torch.isnan(batch_std)]
Update the normalization parameters.
xxxxxxxxxx
58 self.mu = batch_mean
self.v = batch_v
self.sigma = batch_std
Update weight and bias with mean and standard deviation to preserve unnormalised outputs:
xxxxxxxxxx
61 self.weight.data = (self.weight.t() * old_std / self.sigma).t()
self.bias.data = (old_std * self.bias + old_mu - self.mu) / self.sigma
Return treetensor-type statistics.
xxxxxxxxxx
64 return ttorch.as_tensor({'new_mean': batch_mean, 'new_std': batch_std})
Overview
A MLP network with popart as the final layer.
Input: observations and actions
Output: Estimated Q value
cat(obs,actions) -> encoder -> popart .
xxxxxxxxxx
67 class MLP(nn.Module):
def __init__(self, obs_shape: int, action_shape: int) -> None:
Define the encoder and popart layer.
Here we use MLP with two layer and ReLU as activate function. The final layer is popart layer.
xxxxxxxxxx
75 self.encoder = nn.Sequential(
nn.Linear(obs_shape + action_shape, 16),
nn.ReLU(),
nn.Linear(16, 32),
nn.ReLU(),
)
self.popart = PopArt(32, 1)
Overview
Forward computation of the MLP network with popart layer.
xxxxxxxxxx
76 def forward(self, obs: torch.Tensor, actions: torch.Tensor) -> ttorch.Tensor:
The encoder first concatenate the observation vectors and actions,
then map the input to an embedding vector.
xxxxxxxxxx
78 x = torch.cat((obs, actions), 1)
x = self.encoder(x)
The popart layer maps the embedding vector to a normalized value.
xxxxxxxxxx
82 x = self.popart(x)
return x
Overview
Example training function for using MLP network with Pop-Art layer in fixed Q value approximation.
xxxxxxxxxx
83 def train(obs_shape: int, action_shape: int, NUM_EPOCH: int, train_data):
Define the MLP network and optimizer, and loss function.
xxxxxxxxxx
86 model = MLP(obs_shape, action_shape)
optimizer = AdamW(model.parameters(), lr=0.0001, weight_decay=0.0001)
MSEloss = nn.MSELoss()
Read the preprocessed data of trained agent on lunarlander environment.
Each sample in the datasets should be a dict with following format:
where the returns is the discounted return from the current state.
xxxxxxxxxx
88 train_data = DataLoader(train_data, batch_size=64, shuffle=True)
For loop 1: train MLP network for NUM_EPOCH epochs.
xxxxxxxxxx
90 running_loss = 0.0
for epoch in range(NUM_EPOCH):
For loop 2: Inside each epoch, split the entire dataset into mini-batches, then train on each mini-batch.
xxxxxxxxxx
91 for idx, data in enumerate(train_data):
Compute the original output and the normalized output.
xxxxxxxxxx
94 output = model(data['observations'], data['actions'])
mu = model.popart.mu
sigma = model.popart.sigma
Normalize the target return to align with the normalized Q value.
xxxxxxxxxx
96 with torch.no_grad():
normalized_return = (data['returns'] - mu) / sigma
The loss is calculated as the MSE loss between normalized Q value and normalized target return.
xxxxxxxxxx
97 loss = MSEloss(output.normalized_output, normalized_return)
Loss backward and optimizer update step.
xxxxxxxxxx
100 optimizer.zero_grad()
loss.backward()
optimizer.step()
After the model parameters are updated with the gradient,
the weights and bias should be updated to preserve unnormalised outputs.
xxxxxxxxxx
102 model.popart.update_parameters(data['returns'])
Use item method to get the pure Python scalar of the loss, then add it into running_loss .
xxxxxxxxxx
104 running_loss += loss.item()
Print the loss every 100 epochs.
xxxxxxxxxx
109 if epoch % 100 == 99:
print('Epoch [%d] loss: %.6f' % (epoch + 1, running_loss / 100))
running_loss = 0.0
If you have any questions or advices about this documation, you can raise issues in GitHub (https://github.com/opendilab/PPOxFamily) or email us (opendilab@pjlab.org.cn).