Artificial Neural Network(ANN) With PyTorch
Advanced MNIST Digit Classification with PyTorch
This comprehensive tutorial covers multiple approaches to handwritten digit recognition using the MNIST dataset with PyTorch, from basic neural networks to convolutional architectures and advanced techniques.
📌 Key Features
- Basic ANN implementation
- CNN architectures
- Data augmentation
- Model interpretation
- Hyperparameter tuning
- Performance optimization
🎯 Performance
- Basic ANN: ~97% accuracy
- Simple CNN: ~99% accuracy
- Training time: 1-10 minutes
- Hardware: CPU/GPU compatible
1. Environment Setup
Prerequisites
- Python 3.7+
- PyTorch 1.8+
- Matplotlib for visualization
Installation
# Base installation
pip install torch torchvision matplotlib
# With GPU support (CUDA)
pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113
2. Enhanced Data Loading
Improved data loading with augmentation and validation split:
def get_data_loaders(batch_size=64, val_ratio=0.1):
# Define transformations
train_transform = transforms.Compose([
transforms.RandomRotation(10),
transforms.RandomAffine(0, translate=(0.1, 0.1)),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
test_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
# Load datasets
full_train = torchvision.datasets.MNIST(
root='./data', train=True, download=True, transform=train_transform)
# Split into train and validation
val_size = int(val_ratio * len(full_train))
train_size = len(full_train) - val_size
train_dataset, val_dataset = torch.utils.data.random_split(
full_train, [train_size, val_size])
test_dataset = torchvision.datasets.MNIST(
root='./data', train=False, download=True, transform=test_transform)
# Create data loaders
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = torch.utils.data.DataLoader(
val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
return train_loader, val_loader, test_loader
Note: We use MNIST-specific normalization values (mean=0.1307, std=0.3081) for better performance.
The validation set helps monitor for overfitting during training.
3. Model Architectures
Option 1: Basic Neural Network (ANN)
class MNIST_ANN(nn.Module):
def __init__(self, input_size=784, hidden_size=128, num_classes=10):
super(MNIST_ANN, self).__init__()
self.flatten = nn.Flatten()
self.fc1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, num_classes)
self.dropout = nn.Dropout(0.2)
def forward(self, x):
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.dropout(x)
x = self.fc2(x)
return x
Option 2: Convolutional Neural Network (CNN)
class MNIST_CNN(nn.Module):
def __init__(self):
super(MNIST_CNN, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout2d(0.25)
self.dropout2 = nn.Dropout2d(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
return x
Option 3: Advanced CNN with BatchNorm
class Advanced_CNN(nn.Module):
def __init__(self):
super(Advanced_CNN, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(1, 32, 3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.Conv2d(32, 32, 3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Dropout(0.25),
nn.Conv2d(32, 64, 3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 64, 3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Dropout(0.25)
)
self.classifier = nn.Sequential(
nn.Linear(64*7*7, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(512, 10)
)
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
Model | Architecture | Expected Accuracy | Training Time |
---|---|---|---|
Basic ANN | 2-layer fully connected | ~97% | Fastest |
Simple CNN | 2 conv + 2 fc layers | ~98.5% | Medium |
Advanced CNN | 4 conv + batch norm | ~99.2% | Slowest |
4. Enhanced Training Loop
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler=None, epochs=10):
train_losses = []
val_losses = []
val_accuracies = []
for epoch in range(epochs):
# Training phase
model.train()
running_loss = 0.0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item() * images.size(0)
# Validation phase
model.eval()
val_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for images, labels in val_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
val_loss += loss.item() * images.size(0)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
# Calculate metrics
train_loss = running_loss / len(train_loader.dataset)
val_loss = val_loss / len(val_loader.dataset)
val_accuracy = 100 * correct / total
train_losses.append(train_loss)
val_losses.append(val_loss)
val_accuracies.append(val_accuracy)
# Update learning rate if scheduler provided
if scheduler:
scheduler.step(val_loss)
print(f"Epoch {epoch+1}/{epochs}")
print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_accuracy:.2f}%")
# Plot training history
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train')
plt.plot(val_losses, label='Validation')
plt.title('Loss over epochs')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(val_accuracies)
plt.title('Validation accuracy')
plt.show()
return {
'train_loss': train_losses,
'val_loss': val_losses,
'val_acc': val_accuracies
}
5. Advanced Evaluation
Confusion Matrix
from sklearn.metrics import confusion_matrix
import seaborn as sns
def plot_confusion_matrix(model, test_loader):
model.eval()
all_preds = []
all_labels = []
with torch.no_grad():
for images, labels in test_loader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
_, preds = torch.max(outputs, 1)
all_preds.extend(preds.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
cm = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=range(10), yticklabels=range(10))
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.show()
Class-wise Accuracy
def class_accuracy(model, test_loader):
model.eval()
class_correct = [0] * 10
class_total = [0] * 10
with torch.no_grad():
for images, labels in test_loader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
_, preds = torch.max(outputs, 1)
correct = (preds == labels).squeeze()
for i in range(len(labels)):
label = labels[i]
class_correct[label] += correct[i].item()
class_total[label] += 1
for i in range(10):
print(f'Accuracy of {i}: {100 * class_correct[i] / class_total[i]:.2f}%')
6. Model Interpretation
Visualizing Feature Maps
def visualize_feature_maps(model, image):
# Hook to get feature maps from first conv layer
features = None
def hook(module, input, output):
nonlocal features
features = output.detach()
handle = model.conv1.register_forward_hook(hook)
# Forward pass
model.eval()
with torch.no_grad():
_ = model(image.unsqueeze(0).to(device))
# Remove hook
handle.remove()
# Plot feature maps
plt.figure(figsize=(12, 8))
for i in range(min(16, features.shape[1])): # Show first 16 filters
plt.subplot(4, 4, i+1)
plt.imshow(features[0, i].cpu(), cmap='viridis')
plt.axis('off')
plt.suptitle('Feature Maps from First Conv Layer')
plt.show()
7. Hyperparameter Tuning
Learning Rate Finder
def find_lr(model, train_loader, criterion, optimizer, init_value=1e-8, end_value=1, beta=0.98):
num = len(train_loader) - 1
mult = (end_value / init_value) ** (1/num)
lr = init_value
optimizer.param_groups[0]['lr'] = lr
avg_loss = 0.
best_loss = 0.
losses = []
log_lrs = []
for batch_num, (inputs, targets) in enumerate(train_loader):
inputs = inputs.to(device)
targets = targets.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
# Compute smoothed loss
avg_loss = beta * avg_loss + (1-beta) * loss.item()
smoothed_loss = avg_loss / (1 - beta**(batch_num+1))
# Stop if loss explodes
if batch_num > 0 and smoothed_loss > 4 * best_loss:
break
if smoothed_loss < best_loss or batch_num == 0:
best_loss = smoothed_loss
losses.append(smoothed_loss)
log_lrs.append(math.log10(lr))
loss.backward()
optimizer.step()
lr *= mult
optimizer.param_groups[0]['lr'] = lr
plt.figure()
plt.plot(log_lrs, losses)
plt.xlabel('log10(lr)')
plt.ylabel('Loss')
plt.title('Learning Rate Finder')
plt.show()
8. Main Program Flow
if __name__ == "__main__":
# Initialize
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_loader, val_loader, test_loader = get_data_loaders(batch_size=128)
# Create model
model = Advanced_CNN().to(device)
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2)
# Train
history = train_model(
model, train_loader, val_loader,
criterion, optimizer, scheduler, epochs=15
)
# Evaluate
test_model(model, test_loader)
plot_confusion_matrix(model, test_loader)
class_accuracy(model, test_loader)
# Visualize
sample_image, _ = next(iter(test_loader))
visualize_feature_maps(model, sample_image[0])
9. Performance Optimization
GPU Acceleration
Use CUDA for faster training:
# Check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
Mixed Precision
Faster training with less memory:
from torch.cuda.amp import GradScaler, autocast
scaler = GradScaler()
with autocast():
outputs = model(inputs)
loss = criterion(outputs, targets)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
10. Deployment
Saving and Loading Models
# Save
torch.save({
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}, 'mnist_model.pth')
# Load
checkpoint = torch.load('mnist_model.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
ONNX Export
dummy_input = torch.randn(1, 1, 28, 28).to(device)
torch.onnx.export(
model, dummy_input, "mnist_model.onnx",
input_names=["input"], output_names=["output"],
dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}
)
11. Further Improvements
- Data Augmentation: Add more transformations like random zoom, shear, or elastic deformations
- Architecture: Try ResNet, EfficientNet, or other advanced architectures
- Regularization: Experiment with different dropout rates, weight decay, or early stopping
- Hyperparameter Tuning: Use tools like Optuna or Ray Tune for automated optimization
- Ensemble Methods: Combine predictions from multiple models
12. Resources
- PyTorch Documentation
- MNIST SOTA Results
- Stanford CS231n: CNN for Visual Recognition
- PyTorch MNIST Examples
Written by Your Name | Last updated: 2024 | License: MIT
Comments