Generative Adversarial Networks (GANs) Text Generation With PyTorch

Complete Guide to Generative Adversarial Networks (GANs) with PyTorch

Complete Guide to Generative Adversarial Networks (GANs) with PyTorch

What You'll Learn: This comprehensive guide covers everything from basic GAN implementation to advanced techniques, including DCGANs, WGANs, and strategies for stable training.

Understanding GAN Fundamentals

The GAN Framework

Generative Adversarial Networks consist of two neural networks engaged in a minimax game:

Generator (G)

  • Maps random noise to data space
  • Tries to produce realistic samples
  • Typically starts with poor quality outputs
  • Improves through adversarial training

Discriminator (D)

  • Acts as a binary classifier
  • Distinguishes real from generated samples
  • Provides training signal to generator
  • Must balance between being too strong/weak

Mathematical Foundation

The GAN objective can be expressed as:

min_G max_D V(D,G) = 𝔼ₓ∼p_data[log D(x)] + 𝔼_z∼p_z[log(1 - D(G(z)))]

Where:

  • D(x) is discriminator's estimate of probability that x is real
  • G(z) is generator's output given noise z
  • p_data is the data generating distribution
  • p_z is the noise distribution

Complete PyTorch Implementation

1. Enhanced Dataset Preparation

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
from torchvision.utils import make_grid

# Configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 128
image_size = 28 * 28  # MNIST dimensions
latent_dim = 100  # Size of noise vector

# Enhanced transformations with data augmentation
transform = transforms.Compose([
    transforms.Resize(32),  # Standard size for many GAN architectures
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
    transforms.RandomHorizontalFlip(p=0.5)  # Simple augmentation
])

# Load dataset
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
    

2. Advanced Generator Architecture

class Generator(nn.Module):
    def __init__(self, latent_dim):
        super(Generator, self).__init__()
        
        self.main = nn.Sequential(
            # Input is latent_dim, going into a convolution
            nn.Linear(latent_dim, 128 * 7 * 7),
            nn.BatchNorm1d(128 * 7 * 7),
            nn.LeakyReLU(0.2, inplace=True),
            
            # Reshape into feature maps
            nn.Unflatten(1, (128, 7, 7)),
            
            # Upsample to 14x14
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            
            # Upsample to 28x28
            nn.ConvTranspose2d(64, 1, 4, 2, 1, bias=False),
            nn.Tanh()
        )
        
    def forward(self, input):
        return self.main(input)
    
Architecture Note: This generator uses transposed convolutions (sometimes called "deconvolutions") to upsample the noise vector into an image. Batch normalization and leaky ReLU help with training stability.

3. Enhanced Discriminator with Spectral Normalization

def add_spectral_norm(layer):
    if isinstance(layer, (nn.Conv2d, nn.Linear)):
        return nn.utils.spectral_norm(layer)
    return layer

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        layers = [
            # Input is 1x28x28
            nn.Conv2d(1, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 64x14x14
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 128x7x7
            nn.Flatten(),
            nn.Linear(128 * 7 * 7, 1)
        ]
        
        # Apply spectral normalization for stability
        self.main = nn.Sequential(*[add_spectral_norm(layer) for layer in layers])
        
    def forward(self, input):
        return torch.sigmoid(self.main(input))
    

4. Comprehensive Training Loop

def train_gan(generator, discriminator, dataloader, num_epochs=50):
    # Loss function and optimizers
    criterion = nn.BCEWithLogitsLoss()
    lr = 0.0002
    beta1 = 0.5
    beta2 = 0.999
    
    optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, beta2))
    optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, beta2))
    
    # Labels for real and fake images
    real_label = 1.0
    fake_label = 0.0
    
    # Training statistics
    G_losses = []
    D_losses = []
    D_x_values = []  # D(x) - average discriminator output on real
    D_G_z_values = []  # D(G(z)) - average discriminator output on fake
    
    for epoch in range(num_epochs):
        for i, (real_images, _) in enumerate(dataloader):
            # Move to device
            real_images = real_images.to(device)
            batch_size = real_images.size(0)
            
            # ---------------------
            # Train Discriminator
            # ---------------------
            
            # Zero gradients
            discriminator.zero_grad()
            
            # Forward pass real images
            output_real = discriminator(real_images).view(-1)
            label = torch.full((batch_size,), real_label, dtype=torch.float, device=device)
            errD_real = criterion(output_real, label)
            errD_real.backward()
            D_x = output_real.mean().item()
            
            # Generate fake images
            noise = torch.randn(batch_size, latent_dim, device=device)
            fake_images = generator(noise)
            
            # Forward pass fake images
            output_fake = discriminator(fake_images.detach()).view(-1)
            label.fill_(fake_label)
            errD_fake = criterion(output_fake, label)
            errD_fake.backward()
            D_G_z1 = output_fake.mean().item()
            
            # Total discriminator loss
            errD = errD_real + errD_fake
            optimizer_D.step()
            
            # -----------------
            # Train Generator
            # -----------------
            
            generator.zero_grad()
            label.fill_(real_label)  # Generator wants discriminator to think fakes are real
            output = discriminator(fake_images).view(-1)
            errG = criterion(output, label)
            errG.backward()
            D_G_z2 = output.mean().item()
            optimizer_G.step()
            
            # Save statistics
            G_losses.append(errG.item())
            D_losses.append(errD.item())
            D_x_values.append(D_x)
            D_G_z_values.append(D_G_z1)
            
            # Print training stats
            if i % 100 == 0:
                print(f'[{epoch}/{num_epochs}][{i}/{len(dataloader)}] '
                      f'Loss_D: {errD.item():.4f} Loss_G: {errG.item():.4f} '
                      f'D(x): {D_x:.4f} D(G(z)): {D_G_z1:.4f}/{D_G_z2:.4f}')
        
        # Generate sample images after each epoch
        with torch.no_grad():
            sample_noise = torch.randn(64, latent_dim, device=device)
            generated_images = generator(sample_noise).detach().cpu()
            save_image(generated_images, f'gan_samples/epoch_{epoch}.png', nrow=8, normalize=True)
    

