import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class BasicAutoencoder(nn.Module):
def __init__(self, input_dim, hidden_dims, latent_dim):
super(BasicAutoencoder, self).__init__()
# Build encoder
encoder_layers = []
prev_dim = input_dim
for hidden_dim in hidden_dims:
encoder_layers.extend([
nn.Linear(prev_dim, hidden_dim),
nn.ReLU(),
nn.BatchNorm1d(hidden_dim)
])
prev_dim = hidden_dim
encoder_layers.append(nn.Linear(prev_dim, latent_dim))
# Build decoder
decoder_layers = []
prev_dim = latent_dim
for hidden_dim in reversed(hidden_dims):
decoder_layers.extend([
nn.Linear(prev_dim, hidden_dim),
nn.ReLU(),
nn.BatchNorm1d(hidden_dim)
])
prev_dim = hidden_dim
decoder_layers.append(nn.Linear(prev_dim, input_dim))
self.encoder = nn.Sequential(*encoder_layers)
self.decoder = nn.Sequential(*decoder_layers)
def forward(self, x):
z = self.encoder(x)
x_recon = self.decoder(z)
return x_recon, z
class VariationalAutoencoder(nn.Module):
def __init__(self, input_dim, hidden_dims, latent_dim):
super(VariationalAutoencoder, self).__init__()
# Build encoder
encoder_layers = []
prev_dim = input_dim
for hidden_dim in hidden_dims:
encoder_layers.extend([
nn.Linear(prev_dim, hidden_dim),
nn.ReLU(),
nn.BatchNorm1d(hidden_dim)
])
prev_dim = hidden_dim
self.encoder = nn.Sequential(*encoder_layers)
self.fc_mu = nn.Linear(prev_dim, latent_dim)
self.fc_var = nn.Linear(prev_dim, latent_dim)
# Build decoder
decoder_layers = []
prev_dim = latent_dim
for hidden_dim in reversed(hidden_dims):
decoder_layers.extend([
nn.Linear(prev_dim, hidden_dim),
nn.ReLU(),
nn.BatchNorm1d(hidden_dim)
])
prev_dim = hidden_dim
decoder_layers.append(nn.Linear(prev_dim, input_dim))
self.decoder = nn.Sequential(*decoder_layers)
def encode(self, x):
h = self.encoder(x)
return self.fc_mu(h), self.fc_var(h)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z):
return self.decoder(z)
def forward(self, x):
mu, logvar = self.encode(x)
z = self.reparameterize(mu, logvar)
return self.decode(z), mu, logvar
class DenoisingAutoencoder(nn.Module):
def __init__(self, input_dim, hidden_dims, latent_dim, noise_factor=0.3):
super(DenoisingAutoencoder, self).__init__()
self.noise_factor = noise_factor
self.basic_ae = BasicAutoencoder(input_dim, hidden_dims, latent_dim)
def add_noise(self, x):
noise = torch.randn_like(x) * self.noise_factor
return x + noise
def forward(self, x):
x_noisy = self.add_noise(x)
x_recon, z = self.basic_ae(x_noisy)
return x_recon, z