Generative Adversarial Networks (GANs) Text Generation With PyTorch
Complete Guide to Generative Adversarial Networks (GANs) with PyTorch
What You'll Learn: This comprehensive guide covers everything from basic GAN implementation to advanced techniques, including DCGANs, WGANs, and strategies for stable training.
Understanding GAN Fundamentals
The GAN Framework
Generative Adversarial Networks consist of two neural networks engaged in a minimax game:
Generator (G)
- Maps random noise to data space
- Tries to produce realistic samples
- Typically starts with poor quality outputs
- Improves through adversarial training
Discriminator (D)
- Acts as a binary classifier
- Distinguishes real from generated samples
- Provides training signal to generator
- Must balance between being too strong/weak
Mathematical Foundation
The GAN objective can be expressed as:
min_G max_D V(D,G) = 𝔼ₓ∼p_data[log D(x)] + 𝔼_z∼p_z[log(1 - D(G(z)))]
Where:
- D(x) is discriminator's estimate of probability that x is real
- G(z) is generator's output given noise z
- p_data is the data generating distribution
- p_z is the noise distribution
Complete PyTorch Implementation
1. Enhanced Dataset Preparation
import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from torchvision import datasets, transforms import numpy as np import matplotlib.pyplot as plt from torchvision.utils import make_grid # Configuration device = torch.device("cuda" if torch.cuda.is_available() else "cpu") batch_size = 128 image_size = 28 * 28 # MNIST dimensions latent_dim = 100 # Size of noise vector # Enhanced transformations with data augmentation transform = transforms.Compose([ transforms.Resize(32), # Standard size for many GAN architectures transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)), transforms.RandomHorizontalFlip(p=0.5) # Simple augmentation ]) # Load dataset dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
2. Advanced Generator Architecture
class Generator(nn.Module): def __init__(self, latent_dim): super(Generator, self).__init__() self.main = nn.Sequential( # Input is latent_dim, going into a convolution nn.Linear(latent_dim, 128 * 7 * 7), nn.BatchNorm1d(128 * 7 * 7), nn.LeakyReLU(0.2, inplace=True), # Reshape into feature maps nn.Unflatten(1, (128, 7, 7)), # Upsample to 14x14 nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False), nn.BatchNorm2d(64), nn.LeakyReLU(0.2, inplace=True), # Upsample to 28x28 nn.ConvTranspose2d(64, 1, 4, 2, 1, bias=False), nn.Tanh() ) def forward(self, input): return self.main(input)
Architecture Note: This generator uses transposed convolutions (sometimes called "deconvolutions") to upsample the noise vector into an image. Batch normalization and leaky ReLU help with training stability.
3. Enhanced Discriminator with Spectral Normalization
def add_spectral_norm(layer): if isinstance(layer, (nn.Conv2d, nn.Linear)): return nn.utils.spectral_norm(layer) return layer class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() layers = [ # Input is 1x28x28 nn.Conv2d(1, 64, 4, 2, 1, bias=False), nn.LeakyReLU(0.2, inplace=True), # 64x14x14 nn.Conv2d(64, 128, 4, 2, 1, bias=False), nn.BatchNorm2d(128), nn.LeakyReLU(0.2, inplace=True), # 128x7x7 nn.Flatten(), nn.Linear(128 * 7 * 7, 1) ] # Apply spectral normalization for stability self.main = nn.Sequential(*[add_spectral_norm(layer) for layer in layers]) def forward(self, input): return torch.sigmoid(self.main(input))
4. Comprehensive Training Loop
def train_gan(generator, discriminator, dataloader, num_epochs=50): # Loss function and optimizers criterion = nn.BCEWithLogitsLoss() lr = 0.0002 beta1 = 0.5 beta2 = 0.999 optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, beta2)) optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, beta2)) # Labels for real and fake images real_label = 1.0 fake_label = 0.0 # Training statistics G_losses = [] D_losses = [] D_x_values = [] # D(x) - average discriminator output on real D_G_z_values = [] # D(G(z)) - average discriminator output on fake for epoch in range(num_epochs): for i, (real_images, _) in enumerate(dataloader): # Move to device real_images = real_images.to(device) batch_size = real_images.size(0) # --------------------- # Train Discriminator # --------------------- # Zero gradients discriminator.zero_grad() # Forward pass real images output_real = discriminator(real_images).view(-1) label = torch.full((batch_size,), real_label, dtype=torch.float, device=device) errD_real = criterion(output_real, label) errD_real.backward() D_x = output_real.mean().item() # Generate fake images noise = torch.randn(batch_size, latent_dim, device=device) fake_images = generator(noise) # Forward pass fake images output_fake = discriminator(fake_images.detach()).view(-1) label.fill_(fake_label) errD_fake = criterion(output_fake, label) errD_fake.backward() D_G_z1 = output_fake.mean().item() # Total discriminator loss errD = errD_real + errD_fake optimizer_D.step() # ----------------- # Train Generator # ----------------- generator.zero_grad() label.fill_(real_label) # Generator wants discriminator to think fakes are real output = discriminator(fake_images).view(-1) errG = criterion(output, label) errG.backward() D_G_z2 = output.mean().item() optimizer_G.step() # Save statistics G_losses.append(errG.item()) D_losses.append(errD.item()) D_x_values.append(D_x) D_G_z_values.append(D_G_z1) # Print training stats if i % 100 == 0: print(f'[{epoch}/{num_epochs}][{i}/{len(dataloader)}] ' f'Loss_D: {errD.item():.4f} Loss_G: {errG.item():.4f} ' f'D(x): {D_x:.4f} D(G(z)): {D_G_z1:.4f}/{D_G_z2:.4f}') # Generate sample images after each epoch with torch.no_grad(): sample_noise = torch.randn(64, latent_dim, device=device) generated_images = generator(sample_noise).detach().cpu() save_image(generated_images, f'gan_samples/epoch_{epoch}.png', nrow=8, normalize=True)
Advanced GAN Architectures
Deep Convolutional GAN (DCGAN)
DCGANs use convolutional layers in both generator and discriminator, following these architectural guidelines:
- Replace pooling layers with strided convolutions (discriminator) and fractional-strided convolutions (generator)
- Use batch normalization in both networks
- Remove fully connected hidden layers
- Use ReLU activation in generator (except output layer which uses Tanh)
- Use LeakyReLU in discriminator
Wasserstein GAN (WGAN)
WGANs improve training stability through:
# Key differences from standard GAN: # 1. Use Wasserstein loss (remove sigmoid from discriminator) # 2. Clip discriminator weights (-0.01, 0.01) # 3. Train discriminator more than generator (typically 5:1 ratio) # WGAN loss functions def discriminator_loss(real_output, fake_output): return torch.mean(fake_output) - torch.mean(real_output) def generator_loss(fake_output): return -torch.mean(fake_output)
GAN Training: Challenges and Solutions
Challenge | Symptoms | Solutions |
---|---|---|
Mode Collapse | Generator produces limited variety of samples | Use minibatch discrimination, unrolled GANs, or WGAN |
Vanishing Gradients | Generator stops improving | Use Wasserstein loss, gradient penalty, or different architectures |
Oscillations | Losses fluctuate without convergence | Adjust learning rates, use TTUR (Two Time-scale Update Rule) |
Overfitting | Discriminator reaches 100% accuracy | Add dropout, reduce discriminator capacity, or add noise |
Monitoring GAN Training
Effective metrics to track during training:
- Inception Score (IS): Measures both quality and diversity of generated images
- Frechet Inception Distance (FID): Compares statistics of real and generated images
- Visual Inspection: Regularly generate sample images to check quality
- Discriminator Metrics: Track D(x) and D(G(z)) values (should converge to ~0.5)
Practical Applications of GANs
Image Synthesis
- Photorealistic image generation
- Art creation
- Data augmentation
Image-to-Image Translation
- Style transfer
- Colorization
- Super-resolution
Domain Adaptation
- Sim-to-real transfer
- Medical imaging
- Anime generation
Extending Your GAN Implementation
Conditional GANs
Add class information to both generator and discriminator to control output categories:
class ConditionalGenerator(nn.Module): def __init__(self, latent_dim, num_classes): super().__init__() self.label_embedding = nn.Embedding(num_classes, num_classes) self.main = nn.Sequential( nn.Linear(latent_dim + num_classes, 256), nn.LeakyReLU(0.2), # ... rest of generator layers ) def forward(self, noise, labels): label_embedding = self.label_embedding(labels) x = torch.cat((noise, label_embedding), dim=1) return self.main(x)
Progressive Growing of GANs
Start with low-resolution images and progressively increase resolution during training:
- Begin training with 4x4 images
- Gradually add layers to increase resolution
- Smoothly fade in new layers
- Results in higher quality outputs than fixed architectures
Resources and Next Steps
Recommended Papers
Datasets to Try
- CIFAR-10
- CelebA
- LSUN Bedrooms
- FFHQ (Flickr Faces HQ)
Advanced Techniques
- Self-Attention GANs
- StyleGAN/StyleGAN2
- BigGAN
- CycleGAN for unpaired translation
Final Tip: When working with GANs, be patient and experiment with different architectures and hyperparameters. The field evolves rapidly, so stay updated with the latest research papers and techniques.
Comments