Artificial Neural Network(ANN) With PyTorch

Advanced MNIST Digit Classification with PyTorch

Advanced MNIST Digit Classification with PyTorch

This comprehensive tutorial covers multiple approaches to handwritten digit recognition using the MNIST dataset with PyTorch, from basic neural networks to convolutional architectures and advanced techniques.

📌 Key Features

  • Basic ANN implementation
  • CNN architectures
  • Data augmentation
  • Model interpretation
  • Hyperparameter tuning
  • Performance optimization

🎯 Performance

  • Basic ANN: ~97% accuracy
  • Simple CNN: ~99% accuracy
  • Training time: 1-10 minutes
  • Hardware: CPU/GPU compatible

1. Environment Setup

Prerequisites

  • Python 3.7+
  • PyTorch 1.8+
  • Matplotlib for visualization

Installation

# Base installation pip install torch torchvision matplotlib # With GPU support (CUDA) pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113

2. Enhanced Data Loading

Improved data loading with augmentation and validation split:

def get_data_loaders(batch_size=64, val_ratio=0.1): # Define transformations train_transform = transforms.Compose([ transforms.RandomRotation(10), transforms.RandomAffine(0, translate=(0.1, 0.1)), transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) test_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) # Load datasets full_train = torchvision.datasets.MNIST( root='./data', train=True, download=True, transform=train_transform) # Split into train and validation val_size = int(val_ratio * len(full_train)) train_size = len(full_train) - val_size train_dataset, val_dataset = torch.utils.data.random_split( full_train, [train_size, val_size]) test_dataset = torchvision.datasets.MNIST( root='./data', train=False, download=True, transform=test_transform) # Create data loaders train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=2) val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=batch_size, shuffle=False, num_workers=2) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=batch_size, shuffle=False, num_workers=2) return train_loader, val_loader, test_loader
Note: We use MNIST-specific normalization values (mean=0.1307, std=0.3081) for better performance. The validation set helps monitor for overfitting during training.

3. Model Architectures

Option 1: Basic Neural Network (ANN)

class MNIST_ANN(nn.Module): def __init__(self, input_size=784, hidden_size=128, num_classes=10): super(MNIST_ANN, self).__init__() self.flatten = nn.Flatten() self.fc1 = nn.Linear(input_size, hidden_size) self.relu = nn.ReLU() self.fc2 = nn.Linear(hidden_size, num_classes) self.dropout = nn.Dropout(0.2) def forward(self, x): x = self.flatten(x) x = self.fc1(x) x = self.relu(x) x = self.dropout(x) x = self.fc2(x) return x

Option 2: Convolutional Neural Network (CNN)

class MNIST_CNN(nn.Module): def __init__(self): super(MNIST_CNN, self).__init__() self.conv1 = nn.Conv2d(1, 32, 3, 1) self.conv2 = nn.Conv2d(32, 64, 3, 1) self.dropout1 = nn.Dropout2d(0.25) self.dropout2 = nn.Dropout2d(0.5) self.fc1 = nn.Linear(9216, 128) self.fc2 = nn.Linear(128, 10) def forward(self, x): x = self.conv1(x) x = F.relu(x) x = self.conv2(x) x = F.relu(x) x = F.max_pool2d(x, 2) x = self.dropout1(x) x = torch.flatten(x, 1) x = self.fc1(x) x = F.relu(x) x = self.dropout2(x) x = self.fc2(x) return x

Option 3: Advanced CNN with BatchNorm

class Advanced_CNN(nn.Module): def __init__(self): super(Advanced_CNN, self).__init__() self.features = nn.Sequential( nn.Conv2d(1, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.Conv2d(32, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2), nn.Dropout(0.25), nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.Conv2d(64, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2), nn.Dropout(0.25) ) self.classifier = nn.Sequential( nn.Linear(64*7*7, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(0.5), nn.Linear(512, 10) ) def forward(self, x): x = self.features(x) x = x.view(x.size(0), -1) x = self.classifier(x) return x
Model Architecture Expected Accuracy Training Time
Basic ANN 2-layer fully connected ~97% Fastest
Simple CNN 2 conv + 2 fc layers ~98.5% Medium
Advanced CNN 4 conv + batch norm ~99.2% Slowest

4. Enhanced Training Loop

def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler=None, epochs=10): train_losses = [] val_losses = [] val_accuracies = [] for epoch in range(epochs): # Training phase model.train() running_loss = 0.0 for images, labels in train_loader: images, labels = images.to(device), labels.to(device) optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() * images.size(0) # Validation phase model.eval() val_loss = 0.0 correct = 0 total = 0 with torch.no_grad(): for images, labels in val_loader: images, labels = images.to(device), labels.to(device) outputs = model(images) loss = criterion(outputs, labels) val_loss += loss.item() * images.size(0) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() # Calculate metrics train_loss = running_loss / len(train_loader.dataset) val_loss = val_loss / len(val_loader.dataset) val_accuracy = 100 * correct / total train_losses.append(train_loss) val_losses.append(val_loss) val_accuracies.append(val_accuracy) # Update learning rate if scheduler provided if scheduler: scheduler.step(val_loss) print(f"Epoch {epoch+1}/{epochs}") print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_accuracy:.2f}%") # Plot training history plt.figure(figsize=(12, 4)) plt.subplot(1, 2, 1) plt.plot(train_losses, label='Train') plt.plot(val_losses, label='Validation') plt.title('Loss over epochs') plt.legend() plt.subplot(1, 2, 2) plt.plot(val_accuracies) plt.title('Validation accuracy') plt.show() return { 'train_loss': train_losses, 'val_loss': val_losses, 'val_acc': val_accuracies }

5. Advanced Evaluation

Confusion Matrix

from sklearn.metrics import confusion_matrix import seaborn as sns def plot_confusion_matrix(model, test_loader): model.eval() all_preds = [] all_labels = [] with torch.no_grad(): for images, labels in test_loader: images = images.to(device) labels = labels.to(device) outputs = model(images) _, preds = torch.max(outputs, 1) all_preds.extend(preds.cpu().numpy()) all_labels.extend(labels.cpu().numpy()) cm = confusion_matrix(all_labels, all_preds) plt.figure(figsize=(10, 8)) sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=range(10), yticklabels=range(10)) plt.xlabel('Predicted') plt.ylabel('True') plt.title('Confusion Matrix') plt.show()

