CIFAR-10 Dataset Classification Using Convolutional Neural Networks (CNNs) With PyTorch

CIFAR-10 Image Classification with CNN (PyTorch) - Complete Guide

CIFAR-10 Image Classification Using CNN with PyTorch

This comprehensive guide demonstrates how to classify images from the CIFAR-10 dataset using Convolutional Neural Networks (CNNs) built with PyTorch. The project covers the entire machine learning pipeline from data loading to model deployment, including advanced techniques for improving performance.

🚀 Key Features

  • Complete PyTorch implementation
  • Data augmentation techniques
  • Multiple CNN architectures
  • Training visualization
  • Model evaluation metrics
  • Hyperparameter tuning
  • Model saving/loading
  • Deployment options

📊 Performance

  • Baseline model: ~70% accuracy
  • Improved model: ~85% accuracy
  • ResNet implementation: ~90% accuracy
  • Training time: 5-30 minutes (depending on hardware)

🧰 Prerequisites & Installation

1. System Requirements

  • Python 3.8 or higher
  • 4GB+ RAM (8GB recommended)
  • GPU with CUDA support (optional but recommended)

2. Installation Options

Basic Installation

pip install torch torchvision torchaudio matplotlib numpy tqdm pandas seaborn

Installation with CUDA (for GPU acceleration)

pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113

Development Installation (with additional tools)

pip install -r requirements.txt
# requirements.txt contents:
# torch==1.12.1
# torchvision==0.13.1
# matplotlib==3.5.2
# numpy==1.22.4
# tqdm==4.64.0
# pandas==1.4.2
# seaborn==0.11.2
# tensorboard==2.9.1
# ipython==8.4.0
Note: For optimal performance, ensure you have the correct version of PyTorch for your CUDA version. Check PyTorch's official website for the right installation command for your system.

📦 Dataset: CIFAR-10

The CIFAR-10 dataset consists of 60,000 32x32 color images in 10 classes, with 6,000 images per class. The dataset is divided into 50,000 training images and 10,000 test images.

Class Information

Class ID Class Name Description
0 airplane Airplane images
1 automobile Car images
2 bird Bird images
3 cat Cat images
4 deer Deer images
5 dog Dog images
6 frog Frog images
7 horse Horse images
8 ship Ship images
9 truck Truck images

📥 Step 1: Data Loading and Preprocessing

Basic Data Loading

import torch
import torchvision
import torchvision.transforms as transforms

# Define transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load training set
trainset = torchvision.datasets.CIFAR10(
    root='./data', 
    train=True,
    download=True, 
    transform=transform
)
trainloader = torch.utils.data.DataLoader(
    trainset, 
    batch_size=32,
    shuffle=True, 
    num_workers=2
)

# Load test set
testset = torchvision.datasets.CIFAR10(
    root='./data',
    train=False,
    download=True,
    transform=transform
)
testloader = torch.utils.data.DataLoader(
    testset,
    batch_size=32,
    shuffle=False,
    num_workers=2
)

# Class names
classes = ('plane', 'car', 'bird', 'cat', 'deer', 
           'dog', 'frog', 'horse', 'ship', 'truck')

Advanced Data Augmentation

To improve model generalization, we can add data augmentation:

