Shortcuts

Source code for ding.torch_utils.network.res_block

from typing import Union

import torch
import torch.nn as nn

from .nn_module import conv2d_block, fc_block


[docs]class ResBlock(nn.Module): """ Overview: Residual Block with 2D convolution layers, including 3 types: basic block: input channel: C x -> 3*3*C -> norm -> act -> 3*3*C -> norm -> act -> out \__________________________________________/+ bottleneck block: x -> 1*1*(1/4*C) -> norm -> act -> 3*3*(1/4*C) -> norm -> act -> 1*1*C -> norm -> act -> out \_____________________________________________________________________________/+ downsample block: used in EfficientZero input channel: C x -> 3*3*C -> norm -> act -> 3*3*C -> norm -> act -> out \__________________ 3*3*C ____________________/+ .. note:: You can refer to `Deep Residual Learning for Image Recognition <https://arxiv.org/abs/1512.03385>`_ for more \ details. Interfaces: ``__init__``, ``forward`` """
[docs] def __init__( self, in_channels: int, activation: nn.Module = nn.ReLU(), norm_type: str = 'BN', res_type: str = 'basic', bias: bool = True, out_channels: Union[int, None] = None, ) -> None: """ Overview: Init the 2D convolution residual block. Arguments: - in_channels (:obj:`int`): Number of channels in the input tensor. - activation (:obj:`nn.Module`): The optional activation function. - norm_type (:obj:`str`): Type of the normalization, default set to 'BN'(Batch Normalization), \ supports ['BN', 'LN', 'IN', 'GN', 'SyncBN', None]. - res_type (:obj:`str`): Type of residual block, supports ['basic', 'bottleneck', 'downsample'] - bias (:obj:`bool`): Whether to add a learnable bias to the conv2d_block. default set to True. - out_channels (:obj:`int`): Number of channels in the output tensor, default set to None, \ which means out_channels = in_channels. """ super(ResBlock, self).__init__() self.act = activation assert res_type in ['basic', 'bottleneck', 'downsample'], 'residual type only support basic and bottleneck, not:{}'.format(res_type) self.res_type = res_type if out_channels is None: out_channels = in_channels if self.res_type == 'basic': self.conv1 = conv2d_block( in_channels, out_channels, 3, 1, 1, activation=self.act, norm_type=norm_type, bias=bias ) self.conv2 = conv2d_block( out_channels, out_channels, 3, 1, 1, activation=None, norm_type=norm_type, bias=bias ) elif self.res_type == 'bottleneck': self.conv1 = conv2d_block( in_channels, out_channels, 1, 1, 0, activation=self.act, norm_type=norm_type, bias=bias ) self.conv2 = conv2d_block( out_channels, out_channels, 3, 1, 1, activation=self.act, norm_type=norm_type, bias=bias ) self.conv3 = conv2d_block( out_channels, out_channels, 1, 1, 0, activation=None, norm_type=norm_type, bias=bias ) elif self.res_type == 'downsample': self.conv1 = conv2d_block( in_channels, out_channels, 3, 2, 1, activation=self.act, norm_type=norm_type, bias=bias ) self.conv2 = conv2d_block( out_channels, out_channels, 3, 1, 1, activation=None, norm_type=norm_type, bias=bias ) self.conv3 = conv2d_block(in_channels, out_channels, 3, 2, 1, activation=None, norm_type=None, bias=bias)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Overview: Return the redisual block output. Arguments: - x (:obj:`torch.Tensor`): The input tensor. Returns: - x (:obj:`torch.Tensor`): The resblock output tensor. """ identity = x x = self.conv1(x) x = self.conv2(x) if self.res_type == 'bottleneck': x = self.conv3(x) elif self.res_type == 'downsample': identity = self.conv3(identity) x = self.act(x + identity) return x
[docs]class ResFCBlock(nn.Module): """ Overview: Residual Block with 2 fully connected layers. x -> fc1 -> norm -> act -> fc2 -> norm -> act -> out \_____________________________________/+ Interfaces: ``__init__``, ``forward`` """
[docs] def __init__( self, in_channels: int, activation: nn.Module = nn.ReLU(), norm_type: str = 'BN', dropout: float = None ): """ Overview: Init the fully connected layer residual block. Arguments: - in_channels (:obj:`int`): The number of channels in the input tensor. - activation (:obj:`nn.Module`): The optional activation function. - norm_type (:obj:`str`): The type of the normalization, default set to 'BN'. - dropout (:obj:`float`): The dropout rate, default set to None. """ super(ResFCBlock, self).__init__() self.act = activation if dropout is not None: self.dropout = nn.Dropout(dropout) else: self.dropout = None self.fc1 = fc_block(in_channels, in_channels, activation=self.act, norm_type=norm_type) self.fc2 = fc_block(in_channels, in_channels, activation=None, norm_type=norm_type)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Overview: Return the output of the redisual block. Arguments: - x (:obj:`torch.Tensor`): The input tensor. Returns: - x (:obj:`torch.Tensor`): The resblock output tensor. """ identity = x x = self.fc1(x) x = self.fc2(x) x = self.act(x + identity) if self.dropout is not None: x = self.dropout(x) return x
class TemporalSpatialResBlock(nn.Module): """ Overview: Residual Block using MLP layers for both temporal and spatial input. t → time_mlp → h1 → dense2 → h2 → out ↗+ ↗+ x → dense1 → ↗ ↘ ↗ → modify_x → → → → """ def __init__(self, input_dim, output_dim, t_dim=128, activation=torch.nn.SiLU()): """ Overview: Init the temporal spatial residual block. Arguments: - input_dim (:obj:`int`): The number of channels in the input tensor. - output_dim (:obj:`int`): The number of channels in the output tensor. - t_dim (:obj:`int`): The dimension of the temporal input. - activation (:obj:`nn.Module`): The optional activation function. """ super().__init__() # temporal input is the embedding of time, which is a Gaussian Fourier Feature tensor self.time_mlp = nn.Sequential( activation, nn.Linear(t_dim, output_dim), ) self.dense1 = nn.Sequential(nn.Linear(input_dim, output_dim), activation) self.dense2 = nn.Sequential(nn.Linear(output_dim, output_dim), activation) self.modify_x = nn.Linear(input_dim, output_dim) if input_dim != output_dim else nn.Identity() def forward(self, x, t) -> torch.Tensor: """ Overview: Return the redisual block output. Arguments: - x (:obj:`torch.Tensor`): The input tensor. - t (:obj:`torch.Tensor`): The temporal input tensor. """ h1 = self.dense1(x) + self.time_mlp(t) h2 = self.dense2(h1) return h2 + self.modify_x(x)