Spatial Transformer Networks (STN) on MNIST With PyTorch

Spatial Transformer Networks (STN) on MNIST - Complete Implementation Guide

Spatial Transformer Networks (STN) on MNIST

Introduction

This document explains a complete implementation of Spatial Transformer Networks (STN) applied to the MNIST dataset using PyTorch. The STN is a learnable module that automatically applies spatial transformations to input data to enhance geometric invariance in neural networks.

STN Architecture

STN architecture overview (Source: PyTorch tutorials)

Setup and Initialization

1. Importing Required Libraries

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import urllib.request

Key libraries:

  • torch: Core PyTorch library
  • torch.nn: Neural network modules
  • torchvision: Computer vision datasets and transforms
  • matplotlib: Visualization

2. Configuration

plt.ion()  # interactive mode
torch.manual_seed(1)  # for reproducibility

# Fix for urllib to avoid certificate/user-agent issues
opener = urllib.request.build_opener()
opener.addheaders = [('User-agent', 'Mozilla/5.0')]
urllib.request.install_opener(opener)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Key points:

  • plt.ion() enables interactive mode for matplotlib
  • Random seed is set for reproducibility
  • URL opener is configured to avoid download issues
  • Device is set to use GPU if available

Data Preparation

1. Data Transformation

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

Transformations:

  • ToTensor(): Converts PIL Image to PyTorch tensor
  • Normalize(): Normalizes with MNIST mean (0.1307) and std (0.3081)

2. Data Loaders

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root='.', train=True, download=True, transform=transform),
    batch_size=64, shuffle=True, num_workers=0
)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root='.', train=False, transform=transform),
    batch_size=64, shuffle=True, num_workers=0
)

Parameters:

  • batch_size=64: Number of samples per batch
  • shuffle=True: Shuffles data each epoch
  • num_workers=0: Safer for some environments (can increase for performance)

Model Architecture

1. STN Network Components

Localization Network

self.localization = nn.Sequential(
    nn.Conv2d(1, 8, kernel_size=7),
    nn.MaxPool2d(2, stride=2),
    nn.ReLU(True),
    nn.Conv2d(8, 10, kernel_size=5),
    nn.MaxPool2d(2, stride=2),
    nn.ReLU(True)
)

This network learns the parameters of the spatial transformation:

  • Two convolutional layers with max pooling
  • Outputs features used to predict transformation parameters

Transformation Regressor

self.fc_loc = nn.Sequential(
    nn.Linear(10 * 3 * 3, 32),
    nn.ReLU(True),
    nn.Linear(32, 3 * 2)
)

Predicts the 6 parameters of the affine transformation matrix (2×3):

  • Initialized to identity transformation
  • Learns to predict rotation, translation, scale, etc.

Main Classification Network

self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)

Standard CNN for digit classification after spatial transformation.

2. STN Transformation Process

def stn(self, x):
    xs = self.localization(x)
    xs = xs.view(-1, 10 * 3 * 3)
    theta = self.fc_loc(xs)
    theta = theta.view(-1, 2, 3)

    grid = F.affine_grid(theta, x.size(), align_corners=False)
    x = F.grid_sample(x, grid, align_corners=False)
    return x

Steps:

  1. Pass input through localization network
  2. Flatten features and predict transformation parameters
  3. Reshape parameters into 2×3 affine matrix
  4. Generate sampling grid with affine_grid
  5. Apply transformation with grid_sample

Training Process

1. Training Loop

def train(epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()

        if batch_idx % 500 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
                  f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')

Key steps:

  • Set model to training mode
  • Transfer data to device (GPU/CPU)
  • Zero gradients, forward pass, loss calculation
  • Backpropagation and optimizer step
  • Periodic progress logging

2. Testing Loop

def test():
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)

            test_loss += F.nll_loss(output, target, reduction='sum').item()
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} '
          f'({100. * correct / len(test_loader.dataset):.0f}%)\n')

Key points:

  • Set model to evaluation mode
  • No gradient calculation (torch.no_grad())
  • Calculate loss and accuracy metrics

Visualization

1. Image Conversion

def convert_image_np(inp):
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.1307])
    std = np.array([0.3081])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    return inp.squeeze()

Converts tensor to numpy array and reverses normalization for display.

2. STN Visualization

def visualize_stn():
    with torch.no_grad():
        data = next(iter(test_loader))[0].to(device)
        input_tensor = data.cpu()
        transformed_input_tensor = model.stn(data).cpu()

        in_grid = convert_image_np(torchvision.utils.make_grid(input_tensor))
        out_grid = convert_image_np(torchvision.utils.make_grid(transformed_input_tensor))

        fig, axarr = plt.subplots(1, 2, figsize=(8, 4))
        axarr[0].imshow(in_grid, cmap='gray')
        axarr[0].set_title('Original Images')

        axarr[1].imshow(out_grid, cmap='gray')
        axarr[1].set_title('Transformed Images')

        for ax in axarr:
            ax.axis('off')

Shows original and transformed images side by side to demonstrate the STN's effect.

Execution Flow

1. Model Instantiation

model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01)

2. Training and Evaluation

for epoch in range(1, 21):
    train(epoch)
    test()

Runs 20 epochs of training with testing after each epoch.

3. Visualization

visualize_stn()
plt.ioff()
plt.show()

Expected Output

During execution, you should see:

  1. Training progress updates every 500 batches
  2. Test accuracy after each epoch (typically reaching ~98%)
  3. Final visualization showing original and transformed digits
Examples of handwritten digits from the MNIST dataset
Standard MNIST digits for reference (Source: Wikimedia Commons)

Troubleshooting

Common Issues

  • Download errors: The URL opener fix should handle most cases
  • CUDA out of memory: Reduce batch size if needed
  • Visualization not showing: Ensure matplotlib backend is properly configured

Conclusion

This implementation demonstrates how Spatial Transformer Networks can learn to automatically apply beneficial transformations to input data, improving model performance on tasks requiring spatial invariance like digit recognition.

The key advantages of STN are:

  • End-to-end differentiable
  • Can be inserted into any CNN architecture
  • Learns transformations without explicit supervision
  • Improves model robustness to geometric variations

Comments

Popular posts from this blog

CIFAR-10 Dataset Classification Using Convolutional Neural Networks (CNNs) With PyTorch

Radial Basis Function Networks with PyTorch

Long-short-term-memory (LSTM) Word Prediction With PyTorch