# Enhanced transformations with data augmentation
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.RandomAffine(0, translate=(0.1, 0.1)),
    transforms.RandomResizedCrop(32, scale=(0.8, 1.0)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
Pro Tip: Data augmentation helps prevent overfitting and improves model generalization by artificially expanding your training dataset with modified versions of existing images.

👀 Step 2: Data Visualization

Sample Images Visualization

import matplotlib.pyplot as plt
import numpy as np

def imshow(img):
    img = img / 2 + 0.5  # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

# Get some random training images
dataiter = iter(trainloader)
images, labels = next(dataiter)

# Show images
imshow(torchvision.utils.make_grid(images[:4]))
# Print labels
print(' '.join(f'{classes[labels[j]]:5s}' for j in range(4)))

Class Distribution Analysis

from collections import Counter
import seaborn as sns

# Get all labels from the training set
all_labels = []
for _, labels in trainloader:
    all_labels.extend(labels.numpy())
    
# Count occurrences of each class
label_counts = Counter(all_labels)

# Plot class distribution
plt.figure(figsize=(10, 5))
sns.barplot(x=[classes[i] for i in label_counts.keys()], 
            y=list(label_counts.values()))
plt.title('Class Distribution in Training Set')
plt.xlabel('Class')
plt.ylabel('Count')
plt.xticks(rotation=45)
plt.show()

🧠 Step 3: Model Architectures

Basic CNN Model

import torch.nn as nn
import torch.nn.functional as F

class BasicCNN(nn.Module):
    def __init__(self):
        super(BasicCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(0.25)
        self.fc1 = nn.Linear(64 * 8 * 8, 512)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.dropout(x)
        x = x.view(-1, 64 * 8 * 8)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

Improved CNN Model

class ImprovedCNN(nn.Module):
    def __init__(self):
        super(ImprovedCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.conv3 = nn.Conv2d(128, 256, 3, padding=1)
        self.bn3 = nn.BatchNorm2d(256)
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(0.5)
        self.fc1 = nn.Linear(256 * 4 * 4, 1024)
        self.fc2 = nn.Linear(1024, 512)
        self.fc3 = nn.Linear(512, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.bn1(self.conv1(x))))
        x = self.pool(F.relu(self.bn2(self.conv2(x))))
        x = self.pool(F.relu(self.bn3(self.conv3(x))))
        x = self.dropout(x)
        x = x.view(-1, 256 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

ResNet Implementation

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_channels, out_channels, kernel_size=3, 
            stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(
            out_channels, out_channels, kernel_size=3,
            stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels))
    
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class ResNet(nn.Module):
    def __init__(self, num_classes=10):
        super(ResNet, self).__init__()
        self.in_channels = 64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(64, 2, stride=1)
        self.layer2 = self._make_layer(128, 2, stride=2)
        self.layer3 = self._make_layer(256, 2, stride=2)
        self.linear = nn.Linear(256, num_classes)
    
    def _make_layer(self, out_channels, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(ResidualBlock(self.in_channels, out_channels, stride))
            self.in_channels = out_channels
        return nn.Sequential(*layers)
    
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = F.avg_pool2d(out, 8)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

BasicCNN

  • 2 Conv layers
  • Max Pooling
  • Dropout
  • 2 FC layers
  • ~70% accuracy

ImprovedCNN

  • 3 Conv layers
  • BatchNorm
  • More filters
  • Deeper FC
  • ~85% accuracy

ResNet

  • Residual blocks
  • Skip connections
  • BatchNorm
  • Deep architecture
  • ~90% accuracy

⚙️ Step 4: Training Configuration

Loss Function and Optimizer

import torch.optim as optim

model = ImprovedCNN()  # Choose your model
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Learning rate scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 
    mode='min', 
    factor=0.1, 
    patience=5, 
    verbose=True
)

Training Loop with Progress Tracking

from tqdm import tqdm
import time

def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()
    best_acc = 0.0
    train_loss_history = []
    val_loss_history = []
    train_acc_history = []
    val_acc_history = []

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
                dataloader = trainloader
            else:
                model.eval()   # Set model to evaluate mode
                dataloader = valloader

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data
            for inputs, labels in tqdm(dataloader, desc=phase):
                inputs = inputs.to(device)
                labels = labels.to(device)

                # Zero the parameter gradients
                optimizer.zero_grad()

                # Forward
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # Backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # Statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(dataloader.dataset)
            epoch_acc = running_corrects.double() / len(dataloader.dataset)

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            # Record history
            if phase == 'train':
                train_loss_history.append(epoch_loss)
                train_acc_history.append(epoch_acc)
                # Update learning rate
                scheduler.step(epoch_loss)
            else:
                val_loss_history.append(epoch_loss)
                val_acc_history.append(epoch_acc)

                # Deep copy the model if it's the best so far
                if epoch_acc > best_acc:
                    best_acc = epoch_acc
                    torch.save(model.state_dict(), 'best_model.pth')

        print()

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:4f}')

    return {
        'model': model,
        'train_loss': train_loss_history,
        'val_loss': val_loss_history,
        'train_acc': train_acc_history,
        'val_acc': val_acc_history
    }

Training Visualization

def plot_training_history(history):
    plt.figure(figsize=(12, 4))
    
    # Loss plot
    plt.subplot(1, 2, 1)
    plt.plot(history['train_loss'], label='train')
    plt.plot(history['val_loss'], label='val')
    plt.title('Loss over epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    
    # Accuracy plot
    plt.subplot(1, 2, 2)
    plt.plot(history['train_acc'], label='train')
    plt.plot(history['val_acc'], label='val')
    plt.title('Accuracy over epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    
    plt.tight_layout()
    plt.show()

📊 Step 5: Model Evaluation

Test Set Evaluation

def evaluate_model(model, testloader):
    model.eval()
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in testloader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    accuracy = 100 * correct / total
    print(f'Accuracy on test set: {accuracy:.2f}%')
    return all_preds, all_labels, accuracy

Confusion Matrix

from sklearn.metrics import confusion_matrix
import seaborn as sns

def plot_confusion_matrix(all_labels, all_preds, classes):
    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=classes, yticklabels=classes)
    plt.title('Confusion Matrix')
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.show()

Class-wise Accuracy

def class_accuracy(all_labels, all_preds, classes):
    cm = confusion_matrix(all_labels, all_preds)
    class_acc = cm.diagonal() / cm.sum(axis=1)
    
    plt.figure(figsize=(10, 5))
    sns.barplot(x=classes, y=class_acc)
    plt.title('Class-wise Accuracy')
    plt.ylabel('Accuracy')
    plt.xticks(rotation=45)
    plt.ylim(0, 1)
    plt.show()
    
    for i, acc in enumerate(class_acc):
        print(f'{classes[i]}: {acc:.2f}')

🔍 Step 6: Model Interpretation

Visualizing Feature Maps

def visualize_feature_maps(model, image):
    # Choose a layer to visualize
    layer = model.conv1
    
    # Hook to get the feature maps
    features = None
    def hook(module, input, output):
        nonlocal features
        features = output.detach()
    
    handle = layer.register_forward_hook(hook)
    
    # Forward pass
    model.eval()
    with torch.no_grad():
        _ = model(image.unsqueeze(0))
    
    # Remove hook
    handle.remove()
    
    # Visualize feature maps
    plt.figure(figsize=(12, 8))
    for i in range(min(16, features.shape[1])):  # Show first 16 filters
        plt.subplot(4, 4, i+1)
        plt.imshow(features[0, i].cpu(), cmap='viridis')
        plt.axis('off')
    plt.suptitle('Feature Maps from First Conv Layer')
    plt.show()

Grad-CAM Visualization

class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradient = None
        self.activation = None
        
        # Register hooks
        target_layer.register_forward_hook(self.save_activation)
        target_layer.register_backward_hook(self.save_gradient)
    
    def save_activation(self, module, input, output):
        self.activation = output.detach()
    
    def save_gradient(self, module, grad_input, grad_output):
        self.gradient = grad_output[0].detach()
    
    def __call__(self, x, class_idx=None):
        # Forward pass
        output = self.model(x)
        
        if class_idx is None:
            class_idx = output.argmax(dim=1).item()
        
        # Backward pass for specific class
        self.model.zero_grad()
        one_hot = torch.zeros_like(output)
        one_hot[0][class_idx] = 1
        output.backward(gradient=one_hot)
        
        # Calculate weights
        weights = self.gradient.mean(dim=(2, 3), keepdim=True)
        cam = (weights * self.activation).sum(dim=1, keepdim=True)
        cam = F.relu(cam)
        cam = F.interpolate(cam, x.shape[2:], mode='bilinear', align_corners=False)
        cam = cam - cam.min()
        cam = cam / cam.max()
        
        return cam.squeeze().cpu().numpy()

def visualize_gradcam(model, image, target_layer):
    grad_cam = GradCAM(model, target_layer)
    cam = grad_cam(image.unsqueeze(0))
    
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.imshow(image.permute(1, 2, 0).cpu().numpy() / 2 + 0.5)
    plt.title('Original Image')
    plt.axis('off')
    
    plt.subplot(1, 2, 2)
    plt.imshow(image.permute(1, 2, 0).cpu().numpy() / 2 + 0.5)
    plt.imshow(cam, cmap='jet', alpha=0.5)
    plt.title('Grad-CAM')
    plt.axis('off')
    plt.show()

💾 Step 7: Model Saving and Loading

Saving the Model

# Save entire model
torch.save(model, 'cifar10_model.pth')

# Save only state dict (recommended)
torch.save(model.state_dict(), 'cifar10_model_state_dict.pth')

# Save with additional metadata
checkpoint = {
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'epoch': epoch,
    'loss': loss,
    'accuracy': accuracy
}
torch.save(checkpoint, 'checkpoint.pth')

Loading the Model

# Load entire model
model = torch.load('cifar10_model.pth')

# Load state dict
model = ImprovedCNN()  # Initialize your model first
model.load_state_dict(torch.load('cifar10_model_state_dict.pth'))

# Load checkpoint
checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
accuracy = checkpoint['accuracy']

🚀 Step 8: Deployment Options

1. Flask Web Application

from flask import Flask, request, jsonify
from PIL import Image
import io

app = Flask(__name__)
model = load_model()  # Your model loading function

@app.route('/predict', methods=['POST'])
def predict():
    if 'file' not in request.files:
        return jsonify({'error': 'no file uploaded'}), 400
    
    file = request.files['file']
    image = Image.open(io.BytesIO(file.read()))
    image = preprocess(image)  # Your preprocessing function
    prediction = model(image)
    
    return jsonify({
        'class': classes[prediction.argmax()],
        'confidence': float(prediction.max())
    })

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=5000)

2. FastAPI Application

from fastapi import FastAPI, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
import uvicorn

app = FastAPI()

# CORS setup
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

@app.post("/predict")
async def predict(file: UploadFile = File(...)):
    contents = await file.read()
    image = Image.open(io.BytesIO(contents))
    image = preprocess(image)
    prediction = model(image)
    
    return {
        "class": classes[prediction.argmax()],
        "confidence": float(prediction.max())
    }

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)