Advanced GAN Architectures

Deep Convolutional GAN (DCGAN)

DCGANs use convolutional layers in both generator and discriminator, following these architectural guidelines:

  • Replace pooling layers with strided convolutions (discriminator) and fractional-strided convolutions (generator)
  • Use batch normalization in both networks
  • Remove fully connected hidden layers
  • Use ReLU activation in generator (except output layer which uses Tanh)
  • Use LeakyReLU in discriminator

Wasserstein GAN (WGAN)

WGANs improve training stability through:

# Key differences from standard GAN:
# 1. Use Wasserstein loss (remove sigmoid from discriminator)
# 2. Clip discriminator weights (-0.01, 0.01)
# 3. Train discriminator more than generator (typically 5:1 ratio)

# WGAN loss functions
def discriminator_loss(real_output, fake_output):
    return torch.mean(fake_output) - torch.mean(real_output)

def generator_loss(fake_output):
    return -torch.mean(fake_output)
    

GAN Training: Challenges and Solutions

Challenge Symptoms Solutions
Mode Collapse Generator produces limited variety of samples Use minibatch discrimination, unrolled GANs, or WGAN
Vanishing Gradients Generator stops improving Use Wasserstein loss, gradient penalty, or different architectures
Oscillations Losses fluctuate without convergence Adjust learning rates, use TTUR (Two Time-scale Update Rule)
Overfitting Discriminator reaches 100% accuracy Add dropout, reduce discriminator capacity, or add noise

Monitoring GAN Training

Effective metrics to track during training:

  1. Inception Score (IS): Measures both quality and diversity of generated images
  2. Frechet Inception Distance (FID): Compares statistics of real and generated images
  3. Visual Inspection: Regularly generate sample images to check quality
  4. Discriminator Metrics: Track D(x) and D(G(z)) values (should converge to ~0.5)

Practical Applications of GANs

Image Synthesis

  • Photorealistic image generation
  • Art creation
  • Data augmentation

Image-to-Image Translation

  • Style transfer
  • Colorization
  • Super-resolution

Domain Adaptation

  • Sim-to-real transfer
  • Medical imaging
  • Anime generation

Extending Your GAN Implementation

Conditional GANs

Add class information to both generator and discriminator to control output categories:

class ConditionalGenerator(nn.Module):
    def __init__(self, latent_dim, num_classes):
        super().__init__()
        self.label_embedding = nn.Embedding(num_classes, num_classes)
        self.main = nn.Sequential(
            nn.Linear(latent_dim + num_classes, 256),
            nn.LeakyReLU(0.2),
            # ... rest of generator layers
        )
    
    def forward(self, noise, labels):
        label_embedding = self.label_embedding(labels)
        x = torch.cat((noise, label_embedding), dim=1)
        return self.main(x)
    

Progressive Growing of GANs

Start with low-resolution images and progressively increase resolution during training:

  1. Begin training with 4x4 images
  2. Gradually add layers to increase resolution
  3. Smoothly fade in new layers
  4. Results in higher quality outputs than fixed architectures

Resources and Next Steps

Datasets to Try

  • CIFAR-10
  • CelebA
  • LSUN Bedrooms
  • FFHQ (Flickr Faces HQ)

Advanced Techniques

  • Self-Attention GANs
  • StyleGAN/StyleGAN2
  • BigGAN
  • CycleGAN for unpaired translation
View Complete Code Repository on GitHub
Final Tip: When working with GANs, be patient and experiment with different architectures and hyperparameters. The field evolves rapidly, so stay updated with the latest research papers and techniques.

Comments

Popular posts from this blog

Tech Duos For Web Development

CIFAR-10 Dataset Classification Using Convolutional Neural Networks (CNNs) With PyTorch

Long-short-term-memory (LSTM) Word Prediction With PyTorch