"""
Overview:
This Python file provides a collection of reusable model templates designed to streamline the development
process for various custom algorithms. By utilizing these pre-built model templates, users can quickly adapt and
customize their algorithms, ensuring efficient and effective development.
Users can refer to the unittest of these model templates to learn how to use them.
"""
import math
from dataclasses import dataclass
from typing import Callable, List, Optional, Tuple, Sequence
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from ditk import logging
# Assuming these imports are valid in the user's environment.
# If they are not, they should be replaced with the correct ones.
from ding.torch_utils import MLP, ResBlock
from ding.torch_utils.network.normalization import build_normalization
from ding.utils import SequenceType, get_rank, get_world_size
from transformers import AutoModelForCausalLM, AutoTokenizer
from ding.utils import set_pkg_seed, get_rank, get_world_size
def MLP_V2(
in_channels: int,
hidden_channels: List[int],
out_channels: int,
layer_fn: Callable = nn.Linear,
activation: Optional[nn.Module] = None,
norm_type: Optional[str] = None,
use_dropout: bool = False,
dropout_probability: float = 0.5,
output_activation: bool = True,
output_norm: bool = True,
last_linear_layer_init_zero: bool = False,
) -> nn.Sequential:
"""
Overview:
Creates a multi-layer perceptron (MLP) using a list of hidden dimensions. Each layer consists of a fully
connected block with optional activation, normalization, and dropout. The final layer is configurable
to include or exclude activation and normalization.
Arguments:
- in_channels (:obj:`int`): Number of input channels (dimensionality of the input tensor).
- hidden_channels (:obj:`List[int]`): A list specifying the number of channels for each hidden layer.
- out_channels (:obj:`int`): Number of output channels (dimensionality of the output tensor).
- layer_fn (:obj:`Callable`): The function to construct layers, defaults to `nn.Linear`.
- activation (:obj:`Optional[nn.Module]`): Activation function to use after each layer, defaults to None.
- norm_type (:obj:`Optional[str]`): Type of normalization to apply. If None, no normalization is applied.
- use_dropout (:obj:`bool`): Whether to apply dropout after each layer, defaults to False.
- dropout_probability (:obj:`float`): The probability for dropout, defaults to 0.5.
- output_activation (:obj:`bool`): Whether to apply activation to the output layer, defaults to True.
- output_norm (:obj:`bool`): Whether to apply normalization to the output layer, defaults to True.
- last_linear_layer_init_zero (:obj:`bool`): Whether to initialize the last linear layer's weights and biases to zero.
Returns:
- block (:obj:`nn.Sequential`): A PyTorch `nn.Sequential` object containing the layers of the MLP.
"""
if not hidden_channels:
logging.warning("hidden_channels is empty, creating a single-layer MLP.")
layers = []
all_channels = [in_channels] + hidden_channels + [out_channels]
num_layers = len(all_channels) - 1
for i in range(num_layers):
is_last_layer = (i == num_layers - 1)
layers.append(layer_fn(all_channels[i], all_channels[i+1]))
if not is_last_layer:
# Intermediate layers
if norm_type:
layers.append(build_normalization(norm_type, dim=1)(all_channels[i+1]))
if activation:
layers.append(activation)
if use_dropout:
layers.append(nn.Dropout(dropout_probability))
else:
# Last layer
if output_norm and norm_type:
layers.append(build_normalization(norm_type, dim=1)(all_channels[i+1]))
if output_activation and activation:
layers.append(activation)
# Note: Dropout on the final output is usually not recommended unless for specific regularization purposes.
# The original logic applied it, so we keep it for consistency.
if use_dropout:
layers.append(nn.Dropout(dropout_probability))
# Initialize the last linear layer to zero if specified
if last_linear_layer_init_zero:
for layer in reversed(layers):
if isinstance(layer, nn.Linear):
nn.init.zeros_(layer.weight)
nn.init.zeros_(layer.bias)
break
return nn.Sequential(*layers)
# --- Data-structures for Network Outputs ---
@dataclass
class MZRNNNetworkOutput:
"""
Overview:
Data structure for the output of the MuZeroRNN model.
"""
value: torch.Tensor
value_prefix: torch.Tensor
policy_logits: torch.Tensor
latent_state: torch.Tensor
predict_next_latent_state: torch.Tensor
reward_hidden_state: Tuple[torch.Tensor, torch.Tensor]
@dataclass
class EZNetworkOutput:
"""
Overview:
Data structure for the output of the EfficientZero model.
"""
value: torch.Tensor
value_prefix: torch.Tensor
policy_logits: torch.Tensor
latent_state: torch.Tensor
reward_hidden_state: Tuple[torch.Tensor, torch.Tensor]
@dataclass
class MZNetworkOutput:
"""
Overview:
Data structure for the output of the MuZero model.
"""
value: torch.Tensor
reward: torch.Tensor
policy_logits: torch.Tensor
latent_state: torch.Tensor
# --- Core Network Components ---
[docs]class SimNorm(nn.Module):
"""
Overview:
Implements Simplicial Normalization as described in the paper: https://arxiv.org/abs/2204.00616.
It groups features and applies softmax to each group.
"""
[docs] def __init__(self, simnorm_dim: int) -> None:
"""
Arguments:
- simnorm_dim (:obj:`int`): The size of each group (simplex) to apply softmax over.
"""
super().__init__()
self.dim = simnorm_dim
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Overview:
Forward pass for SimNorm.
Arguments:
- x (:obj:`torch.Tensor`): The input tensor.
Returns:
- (:obj:`torch.Tensor`): The tensor after applying Simplicial Normalization.
"""
if x.shape[1] == 0:
return x
# Reshape to (batch, groups, dim)
x_reshaped = x.view(*x.shape[:-1], -1, self.dim)
# Apply softmax over the last dimension (the simplex)
x_softmax = F.softmax(x_reshaped, dim=-1)
# Reshape back to the original tensor shape
return x_softmax.view(*x.shape)
def __repr__(self) -> str:
return f"SimNorm(dim={self.dim})"
def AvgL1Norm(x: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
"""
Overview:
Normalizes a tensor by the mean of its absolute values (L1 norm) along the last dimension.
Arguments:
- x (:obj:`torch.Tensor`): The input tensor to normalize.
- eps (:obj:`float`): A small epsilon value to prevent division by zero.
Returns:
- (:obj:`torch.Tensor`): The normalized tensor.
"""
return x / (x.abs().mean(dim=-1, keepdim=True) + eps)
[docs]class FeatureAndGradientHook:
"""
Overview:
A utility class to capture and analyze features and gradients of a specific module during
the forward and backward passes. This is useful for debugging and understanding model dynamics.
"""
[docs] def __init__(self, module: nn.Module):
"""
Arguments:
- module (:obj:`nn.Module`): The PyTorch module to attach the hooks to.
"""
self.features_before = []
self.features_after = []
self.grads_before = []
self.grads_after = []
self.forward_handler = module.register_forward_hook(self._forward_hook)
self.backward_handler = module.register_full_backward_hook(self._backward_hook)
[docs] def _forward_hook(self, module: nn.Module, inputs: Tuple[torch.Tensor], output: torch.Tensor) -> None:
"""Hook to capture input and output features during the forward pass."""
with torch.no_grad():
self.features_before.append(inputs[0].clone().detach())
self.features_after.append(output.clone().detach())
[docs] def _backward_hook(self, module: nn.Module, grad_inputs: Tuple[torch.Tensor], grad_outputs: Tuple[torch.Tensor]) -> None:
"""Hook to capture input and output gradients during the backward pass."""
with torch.no_grad():
self.grads_before.append(grad_inputs[0].clone().detach() if grad_inputs[0] is not None else None)
self.grads_after.append(grad_outputs[0].clone().detach() if grad_outputs[0] is not None else None)
[docs] def analyze(self) -> Tuple[float, float, float, float]:
"""
Overview:
Analyzes the captured features and gradients by computing their average L2 norms.
This method clears the stored data after analysis to free memory.
Returns:
- (:obj:`Tuple[float, float, float, float]`): A tuple containing the L2 norms of
(features_before, features_after, grads_before, grads_after).
"""
if not self.features_before:
return 0.0, 0.0, 0.0, 0.0
l2_norm_before = torch.mean(torch.stack([torch.norm(f, p=2) for f in self.features_before])).item()
l2_norm_after = torch.mean(torch.stack([torch.norm(f, p=2) for f in self.features_after])).item()
valid_grads_before = [g for g in self.grads_before if g is not None]
grad_norm_before = torch.mean(torch.stack([torch.norm(g, p=2) for g in valid_grads_before])).item() if valid_grads_before else 0.0
valid_grads_after = [g for g in self.grads_after if g is not None]
grad_norm_after = torch.mean(torch.stack([torch.norm(g, p=2) for g in valid_grads_after])).item() if valid_grads_after else 0.0
self.clear_data()
return l2_norm_before, l2_norm_after, grad_norm_before, grad_norm_after
[docs] def clear_data(self) -> None:
"""Clears all stored feature and gradient tensors to free up memory."""
self.features_before.clear()
self.features_after.clear()
self.grads_before.clear()
self.grads_after.clear()
if torch.cuda.is_available():
torch.cuda.empty_cache()
[docs] def remove_hooks(self) -> None:
"""Removes the registered forward and backward hooks."""
self.forward_handler.remove()
self.backward_handler.remove()
[docs]class DownSample(nn.Module):
"""
Overview:
A convolutional network for downsampling image-based observations, commonly used in Atari environments.
It consists of a series of convolutional, normalization, and residual blocks.
"""
[docs] def __init__(
self,
observation_shape: Sequence[int],
out_channels: int,
activation: nn.Module = nn.ReLU(inplace=True),
norm_type: str = 'BN',
num_resblocks: int = 1,
) -> None:
"""
Arguments:
- observation_shape (:obj:`Sequence[int]`): The shape of the input observation, e.g., (C, H, W).
- out_channels (:obj:`int`): The number of output channels.
- activation (:obj:`nn.Module`): The activation function to use.
- norm_type (:obj:`str`): The type of normalization ('BN' or 'LN').
- num_resblocks (:obj:`int`): The number of residual blocks in each stage.
"""
super().__init__()
if norm_type not in ['BN', 'LN']:
raise ValueError(f"Unsupported norm_type: {norm_type}. Must be 'BN' or 'LN'.")
# The original design was fixed to 1 resblock per stage.
if num_resblocks != 1:
logging.warning(f"DownSample is designed for num_resblocks=1, but got {num_resblocks}.")
self.observation_shape = observation_shape
self.activation = activation
# Initial convolution: stride 2
self.conv1 = nn.Conv2d(observation_shape[0], out_channels // 2, kernel_size=3, stride=2, padding=1, bias=False)
self.norm1 = build_normalization(norm_type, dim=2)(out_channels // 2)
# Stage 1 with residual blocks
self.resblocks1 = nn.ModuleList([
ResBlock(in_channels=out_channels // 2, activation=activation, norm_type=norm_type, res_type='basic', bias=False)
for _ in range(num_resblocks)
])
# Downsample block: stride 2
self.downsample_block = ResBlock(in_channels=out_channels // 2, out_channels=out_channels, activation=activation, norm_type=norm_type, res_type='downsample', bias=False)
# Stage 2 with residual blocks
self.resblocks2 = nn.ModuleList([
ResBlock(in_channels=out_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False)
for _ in range(num_resblocks)
])
# Pooling 1: stride 2
self.pooling1 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
# Stage 3 with residual blocks
self.resblocks3 = nn.ModuleList([
ResBlock(in_channels=out_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False)
for _ in range(num_resblocks)
])
# Final pooling for specific input sizes
self.pooling2 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Shapes:
- x (:obj:`torch.Tensor`): (B, C_in, H, W)
- output (:obj:`torch.Tensor`): (B, C_out, H_out, W_out)
"""
x = self.conv1(x)
x = self.norm1(x)
x = self.activation(x)
for block in self.resblocks1:
x = block(x)
x = self.downsample_block(x)
for block in self.resblocks2:
x = block(x)
x = self.pooling1(x)
for block in self.resblocks3:
x = block(x)
# This part handles specific Atari resolutions. A more general approach might be desirable,
# but we maintain original behavior.
obs_height = self.observation_shape[1]
if obs_height == 64:
return x
elif obs_height in [84, 96]:
return self.pooling2(x)
else:
raise NotImplementedError(
f"DownSample for observation height {obs_height} is not implemented. "
f"Supported heights are 64, 84, 96."
)
class QwenNetwork(nn.Module):
def __init__(self,
model_path: str = 'Qwen/Qwen3-1.7B',
embedding_size: int = 768,
final_norm_option_in_encoder: str = "layernorm",
group_size: int = 8,
tokenizer=None):
super().__init__()
logging.info(f"Loading Qwen model from: {model_path}")
local_rank = get_rank()
if local_rank == 0:
self.pretrained_model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype="auto",
device_map={"": local_rank},
attn_implementation="flash_attention_2"
)
if get_world_size() > 1:
torch.distributed.barrier()
if local_rank != 0:
self.pretrained_model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype="auto",
device_map={"": local_rank},
attn_implementation="flash_attention_2"
)
for p in self.pretrained_model.parameters():
p.requires_grad = False
if tokenizer is None:
if local_rank == 0:
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
if get_world_size() > 1:
torch.distributed.barrier()
if local_rank != 0:
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
else:
self.tokenizer = tokenizer
qwen_hidden_size = self.pretrained_model.config.hidden_size
self.embedding_head = nn.Sequential(
nn.Linear(qwen_hidden_size, embedding_size),
self._create_norm_layer(final_norm_option_in_encoder, embedding_size, group_size)
)
def _create_norm_layer(self, norm_option, embedding_size, group_size):
if norm_option.lower() == "simnorm":
return SimNorm(simnorm_dim=group_size)
elif norm_option.lower() == "layernorm":
return nn.LayerNorm(embedding_size)
else:
raise NotImplementedError(f"Normalization type '{norm_option}' is not implemented.")
def encode(self, x: torch.Tensor, no_grad: bool = True) -> torch.Tensor:
"""
Overview:
Encode the input token sequence `x` into a latent representation
using a pretrained language model backbone followed by a projection head.
Arguments:
- x (:obj:`torch.Tensor`): Input token ids of shape (B, L)
- no_grad (:obj:`bool`, optional, default=True): If True, encoding is performed under `torch.no_grad()` to save memory and computation (no gradient tracking).
Returns:
- latent (:obj:`torch.Tensor`): Encoded latent state of shape (B, D).
"""
pad_id = self.tokenizer.pad_token_id
attention_mask = (x != pad_id).long().to(x.device)
context = {'input_ids': x.long(), 'attention_mask': attention_mask}
if no_grad:
with torch.no_grad():
outputs = self.pretrained_model(**context, output_hidden_states=True, return_dict=True)
else:
outputs = self.pretrained_model(**context, output_hidden_states=True, return_dict=True)
last_hidden = outputs.hidden_states[-1]
B, L, H = last_hidden.size()
lengths = attention_mask.sum(dim=1) # [B]
positions = torch.clamp(lengths - 1, min=0) # [B]
batch_idx = torch.arange(B, device=last_hidden.device)
selected = last_hidden[batch_idx, positions] # [B, H]
latent = self.embedding_head(selected.to(self.embedding_head[0].weight.dtype))
return latent
def decode(self, embeddings: torch.Tensor, max_length: int = 512) -> str:
"""
Decodes embeddings into text via the decoder network.
"""
embeddings_detached = embeddings.detach()
self.pretrained_model.eval()
# Directly generate using provided embeddings
with torch.no_grad():
param = next(self.pretrained_model.parameters())
embeddings = embeddings_detached.to(device=param.device, dtype=param.dtype)
gen_ids = self.pretrained_model.generate(
inputs_embeds=embeddings,
max_length=max_length
)
texts = self.tokenizer.batch_decode(gen_ids, skip_special_tokens=True)
self.pretrained_model.train()
return texts[0] if len(texts) == 1 else texts
def forward(self, x: torch.Tensor, no_grad: bool = True) -> torch.Tensor:
return self.encode(x, no_grad=no_grad)
class HFLanguageRepresentationNetwork(nn.Module):
def __init__(self,
model_path: str = 'google-bert/bert-base-uncased',
embedding_size: int = 768,
group_size: int = 8,
final_norm_option_in_encoder: str = "layernorm",
tokenizer=None):
"""
Arguments:
- model_path (str): The path to the pretrained Hugging Face model. Default is 'google-bert/bert-base-uncased'.
- embedding_size (int): The dimension of the output embeddings. Default is 768.
- group_size (int): The group size for SimNorm when using normalization.
- final_norm_option_in_encoder (str): The type of normalization to use ("simnorm" or "layernorm"). Default is "layernorm".
- tokenizer (Optional): An instance of a tokenizer. If None, the tokenizer will be loaded from the pretrained model.
"""
super().__init__()
from transformers import AutoModel, AutoTokenizer
# In distributed settings, ensure only rank 0 downloads the model/tokenizer.
if get_rank() == 0:
self.pretrained_model = AutoModel.from_pretrained(model_path)
if get_world_size() > 1:
# Wait for rank 0 to finish loading the model.
torch.distributed.barrier()
if get_rank() != 0:
self.pretrained_model = AutoModel.from_pretrained(model_path)
if get_rank() != 0:
logging.info(f"Worker process is loading model from cache: {model_path}")
self.model = AutoModel.from_pretrained(model_path)
if tokenizer is None:
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
if tokenizer is not None:
self.tokenizer = tokenizer
self.embedding_size = embedding_size
self.embed_proj_head = nn.Linear(self.pretrained_model.config.hidden_size, self.embedding_size)
# Select the normalization method based on the final_norm_option_in_encoder parameter.
if final_norm_option_in_encoder.lower() == "simnorm":
self.norm = SimNorm(simnorm_dim=group_size)
elif final_norm_option_in_encoder.lower() == "layernorm":
self.norm = nn.LayerNorm(embedding_size)
else:
raise NotImplementedError(f"Normalization type '{final_norm_option_in_encoder}' is not implemented. "
f"Choose 'simnorm' or 'layernorm'.")
def forward(self, x: torch.Tensor, no_grad: bool = True) -> torch.Tensor:
"""
Overview:
Computes language representation from input token IDs.
Arguments:
- x (:obj:`torch.Tensor`): Input token sequence of shape (B, seq_len).
- no_grad (:obj:`bool`): If True, run the transformer model in `torch.no_grad()` context.
Returns:
- (:obj:`torch.Tensor`): The final language embedding of shape (B, embedding_size).
"""
# Construct the attention mask to exclude padding tokens.
attention_mask = x != self.tokenizer.pad_token_id
if no_grad:
with torch.no_grad():
x = x.long() # Ensure the input tensor is of type long.
outputs = self.pretrained_model(x, attention_mask=attention_mask)
# Get the hidden state from the last layer and select the output corresponding to the [CLS] token.
cls_embedding = outputs.last_hidden_state[:, 0, :]
else:
x = x.long()
outputs = self.pretrained_model(x, attention_mask=attention_mask)
cls_embedding = outputs.last_hidden_state[:, 0, :]
cls_embedding = self.embed_proj_head(cls_embedding)
cls_embedding = self.norm(cls_embedding)
return cls_embedding
[docs]class RepresentationNetworkUniZero(nn.Module):
[docs] def __init__(
self,
observation_shape: SequenceType = (3, 64, 64),
num_res_blocks: int = 1,
num_channels: int = 64,
downsample: bool = True,
activation: nn.Module = nn.GELU(approximate='tanh'),
norm_type: str = 'BN',
embedding_dim: int = 256,
group_size: int = 8,
final_norm_option_in_encoder: str = 'LayerNorm', # TODO
) -> None:
"""
Overview:
Representation network used in UniZero. Encode the 2D image obs into latent state.
Currently, the network only supports obs images with both a width and height of 64.
Arguments:
- observation_shape (:obj:`SequenceType`): The shape of observation space, e.g. [C, W, H]=[3, 64, 64]
for video games like atari, RGB 3 channel.
- num_res_blocks (:obj:`int`): The number of residual blocks.
- num_channels (:obj:`int`): The channel of output hidden state.
- downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``, \
defaults to True. This option is often used in video games like Atari. In board games like go, \
we don't need this module.
- activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.ReLU(inplace=True). \
Use the inplace operation to speed up.
- norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'.
- embedding_dim (:obj:`int`): The dimension of the latent state.
- group_size (:obj:`int`): The dimension for simplicial normalization.
- final_norm_option_in_encoder (:obj:`str`): The normalization option for the final layer, defaults to 'SimNorm'. \
Options are 'SimNorm' and 'LayerNorm'.
"""
super().__init__()
assert norm_type in ['BN', 'LN'], "norm_type must in ['BN', 'LN']"
# Only log from rank 0 to avoid excessive output in distributed training
from ding.utils import get_rank
if get_rank() == 0:
logging.info(f"Using norm type: {norm_type}")
logging.info(f"Using activation type: {activation}")
self.observation_shape = observation_shape
self.downsample = downsample
if self.downsample:
self.downsample_net = DownSample(
observation_shape,
num_channels,
activation=activation,
norm_type=norm_type,
)
else:
self.conv = nn.Conv2d(observation_shape[0], num_channels, kernel_size=3, stride=1, padding=1, bias=False)
if norm_type == 'BN':
self.norm = nn.BatchNorm2d(num_channels)
elif norm_type == 'LN':
if downsample:
self.norm = nn.LayerNorm(
[num_channels, math.ceil(observation_shape[-2] / 16), math.ceil(observation_shape[-1] / 16)],
eps=1e-5)
else:
self.norm = nn.LayerNorm([num_channels, observation_shape[-2], observation_shape[-1]], eps=1e-5)
self.resblocks = nn.ModuleList(
[
ResBlock(
in_channels=num_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False
) for _ in range(num_res_blocks)
]
)
self.activation = activation
self.embedding_dim = embedding_dim
# ==================== Modification Start ====================
if self.observation_shape[1] == 64:
# Fix: Replace hardcoded 64 with num_channels
self.last_linear = nn.Linear(num_channels * 8 * 8, self.embedding_dim, bias=False)
elif self.observation_shape[1] in [84, 96]:
# Fix: Replace hardcoded 64 with num_channels
self.last_linear = nn.Linear(num_channels * 6 * 6, self.embedding_dim, bias=False)
# ==================== Modification End ====================
self.final_norm_option_in_encoder=final_norm_option_in_encoder
# Initialize final_norm uniformly in __init__
if self.final_norm_option_in_encoder in ['LayerNorm', 'LayerNorm_Tanh']:
self.final_norm = nn.LayerNorm(self.embedding_dim, eps=1e-5)
elif self.final_norm_option_in_encoder == 'LayerNormNoAffine':
self.final_norm = nn.LayerNorm(
self.embedding_dim, eps=1e-5, elementwise_affine=False
)
elif self.final_norm_option_in_encoder == 'SimNorm':
# Ensure SimNorm is defined
self.final_norm = SimNorm(simnorm_dim=group_size)
elif self.final_norm_option_in_encoder == 'L2Norm':
# Directly instantiate our custom L2Norm module
self.final_norm = L2Norm(eps=1e-6)
elif self.final_norm_option_in_encoder is None:
# If no normalization is needed, set to nn.Identity() or None
self.final_norm = nn.Identity()
else:
raise ValueError(f"Unsupported final_norm_option_in_encoder: {self.final_norm_option_in_encoder}")
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Shapes:
- x (:obj:`torch.Tensor`): :math:`(B, C_in, W, H)`, where B is batch size, C_in is channel, W is width, \
H is height.
- output (:obj:`torch.Tensor`): :math:`(B, C_out, W_, H_)`, where B is batch size, C_out is channel, W_ is \
output width, H_ is output height.
"""
if self.downsample:
x = self.downsample_net(x)
else:
x = self.conv(x)
x = self.norm(x)
x = self.activation(x)
for block in self.resblocks:
x = block(x)
# Important: Transform the output feature plane to the latent state.
# For example, for an Atari feature plane of shape (64, 8, 8),
# flattening results in a size of 4096, which is then transformed to 768.
x = self.last_linear(x.view(x.size(0), -1))
x = x.view(-1, self.embedding_dim)
# NOTE: very important for training stability.
# x = self.final_norm(x)
# Uniformly call self.final_norm in forward
# This structure is clearer and more extensible
if self.final_norm is not None:
x = self.final_norm(x)
# Special handling for LayerNorm_Tanh
if self.final_norm_option_in_encoder == 'LayerNorm_Tanh':
x = torch.tanh(x)
return x
[docs]class RepresentationNetwork(nn.Module):
"""
Overview:
The standard representation network used in MuZero. It encodes a 2D image observation
into a latent state, which retains its spatial dimensions.
"""
[docs] def __init__(
self,
observation_shape: Sequence[int] = (4, 96, 96),
num_res_blocks: int = 1,
num_channels: int = 64,
downsample: bool = True,
activation: nn.Module = nn.ReLU(inplace=True),
norm_type: str = 'BN',
use_sim_norm: bool = False,
group_size: int = 8,
) -> None:
"""
Arguments:
- observation_shape (:obj:`Sequence[int]`): Shape of the input observation (C, H, W).
- num_res_blocks (:obj:`int`): The number of residual blocks.
- num_channels (:obj:`int`): The number of channels in the convolutional layers.
- downsample (:obj:`bool`): Whether to use the `DownSample` module.
- activation (:obj:`nn.Module`): The activation function to use.
- norm_type (:obj:`str`): Normalization type ('BN' or 'LN').
- use_sim_norm (:obj:`bool`): Whether to apply a final `SimNorm` layer.
- group_size (:obj:`int`): Group size for `SimNorm`.
"""
super().__init__()
if norm_type not in ['BN', 'LN']:
raise ValueError(f"Unsupported norm_type: {norm_type}. Must be 'BN' or 'LN'.")
self.downsample = downsample
self.activation = activation
if self.downsample:
self.downsample_net = DownSample(observation_shape, num_channels, activation, norm_type)
else:
self.conv = nn.Conv2d(observation_shape[0], num_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.norm = build_normalization(norm_type, dim=2)(num_channels)
self.resblocks = nn.ModuleList([
ResBlock(in_channels=num_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False)
for _ in range(num_res_blocks)
])
self.use_sim_norm = use_sim_norm
if self.use_sim_norm:
self.sim_norm = SimNorm(simnorm_dim=group_size)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Shapes:
- x (:obj:`torch.Tensor`): (B, C_in, H, W)
- output (:obj:`torch.Tensor`): (B, C_out, H_out, W_out)
"""
if self.downsample:
x = self.downsample_net(x)
else:
x = self.conv(x)
x = self.norm(x)
x = self.activation(x)
for block in self.resblocks:
x = block(x)
if self.use_sim_norm:
# Flatten the spatial dimensions, apply SimNorm, and then reshape back.
b, c, h, w = x.shape
x_flat = x.view(b, c * h * w)
x_norm = self.sim_norm(x_flat)
x = x_norm.view(b, c, h, w)
return x
[docs]class RepresentationNetworkMLP(nn.Module):
"""
Overview:
An MLP-based representation network for encoding vector observations into a latent state.
"""
[docs] def __init__(
self,
observation_dim: int,
hidden_channels: int = 64,
num_layers: int = 2,
activation: nn.Module = nn.GELU(approximate='tanh'),
norm_type: Optional[str] = 'BN',
group_size: int = 8,
final_norm_option_in_encoder: str = 'LayerNorm', # TODO
) -> torch.Tensor:
"""
Arguments:
- observation_dim (:obj:`int`): The dimension of the input vector observation.
- hidden_channels (:obj:`int`): The number of neurons in the hidden and output layers.
- num_layers (:obj:`int`): The total number of layers in the MLP.
- activation (:obj:`nn.Module`): The activation function to use.
- norm_type (:obj:`Optional[str]`): The type of normalization ('BN', 'LN', or None).
- group_size (:obj:`int`): The group size for the final `SimNorm` layer.
"""
super().__init__()
# Creating hidden layers list for MLP_V2
hidden_layers = [hidden_channels] * (num_layers - 1) if num_layers > 1 else []
self.fc_representation = MLP_V2(
in_channels=observation_dim,
hidden_channels=hidden_layers,
out_channels=hidden_channels,
activation=activation,
norm_type=norm_type,
output_activation=False,
output_norm=False,
last_linear_layer_init_zero=True,
)
# Select the normalization method based on the final_norm_option_in_encoder parameter.
if final_norm_option_in_encoder.lower() == "simnorm":
self.norm = SimNorm(simnorm_dim=group_size)
elif final_norm_option_in_encoder.lower() == "layernorm":
self.norm = nn.LayerNorm(hidden_channels)
else:
raise NotImplementedError(f"Normalization type '{final_norm_option_in_encoder}' is not implemented. "
f"Choose 'simnorm' or 'layernorm'.")
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Shapes:
- x (:obj:`torch.Tensor`): (B, observation_dim)
- output (:obj:`torch.Tensor`): (B, hidden_channels)
"""
x = self.fc_representation(x)
x = self.norm(x)
return x
[docs]class LatentDecoder(nn.Module):
"""
Overview:
A decoder network that reconstructs a 2D image from a 1D latent embedding.
It acts as the inverse of a representation network like `RepresentationNetworkUniZero`.
"""
[docs] def __init__(
self,
embedding_dim: int,
output_shape: Tuple[int, int, int],
num_channels: int = 64,
activation: nn.Module = nn.GELU(approximate='tanh')
):
"""
Arguments:
- embedding_dim (:obj:`int`): The dimension of the input latent embedding.
- output_shape (:obj:`Tuple[int, int, int]`): The shape of the target output image (C, H, W).
- num_channels (:obj:`int`): The base number of channels for the initial upsampling stage.
- activation (:obj:`nn.Module`): The activation function to use.
"""
super().__init__()
self.embedding_dim = embedding_dim
self.output_shape = output_shape
# This should match the spatial size of the encoder's feature map before flattening.
# Assuming a total downsampling factor of 8 (e.g., for a 64x64 -> 8x8 encoder).
self.initial_h = output_shape[1] // 8
self.initial_w = output_shape[2] // 8
self.initial_size = (num_channels, self.initial_h, self.initial_w)
self.fc = nn.Linear(embedding_dim, np.prod(self.initial_size))
self.deconv_blocks = nn.Sequential(
# Block 1: (C, H/8, W/8) -> (C/2, H/4, W/4)
nn.ConvTranspose2d(num_channels, num_channels // 2, kernel_size=3, stride=2, padding=1, output_padding=1),
activation,
nn.BatchNorm2d(num_channels // 2),
# Block 2: (C/2, H/4, W/4) -> (C/4, H/2, W/2)
nn.ConvTranspose2d(num_channels // 2, num_channels // 4, kernel_size=3, stride=2, padding=1, output_padding=1),
activation,
nn.BatchNorm2d(num_channels // 4),
# Block 3: (C/4, H/2, W/2) -> (output_C, H, W)
nn.ConvTranspose2d(num_channels // 4, output_shape[0], kernel_size=3, stride=2, padding=1, output_padding=1),
# A final activation like Sigmoid or Tanh is often used if pixel values are in a fixed range [0,1] or [-1,1].
# We omit it here to maintain consistency with the original code.
)
[docs] def forward(self, embeddings: torch.Tensor) -> torch.Tensor:
"""
Shapes:
- embeddings (:obj:`torch.Tensor`): (B, embedding_dim)
- output (:obj:`torch.Tensor`): (B, C, H, W)
"""
x = self.fc(embeddings)
x = x.view(-1, *self.initial_size)
x = self.deconv_blocks(x)
return x
# --- Networks for MemoryEnv ---
# --- Prediction Networks ---
[docs]class PredictionNetwork(nn.Module):
[docs] def __init__(
self,
observation_shape: SequenceType,
action_space_size: int,
num_res_blocks: int,
num_channels: int,
value_head_channels: int,
policy_head_channels: int,
value_head_hidden_channels: int,
policy_head_hidden_channels: int,
output_support_size: int,
flatten_input_size_for_value_head: int,
flatten_input_size_for_policy_head: int,
downsample: bool = False,
last_linear_layer_init_zero: bool = True,
activation: nn.Module = nn.ReLU(inplace=True),
norm_type: Optional[str] = 'BN',
) -> None:
"""
Overview:
The definition of policy and value prediction network, which is used to predict value and policy by the
given latent state.
Arguments:
- observation_shape (:obj:`SequenceType`): The shape of observation space, e.g. (C, H, W) for image.
- action_space_size: (:obj:`int`): Action space size, usually an integer number for discrete action space.
- num_res_blocks (:obj:`int`): The number of res blocks in AlphaZero model.
- num_channels (:obj:`int`): The channels of hidden states.
- value_head_channels (:obj:`int`): The channels of value head.
- policy_head_channels (:obj:`int`): The channels of policy head.
- value_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head).
- policy_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head).
- output_support_size (:obj:`int`): The size of categorical value output.
- self_supervised_learning_loss (:obj:`bool`): Whether to use self_supervised_learning related networks \
- flatten_input_size_for_value_head (:obj:`int`): The size of flatten hidden states, i.e. the input size \
of the value head.
- flatten_input_size_for_policy_head (:obj:`int`): The size of flatten hidden states, i.e. the input size \
of the policy head.
- downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``.
- last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of \
dynamics/prediction mlp, default sets it to True.
- activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \
operation to speedup, e.g. ReLU(inplace=True).
- norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'.
"""
super(PredictionNetwork, self).__init__()
assert norm_type in ['BN', 'LN'], "norm_type must in ['BN', 'LN']"
self.resblocks = nn.ModuleList(
[
ResBlock(
in_channels=num_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False
) for _ in range(num_res_blocks)
]
)
self.conv1x1_value = nn.Conv2d(num_channels, value_head_channels, 1)
self.conv1x1_policy = nn.Conv2d(num_channels, policy_head_channels, 1)
if observation_shape[1] == 96:
latent_shape = (observation_shape[1] / 16, observation_shape[2] / 16)
elif observation_shape[1] == 64:
latent_shape = (observation_shape[1] / 8, observation_shape[2] / 8)
if norm_type == 'BN':
self.norm_value = nn.BatchNorm2d(value_head_channels)
self.norm_policy = nn.BatchNorm2d(policy_head_channels)
elif norm_type == 'LN':
if downsample:
self.norm_value = nn.LayerNorm(
[value_head_channels, *latent_shape],
eps=1e-5)
self.norm_policy = nn.LayerNorm([policy_head_channels, *latent_shape], eps=1e-5)
else:
self.norm_value = nn.LayerNorm([value_head_channels, observation_shape[-2], observation_shape[-1]],
eps=1e-5)
self.norm_policy = nn.LayerNorm([policy_head_channels, observation_shape[-2], observation_shape[-1]],
eps=1e-5)
self.flatten_input_size_for_value_head = flatten_input_size_for_value_head
self.flatten_input_size_for_policy_head = flatten_input_size_for_policy_head
self.activation = activation
self.fc_value = MLP_V2(
in_channels=self.flatten_input_size_for_value_head,
hidden_channels=value_head_hidden_channels,
out_channels=output_support_size,
activation=self.activation,
norm_type=norm_type,
output_activation=False,
output_norm=False,
# last_linear_layer_init_zero=True is beneficial for convergence speed.
last_linear_layer_init_zero=last_linear_layer_init_zero
)
self.fc_policy = MLP_V2(
in_channels=self.flatten_input_size_for_policy_head,
hidden_channels=policy_head_hidden_channels,
out_channels=action_space_size,
activation=self.activation,
norm_type=norm_type,
output_activation=False,
output_norm=False,
# last_linear_layer_init_zero=True is beneficial for convergence speed.
last_linear_layer_init_zero=last_linear_layer_init_zero
)
[docs] def forward(self, latent_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Overview:
Forward computation of the prediction network.
Arguments:
- latent_state (:obj:`torch.Tensor`): input tensor with shape (B, latent_state_dim).
Returns:
- policy (:obj:`torch.Tensor`): policy tensor with shape (B, action_space_size).
- value (:obj:`torch.Tensor`): value tensor with shape (B, output_support_size).
"""
for res_block in self.resblocks:
latent_state = res_block(latent_state)
value = self.conv1x1_value(latent_state)
value = self.norm_value(value)
value = self.activation(value)
policy = self.conv1x1_policy(latent_state)
policy = self.norm_policy(policy)
policy = self.activation(policy)
value = value.reshape(-1, self.flatten_input_size_for_value_head)
policy = policy.reshape(-1, self.flatten_input_size_for_policy_head)
value = self.fc_value(value)
policy = self.fc_policy(policy)
return policy, value
[docs]class PredictionNetworkMLP(nn.Module):
"""
Overview:
An MLP-based prediction network that predicts policy and value from a 1D latent state.
"""
[docs] def __init__(
self,
action_space_size: int,
num_channels: int,
common_layer_num: int = 2,
value_head_hidden_channels: List[int] = [32],
policy_head_hidden_channels: List[int] = [32],
output_support_size: int = 601,
last_linear_layer_init_zero: bool = True,
activation: nn.Module = nn.ReLU(inplace=True),
norm_type: Optional[str] = 'BN',
):
"""
Arguments:
- action_space_size: (:obj:`int`): The size of the action space.
- num_channels (:obj:`int`): The dimension of the input latent state.
- common_layer_num (:obj:`int`): Number of layers in the shared backbone MLP.
- value_head_hidden_channels (:obj:`List[int]`): Hidden layer sizes for the value MLP head.
- policy_head_hidden_channels (:obj:`List[int]`): Hidden layer sizes for the policy MLP head.
- output_support_size (:obj:`int`): The size of the categorical value distribution.
- last_linear_layer_init_zero (:obj:`bool`): Whether to initialize the last layer of heads to zero.
- activation (:obj:`nn.Module`): The activation function.
- norm_type (:obj:`Optional[str]`): The normalization type.
"""
super().__init__()
common_hidden = [num_channels] * (common_layer_num - 1) if common_layer_num > 1 else []
self.fc_prediction_common = MLP_V2(
in_channels=num_channels,
hidden_channels=common_hidden,
out_channels=num_channels,
activation=activation,
norm_type=norm_type,
output_activation=True,
output_norm=True,
last_linear_layer_init_zero=False,
)
self.fc_value_head = MLP_V2(
in_channels=num_channels,
hidden_channels=value_head_hidden_channels,
out_channels=output_support_size,
activation=activation,
norm_type=norm_type,
output_activation=False,
output_norm=False,
last_linear_layer_init_zero=last_linear_layer_init_zero
)
self.fc_policy_head = MLP_V2(
in_channels=num_channels,
hidden_channels=policy_head_hidden_channels,
out_channels=action_space_size,
activation=activation,
norm_type=norm_type,
output_activation=False,
output_norm=False,
last_linear_layer_init_zero=last_linear_layer_init_zero
)
[docs] def forward(self, latent_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Shapes:
- latent_state (:obj:`torch.Tensor`): (B, num_channels)
- policy_logits (:obj:`torch.Tensor`): (B, action_space_size)
- value (:obj:`torch.Tensor`): (B, output_support_size)
"""
x = self.fc_prediction_common(latent_state)
value = self.fc_value_head(x)
policy_logits = self.fc_policy_head(x)
return policy_logits, value