3. ONNX Export for Production

import torch.onnx

# Input to the model
dummy_input = torch.randn(1, 3, 32, 32)

# Export the model
torch.onnx.export(
    model,                     # model being run
    dummy_input,               # model input
    "cifar10_model.onnx",      # where to save the model
    export_params=True,        # store the trained parameter weights
    opset_version=11,          # the ONNX version to export the model to
    do_constant_folding=True,  # whether to execute constant folding
    input_names=['input'],     # model's input names
    output_names=['output'],   # model's output names
    dynamic_axes={
        'input': {0: 'batch_size'},    # variable length axes
        'output': {0: 'batch_size'}
    }
)

📈 Performance Optimization

1. Mixed Precision Training

from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()

for inputs, labels in trainloader:
    inputs = inputs.to(device)
    labels = labels.to(device)
    
    optimizer.zero_grad()
    
    with autocast():
        outputs = model(inputs)
        loss = criterion(outputs, labels)
    
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

2. Data Loading Optimization

# Use pinned memory for faster host to device transfers
trainloader = torch.utils.data.DataLoader(
    trainset, 
    batch_size=128, 
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

3. Learning Rate Finder

def find_lr(model, criterion, optimizer, trainloader, init_value=1e-8, end_value=10, beta=0.98):
    num = len(trainloader) - 1
    mult = (end_value / init_value) ** (1/num)
    lr = init_value
    optimizer.param_groups[0]['lr'] = lr
    avg_loss = 0.
    best_loss = 0.
    losses = []
    log_lrs = []
    
    for batch_num, (inputs, targets) in enumerate(trainloader):
        inputs = inputs.to(device)
        targets = targets.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        
        avg_loss = beta * avg_loss + (1-beta) * loss.item()
        smoothed_loss = avg_loss / (1 - beta**(batch_num+1))
        
        if batch_num > 0 and smoothed_loss > 4 * best_loss:
            return log_lrs, losses
        
        if smoothed_loss < best_loss or batch_num == 0:
            best_loss = smoothed_loss
            
        losses.append(smoothed_loss)
        log_lrs.append(math.log10(lr))
        
        loss.backward()
        optimizer.step()
        
        lr *= mult
        optimizer.param_groups[0]['lr'] = lr
    
    return log_lrs, losses

📝 Best Practices

✅ Do

  • Use data augmentation to prevent overfitting
  • Normalize your input data
  • Use batch normalization for deeper networks
  • Monitor training with TensorBoard
  • Save checkpoints regularly
  • Experiment with different architectures
  • Use learning rate scheduling

❌ Don't

  • Use too large batch sizes (start with 32-128)
  • Forget to shuffle your training data
  • Use too high learning rates
  • Ignore class imbalance in your dataset
  • Overfit on your test set by tuning too much
  • Forget to set model to eval mode during inference

📚 Resources & Further Reading

Official Documentation

Research Papers

Books

Online Courses

❓ Frequently Asked Questions

Why is my model not learning?

Common reasons include:

  • Learning rate is too high or too low
  • Data isn't properly normalized
  • Model architecture is too simple/complex
  • Gradient vanishing/exploding

How can I improve my model's accuracy?

Try these techniques:

  • Add more data or use data augmentation
  • Use a deeper/more complex architecture
  • Add batch normalization
  • Fine-tune hyperparameters

How do I choose batch size?

General guidelines:

  • Start with 32 or 64
  • Larger batch sizes need higher learning rates
  • Smaller batches often generalize better
  • Limited by GPU memory

When should I stop training?

Consider stopping when:

  • Validation loss stops improving
  • Training loss is very close to zero
  • You see signs of overfitting
  • You've reached your time/compute budget

📜 License

This project is licensed under the MIT License - see the LICENSE file for details.

🙏 Acknowledgments

  • PyTorch team for the amazing deep learning framework
  • Alex Krizhevsky for creating the CIFAR-10 dataset
  • All the researchers who developed CNN architectures

Comments

Popular posts from this blog

Tech Duos For Web Development

Long-short-term-memory (LSTM) Word Prediction With PyTorch