Example implementations of a basic GAN and Wasserstein GAN (WGAN) using PyTorch:
import torch
import torch.nn as nn
class Generator(nn.Module):
def __init__(self, latent_dim, hidden_dims, output_dim):
super(Generator, self).__init__()
layers = []
prev_dim = latent_dim
for hidden_dim in hidden_dims:
layers.extend([
nn.Linear(prev_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(inplace=True)
])
prev_dim = hidden_dim
layers.append(nn.Linear(prev_dim, output_dim))
layers.append(nn.Tanh())
self.model = nn.Sequential(*layers)
def forward(self, z):
return self.model(z)
class Discriminator(nn.Module):
def __init__(self, input_dim, hidden_dims):
super(Discriminator, self).__init__()
layers = []
prev_dim = input_dim
for hidden_dim in hidden_dims:
layers.extend([
nn.Linear(prev_dim, hidden_dim),
nn.LeakyReLU(0.2, inplace=True),
nn.Dropout(0.3)
])
prev_dim = hidden_dim
layers.append(nn.Linear(prev_dim, 1))
layers.append(nn.Sigmoid())
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x)
class WGAN(nn.Module):
def __init__(self, latent_dim, hidden_dims, output_dim, critic_iters=5, clip_value=0.01):
super(WGAN, self).__init__()
self.generator = Generator(latent_dim, hidden_dims, output_dim)
self.critic = Discriminator(output_dim, hidden_dims)
self.latent_dim = latent_dim
self.critic_iters = critic_iters
self.clip_value = clip_value
def clip_critic_weights(self):
for p in self.critic.parameters():
p.data.clamp_(-self.clip_value, self.clip_value)
def train_step(self, real_data, optimizer_g, optimizer_c):
batch_size = real_data.size(0)
critic_loss = 0
for _ in range(self.critic_iters):
z = torch.randn(batch_size, self.latent_dim)
fake_data = self.generator(z)
critic_real = self.critic(real_data)
critic_fake = self.critic(fake_data.detach())
critic_loss = -(torch.mean(critic_real) - torch.mean(critic_fake))
optimizer_c.zero_grad()
critic_loss.backward()
optimizer_c.step()
self.clip_critic_weights()
z = torch.randn(batch_size, self.latent_dim)
fake_data = self.generator(z)
critic_fake = self.critic(fake_data)
generator_loss = -torch.mean(critic_fake)
optimizer_g.zero_grad()
generator_loss.backward()
optimizer_g.step()
return {
'critic_loss': critic_loss.item(),
'generator_loss': generator_loss.item()
}