Class-wise Accuracy

def class_accuracy(model, test_loader): model.eval() class_correct = [0] * 10 class_total = [0] * 10 with torch.no_grad(): for images, labels in test_loader: images = images.to(device) labels = labels.to(device) outputs = model(images) _, preds = torch.max(outputs, 1) correct = (preds == labels).squeeze() for i in range(len(labels)): label = labels[i] class_correct[label] += correct[i].item() class_total[label] += 1 for i in range(10): print(f'Accuracy of {i}: {100 * class_correct[i] / class_total[i]:.2f}%')

6. Model Interpretation

Visualizing Feature Maps

def visualize_feature_maps(model, image): # Hook to get feature maps from first conv layer features = None def hook(module, input, output): nonlocal features features = output.detach() handle = model.conv1.register_forward_hook(hook) # Forward pass model.eval() with torch.no_grad(): _ = model(image.unsqueeze(0).to(device)) # Remove hook handle.remove() # Plot feature maps plt.figure(figsize=(12, 8)) for i in range(min(16, features.shape[1])): # Show first 16 filters plt.subplot(4, 4, i+1) plt.imshow(features[0, i].cpu(), cmap='viridis') plt.axis('off') plt.suptitle('Feature Maps from First Conv Layer') plt.show()

7. Hyperparameter Tuning

Learning Rate Finder

def find_lr(model, train_loader, criterion, optimizer, init_value=1e-8, end_value=1, beta=0.98): num = len(train_loader) - 1 mult = (end_value / init_value) ** (1/num) lr = init_value optimizer.param_groups[0]['lr'] = lr avg_loss = 0. best_loss = 0. losses = [] log_lrs = [] for batch_num, (inputs, targets) in enumerate(train_loader): inputs = inputs.to(device) targets = targets.to(device) optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, targets) # Compute smoothed loss avg_loss = beta * avg_loss + (1-beta) * loss.item() smoothed_loss = avg_loss / (1 - beta**(batch_num+1)) # Stop if loss explodes if batch_num > 0 and smoothed_loss > 4 * best_loss: break if smoothed_loss < best_loss or batch_num == 0: best_loss = smoothed_loss losses.append(smoothed_loss) log_lrs.append(math.log10(lr)) loss.backward() optimizer.step() lr *= mult optimizer.param_groups[0]['lr'] = lr plt.figure() plt.plot(log_lrs, losses) plt.xlabel('log10(lr)') plt.ylabel('Loss') plt.title('Learning Rate Finder') plt.show()

8. Main Program Flow

if __name__ == "__main__": # Initialize device = torch.device("cuda" if torch.cuda.is_available() else "cpu") train_loader, val_loader, test_loader = get_data_loaders(batch_size=128) # Create model model = Advanced_CNN().to(device) # Loss and optimizer criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.001) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2) # Train history = train_model( model, train_loader, val_loader, criterion, optimizer, scheduler, epochs=15 ) # Evaluate test_model(model, test_loader) plot_confusion_matrix(model, test_loader) class_accuracy(model, test_loader) # Visualize sample_image, _ = next(iter(test_loader)) visualize_feature_maps(model, sample_image[0])

9. Performance Optimization

GPU Acceleration

Use CUDA for faster training:

# Check for GPU device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device)

Mixed Precision

Faster training with less memory:

from torch.cuda.amp import GradScaler, autocast scaler = GradScaler() with autocast(): outputs = model(inputs) loss = criterion(outputs, targets) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()

10. Deployment

Saving and Loading Models

# Save torch.save({ 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), }, 'mnist_model.pth') # Load checkpoint = torch.load('mnist_model.pth') model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

ONNX Export

dummy_input = torch.randn(1, 1, 28, 28).to(device) torch.onnx.export( model, dummy_input, "mnist_model.onnx", input_names=["input"], output_names=["output"], dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}} )

11. Further Improvements

  • Data Augmentation: Add more transformations like random zoom, shear, or elastic deformations
  • Architecture: Try ResNet, EfficientNet, or other advanced architectures
  • Regularization: Experiment with different dropout rates, weight decay, or early stopping
  • Hyperparameter Tuning: Use tools like Optuna or Ray Tune for automated optimization
  • Ensemble Methods: Combine predictions from multiple models

12. Resources


Written by Your Name | Last updated: 2024 | License: MIT

Comments

Popular posts from this blog

Tech Duos For Web Development

CIFAR-10 Dataset Classification Using Convolutional Neural Networks (CNNs) With PyTorch

Long-short-term-memory (LSTM) Word Prediction With PyTorch