Source code for ding.torch_utils.network.popart
"""
Implementation of ``POPART`` algorithm for reward rescale.
<link https://arxiv.org/abs/1602.07714 link>
POPART is an adaptive normalization algorithm to normalize the targets used in the learning updates.
The two main components in POPART are:
**ART**: to update scale and shift such that the return is appropriately normalized,
**POP**: to preserve the outputs of the unnormalized function when we change the scale and shift.
"""
from typing import Optional, Union, Dict
import math
import torch
import torch.nn as nn
[docs]class PopArt(nn.Module):
"""
Overview:
A linear layer with popart normalization. This class implements a linear transformation followed by
PopArt normalization, which is a method to automatically adapt the contribution of each task to the agent's
updates in multi-task learning, as described in the paper <https://arxiv.org/abs/1809.04474>.
Interfaces:
``__init__``, ``reset_parameters``, ``forward``, ``update_parameters``
"""
[docs] def __init__(
self,
input_features: Union[int, None] = None,
output_features: Union[int, None] = None,
beta: float = 0.5
) -> None:
"""
Overview:
Initialize the class with input features, output features, and the beta parameter.
Arguments:
- input_features (:obj:`Union[int, None]`): The size of each input sample.
- output_features (:obj:`Union[int, None]`): The size of each output sample.
- beta (:obj:`float`): The parameter for moving average.
"""
super(PopArt, self).__init__()
self.beta = beta
self.input_features = input_features
self.output_features = output_features
# Initialize the linear layer parameters, weight and bias.
self.weight = nn.Parameter(torch.Tensor(output_features, input_features))
self.bias = nn.Parameter(torch.Tensor(output_features))
# Register a buffer for normalization parameters which can not be considered as model parameters.
# The normalization parameters will be used later to save the target value's scale and shift.
self.register_buffer('mu', torch.zeros(output_features, requires_grad=False))
self.register_buffer('sigma', torch.ones(output_features, requires_grad=False))
self.register_buffer('v', torch.ones(output_features, requires_grad=False))
self.reset_parameters()
[docs] def reset_parameters(self):
"""
Overview:
Reset the parameters including weights and bias using ``kaiming_uniform_`` and ``uniform_`` initialization.
"""
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
nn.init.uniform_(self.bias, -bound, bound)
[docs] def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
"""
Overview:
Implement the forward computation of the linear layer and return both the output and the
normalized output of the layer.
Arguments:
- x (:obj:`torch.Tensor`): Input tensor which is to be normalized.
Returns:
- output (:obj:`Dict[str, torch.Tensor]`): A dictionary contains 'pred' and 'unnormalized_pred'.
"""
normalized_output = x.mm(self.weight.t())
normalized_output += self.bias.unsqueeze(0).expand_as(normalized_output)
# The unnormalization of output
with torch.no_grad():
output = normalized_output * self.sigma + self.mu
return {'pred': normalized_output.squeeze(1), 'unnormalized_pred': output.squeeze(1)}
[docs] def update_parameters(self, value: torch.Tensor) -> Dict[str, torch.Tensor]:
"""
Overview:
Update the normalization parameters based on the given value and return the new mean and
standard deviation after the update.
Arguments:
- value (:obj:`torch.Tensor`): The tensor to be used for updating parameters.
Returns:
- update_results (:obj:`Dict[str, torch.Tensor]`): A dictionary contains 'new_mean' and 'new_std'.
"""
# Tensor device conversion of the normalization parameters.
self.mu = self.mu.to(value.device)
self.sigma = self.sigma.to(value.device)
self.v = self.v.to(value.device)
old_mu = self.mu
old_std = self.sigma
# Calculate the first and second moments (mean and variance) of the target value:
batch_mean = torch.mean(value, 0)
batch_v = torch.mean(torch.pow(value, 2), 0)
batch_mean[torch.isnan(batch_mean)] = self.mu[torch.isnan(batch_mean)]
batch_v[torch.isnan(batch_v)] = self.v[torch.isnan(batch_v)]
batch_mean = (1 - self.beta) * self.mu + self.beta * batch_mean
batch_v = (1 - self.beta) * self.v + self.beta * batch_v
batch_std = torch.sqrt(batch_v - (batch_mean ** 2))
# Clip the standard deviation to reject the outlier data.
batch_std = torch.clamp(batch_std, min=1e-4, max=1e+6)
# Replace the nan value with old value.
batch_std[torch.isnan(batch_std)] = self.sigma[torch.isnan(batch_std)]
self.mu = batch_mean
self.v = batch_v
self.sigma = batch_std
# Update weight and bias with mean and standard deviation to preserve unnormalised outputs
self.weight.data = (self.weight.data.t() * old_std / self.sigma).t()
self.bias.data = (old_std * self.bias.data + old_mu - self.mu) / self.sigma
return {'new_mean': batch_mean, 'new_std': batch_std}