Source code for core.models.vae_model

import torch
import torch.nn as nn
import numpy as np
from typing import Dict, Tuple, List, Any
from torch.nn import functional as F


[docs]class VanillaVAE(nn.Module): """ Vanilla Variational Auto Encoder model. :Interfaces: encode, decode, reparameterize, forward, loss_function, sample, generate :Arguments: - in_channels (int): the channel number of input - latent_dim (int): the latent dimension of the middle representation - hidden_dims (List): the hidden dimensions of each layer in the MLP architecture in encoder and decoder - kld_weight(float): the weight of KLD loss """ def __init__( self, in_channels: int, latent_dim: int, hidden_dims: List = None, kld_weight: float = 0.1, ) -> None: super(VanillaVAE, self).__init__() self.latent_dim = latent_dim modules = [] if hidden_dims is None: hidden_dims = [32, 64, 128, 256, 512] self.hidden_dims = hidden_dims self.kld_weight = kld_weight # Build Encoder for h_dim in hidden_dims: modules.append( nn.Sequential( nn.Conv2d(in_channels, out_channels=h_dim, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(h_dim), nn.LeakyReLU() ) ) in_channels = h_dim self.encoder = nn.Sequential(*modules) # original fc self.fc_mu = nn.Linear(hidden_dims[-1] * 36, latent_dim) self.fc_var = nn.Linear(hidden_dims[-1] * 36, latent_dim) # Build Decoder modules = [] self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 36) hidden_dims.reverse() for i in range(len(hidden_dims) - 1): modules.append( nn.Sequential( nn.ConvTranspose2d( hidden_dims[i], hidden_dims[i + 1], kernel_size=3, stride=2, padding=1, output_padding=1 ), nn.BatchNorm2d(hidden_dims[i + 1]), nn.LeakyReLU() ) ) self.decoder = nn.Sequential(*modules) self.final_layer = nn.Sequential( nn.ConvTranspose2d(hidden_dims[-1], hidden_dims[-1], kernel_size=3, stride=2, padding=1, output_padding=1), nn.BatchNorm2d(hidden_dims[-1]), nn.LeakyReLU(), nn.Conv2d(hidden_dims[-1], out_channels=7, kernel_size=3, padding=1), nn.Sigmoid() )
[docs] def encode(self, input: torch.Tensor) -> List[torch.Tensor]: """ Encodes the input by passing through the encoder network and returns the latent codes. :Arguments: - input (Tensor): Input tensor to encode [N x C x H x W] :Returns: Tensor: List of latent codes """ result = self.encoder(input) result = torch.flatten(result, start_dim=1) # Split the result into mu and var components # of the latent Gaussian distribution mu = self.fc_mu(result) log_var = self.fc_var(result) return mu, log_var
[docs] def decode(self, z: torch.Tensor) -> torch.Tensor: """ Maps the given latent codes onto the image space. :Arguments: - z (Tensor): [B x D] :Returns: Tensor: Output decode tensor [B x C x H x W] """ result = self.decoder_input(z) result = result.view(-1, 512, 6, 6) result = self.decoder(result) result = self.final_layer(result) return result
[docs] def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: """ Reparameterization trick to sample from N(mu, var) from N(0,1). :Arguments: - mu (Tensor): Mean of the latent Gaussian [B x D] - logvar (Tensor): Standard deviation of the latent Gaussian [B x D] :Returns: Tensor: [B x D] """ std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return eps * std + mu
[docs] def forward(self, input: torch.Tensor, **kwargs) -> List[torch.Tensor]: """ [summary] :Arguments: - input (torch.Tensor): Input tensor :Returns: List[torch.Tensor]: Input and output tensor """ mu, log_var = self.encode(input) z = self.reparameterize(mu, log_var) #z = mu return [self.decode(z), input, mu, log_var]
[docs] def loss_function(self, *args, **kwargs) -> Dict: """ Computes the VAE loss function. :math:`KL(N(\mu, \sigma), N(0, 1)) = \log \\frac{1}{\sigma} + \\frac{\sigma^2 + \mu^2}{2} - \\frac{1}{2}` :Returns: Dict: Dictionary containing loss information """ recons = args[0] input = args[1] mu = args[2] log_var = args[3] #kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset kld_weight = self.kld_weight recons_loss = 0 ''' weight = [8.7924e-01, 7.4700e-02, 1.0993e-02, 6.1075e-04, 2.6168e-03, 2.8066e-02, 3.7737e-03] vd = 1 for i in range(7): cur = F.l1_loss(recons[:, i, ...], input[:, i, ...]) recons_loss += 1 / weight[i] * cur * vd ret[str(i)] = cur if i==0 and cur > 0.05: vd = 0 ''' recons_loss = F.mse_loss(recons, input) if recons_loss < 0.05: recons_loss = F.l1_loss(recons, input) kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=1), dim=0) loss = recons_loss + kld_weight * kld_loss return {'loss': loss, 'reconstruction_Loss': recons_loss, 'KLD': -kld_loss}
[docs] def sample(self, num_samples: int, current_device: int, **kwargs) -> torch.Tensor: r""" Samples from the latent space and return the corresponding image space map. :Arguments: - num_samples(Int): Number of samples. - current_device(Int): Device to run the model. :Returns: Tensor: Sampled decode tensor. """ z = torch.randn(num_samples, self.latent_dim) z = z.to(current_device) samples = self.decode(z) return samples
[docs] def generate(self, x: torch.Tensor, **kwargs) -> torch.Tensor: """ Given an input image x, returns the reconstructed image :Arguments: - x(Tensor): [B x C x H x W] :Returns: Tensor: [B x C x H x W] """ return self.forward(x)[0]