Spatial Transformer Networks (STN) on MNIST With PyTorch
Spatial Transformer Networks (STN) on MNIST
Introduction
This document explains a complete implementation of Spatial Transformer Networks (STN) applied to the MNIST dataset using PyTorch. The STN is a learnable module that automatically applies spatial transformations to input data to enhance geometric invariance in neural networks.
STN architecture overview (Source: PyTorch tutorials)
Setup and Initialization
1. Importing Required Libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import urllib.request
Key libraries:
torch: Core PyTorch librarytorch.nn: Neural network modulestorchvision: Computer vision datasets and transformsmatplotlib: Visualization
2. Configuration
plt.ion() # interactive mode
torch.manual_seed(1) # for reproducibility
# Fix for urllib to avoid certificate/user-agent issues
opener = urllib.request.build_opener()
opener.addheaders = [('User-agent', 'Mozilla/5.0')]
urllib.request.install_opener(opener)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Key points:
plt.ion()enables interactive mode for matplotlib- Random seed is set for reproducibility
- URL opener is configured to avoid download issues
- Device is set to use GPU if available
Data Preparation
1. Data Transformation
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
Transformations:
ToTensor(): Converts PIL Image to PyTorch tensorNormalize(): Normalizes with MNIST mean (0.1307) and std (0.3081)
2. Data Loaders
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(root='.', train=True, download=True, transform=transform),
batch_size=64, shuffle=True, num_workers=0
)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(root='.', train=False, transform=transform),
batch_size=64, shuffle=True, num_workers=0
)
Parameters:
batch_size=64: Number of samples per batchshuffle=True: Shuffles data each epochnum_workers=0: Safer for some environments (can increase for performance)
Model Architecture
1. STN Network Components
Localization Network
self.localization = nn.Sequential(
nn.Conv2d(1, 8, kernel_size=7),
nn.MaxPool2d(2, stride=2),
nn.ReLU(True),
nn.Conv2d(8, 10, kernel_size=5),
nn.MaxPool2d(2, stride=2),
nn.ReLU(True)
)
This network learns the parameters of the spatial transformation:
- Two convolutional layers with max pooling
- Outputs features used to predict transformation parameters
Transformation Regressor
self.fc_loc = nn.Sequential(
nn.Linear(10 * 3 * 3, 32),
nn.ReLU(True),
nn.Linear(32, 3 * 2)
)
Predicts the 6 parameters of the affine transformation matrix (2×3):
- Initialized to identity transformation
- Learns to predict rotation, translation, scale, etc.
Main Classification Network
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
Standard CNN for digit classification after spatial transformation.
2. STN Transformation Process
def stn(self, x):
xs = self.localization(x)
xs = xs.view(-1, 10 * 3 * 3)
theta = self.fc_loc(xs)
theta = theta.view(-1, 2, 3)
grid = F.affine_grid(theta, x.size(), align_corners=False)
x = F.grid_sample(x, grid, align_corners=False)
return x
Steps:
- Pass input through localization network
- Flatten features and predict transformation parameters
- Reshape parameters into 2×3 affine matrix
- Generate sampling grid with
affine_grid - Apply transformation with
grid_sample
Training Process
1. Training Loop
def train(epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % 500 == 0:
print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')
Key steps:
- Set model to training mode
- Transfer data to device (GPU/CPU)
- Zero gradients, forward pass, loss calculation
- Backpropagation and optimizer step
- Periodic progress logging
2. Testing Loop
def test():
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item()
pred = output.max(1, keepdim=True)[1]
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} '
f'({100. * correct / len(test_loader.dataset):.0f}%)\n')
Key points:
- Set model to evaluation mode
- No gradient calculation (
torch.no_grad()) - Calculate loss and accuracy metrics
Visualization
1. Image Conversion
def convert_image_np(inp):
inp = inp.numpy().transpose((1, 2, 0))
mean = np.array([0.1307])
std = np.array([0.3081])
inp = std * inp + mean
inp = np.clip(inp, 0, 1)
return inp.squeeze()
Converts tensor to numpy array and reverses normalization for display.
2. STN Visualization
def visualize_stn():
with torch.no_grad():
data = next(iter(test_loader))[0].to(device)
input_tensor = data.cpu()
transformed_input_tensor = model.stn(data).cpu()
in_grid = convert_image_np(torchvision.utils.make_grid(input_tensor))
out_grid = convert_image_np(torchvision.utils.make_grid(transformed_input_tensor))
fig, axarr = plt.subplots(1, 2, figsize=(8, 4))
axarr[0].imshow(in_grid, cmap='gray')
axarr[0].set_title('Original Images')
axarr[1].imshow(out_grid, cmap='gray')
axarr[1].set_title('Transformed Images')
for ax in axarr:
ax.axis('off')
Shows original and transformed images side by side to demonstrate the STN's effect.
Execution Flow
1. Model Instantiation
model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01)
2. Training and Evaluation
for epoch in range(1, 21):
train(epoch)
test()
Runs 20 epochs of training with testing after each epoch.
3. Visualization
visualize_stn()
plt.ioff()
plt.show()
Expected Output
During execution, you should see:
- Training progress updates every 500 batches
- Test accuracy after each epoch (typically reaching ~98%)
- Final visualization showing original and transformed digits
Troubleshooting
Common Issues
- Download errors: The URL opener fix should handle most cases
- CUDA out of memory: Reduce batch size if needed
- Visualization not showing: Ensure matplotlib backend is properly configured
Conclusion
This implementation demonstrates how Spatial Transformer Networks can learn to automatically apply beneficial transformations to input data, improving model performance on tasks requiring spatial invariance like digit recognition.
The key advantages of STN are:
- End-to-end differentiable
- Can be inserted into any CNN architecture
- Learns transformations without explicit supervision
- Improves model robustness to geometric variations
Comments