CIFAR-10 Dataset Classification Using Convolutional Neural Networks (CNNs) With PyTorch
CIFAR-10 Image Classification Using CNN with PyTorch
This comprehensive guide demonstrates how to classify images from the CIFAR-10 dataset using Convolutional Neural Networks (CNNs) built with PyTorch. The project covers the entire machine learning pipeline from data loading to model deployment, including advanced techniques for improving performance.
🚀 Key Features
- Complete PyTorch implementation
- Data augmentation techniques
- Multiple CNN architectures
- Training visualization
- Model evaluation metrics
- Hyperparameter tuning
- Model saving/loading
- Deployment options
📊 Performance
- Baseline model: ~70% accuracy
- Improved model: ~85% accuracy
- ResNet implementation: ~90% accuracy
- Training time: 5-30 minutes (depending on hardware)
🧰 Prerequisites & Installation
1. System Requirements
- Python 3.8 or higher
- 4GB+ RAM (8GB recommended)
- GPU with CUDA support (optional but recommended)
2. Installation Options
Basic Installation
pip install torch torchvision torchaudio matplotlib numpy tqdm pandas seaborn
Installation with CUDA (for GPU acceleration)
pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113
Development Installation (with additional tools)
pip install -r requirements.txt
# requirements.txt contents:
# torch==1.12.1
# torchvision==0.13.1
# matplotlib==3.5.2
# numpy==1.22.4
# tqdm==4.64.0
# pandas==1.4.2
# seaborn==0.11.2
# tensorboard==2.9.1
# ipython==8.4.0
Note: For optimal performance, ensure you have the correct version of PyTorch for your CUDA version. Check PyTorch's official website for the right installation command for your system.
📦 Dataset: CIFAR-10
The CIFAR-10 dataset consists of 60,000 32x32 color images in 10 classes, with 6,000 images per class. The dataset is divided into 50,000 training images and 10,000 test images.
Class Information
Class ID | Class Name | Description |
---|---|---|
0 | airplane | Airplane images |
1 | automobile | Car images |
2 | bird | Bird images |
3 | cat | Cat images |
4 | deer | Deer images |
5 | dog | Dog images |
6 | frog | Frog images |
7 | horse | Horse images |
8 | ship | Ship images |
9 | truck | Truck images |
📥 Step 1: Data Loading and Preprocessing
Basic Data Loading
import torch
import torchvision
import torchvision.transforms as transforms
# Define transformations
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# Load training set
trainset = torchvision.datasets.CIFAR10(
root='./data',
train=True,
download=True,
transform=transform
)
trainloader = torch.utils.data.DataLoader(
trainset,
batch_size=32,
shuffle=True,
num_workers=2
)
# Load test set
testset = torchvision.datasets.CIFAR10(
root='./data',
train=False,
download=True,
transform=transform
)
testloader = torch.utils.data.DataLoader(
testset,
batch_size=32,
shuffle=False,
num_workers=2
)
# Class names
classes = ('plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck')
Advanced Data Augmentation
To improve model generalization, we can add data augmentation:
# Enhanced transformations with data augmentation
train_transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.RandomAffine(0, translate=(0.1, 0.1)),
transforms.RandomResizedCrop(32, scale=(0.8, 1.0)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
test_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
Pro Tip: Data augmentation helps prevent overfitting and improves model generalization by artificially expanding your training dataset with modified versions of existing images.
👀 Step 2: Data Visualization
Sample Images Visualization
import matplotlib.pyplot as plt
import numpy as np
def imshow(img):
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
# Get some random training images
dataiter = iter(trainloader)
images, labels = next(dataiter)
# Show images
imshow(torchvision.utils.make_grid(images[:4]))
# Print labels
print(' '.join(f'{classes[labels[j]]:5s}' for j in range(4)))
Class Distribution Analysis
from collections import Counter
import seaborn as sns
# Get all labels from the training set
all_labels = []
for _, labels in trainloader:
all_labels.extend(labels.numpy())
# Count occurrences of each class
label_counts = Counter(all_labels)
# Plot class distribution
plt.figure(figsize=(10, 5))
sns.barplot(x=[classes[i] for i in label_counts.keys()],
y=list(label_counts.values()))
plt.title('Class Distribution in Training Set')
plt.xlabel('Class')
plt.ylabel('Count')
plt.xticks(rotation=45)
plt.show()
🧠 Step 3: Model Architectures
Basic CNN Model
import torch.nn as nn
import torch.nn.functional as F
class BasicCNN(nn.Module):
def __init__(self):
super(BasicCNN, self).__init__()
self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.dropout = nn.Dropout(0.25)
self.fc1 = nn.Linear(64 * 8 * 8, 512)
self.fc2 = nn.Linear(512, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = self.dropout(x)
x = x.view(-1, 64 * 8 * 8)
x = F.relu(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)
return x
Improved CNN Model
class ImprovedCNN(nn.Module):
def __init__(self):
super(ImprovedCNN, self).__init__()
self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
self.bn1 = nn.BatchNorm2d(64)
self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
self.bn2 = nn.BatchNorm2d(128)
self.conv3 = nn.Conv2d(128, 256, 3, padding=1)
self.bn3 = nn.BatchNorm2d(256)
self.pool = nn.MaxPool2d(2, 2)
self.dropout = nn.Dropout(0.5)
self.fc1 = nn.Linear(256 * 4 * 4, 1024)
self.fc2 = nn.Linear(1024, 512)
self.fc3 = nn.Linear(512, 10)
def forward(self, x):
x = self.pool(F.relu(self.bn1(self.conv1(x))))
x = self.pool(F.relu(self.bn2(self.conv2(x))))
x = self.pool(F.relu(self.bn3(self.conv3(x))))
x = self.dropout(x)
x = x.view(-1, 256 * 4 * 4)
x = F.relu(self.fc1(x))
x = self.dropout(x)
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
ResNet Implementation
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(
in_channels, out_channels, kernel_size=3,
stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(
out_channels, out_channels, kernel_size=3,
stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(out_channels))
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, num_classes=10):
super(ResNet, self).__init__()
self.in_channels = 64
self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.layer1 = self._make_layer(64, 2, stride=1)
self.layer2 = self._make_layer(128, 2, stride=2)
self.layer3 = self._make_layer(256, 2, stride=2)
self.linear = nn.Linear(256, num_classes)
def _make_layer(self, out_channels, num_blocks, stride):
strides = [stride] + [1]*(num_blocks-1)
layers = []
for stride in strides:
layers.append(ResidualBlock(self.in_channels, out_channels, stride))
self.in_channels = out_channels
return nn.Sequential(*layers)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = F.avg_pool2d(out, 8)
out = out.view(out.size(0), -1)
out = self.linear(out)
return out
BasicCNN
- 2 Conv layers
- Max Pooling
- Dropout
- 2 FC layers
- ~70% accuracy
ImprovedCNN
- 3 Conv layers
- BatchNorm
- More filters
- Deeper FC
- ~85% accuracy
ResNet
- Residual blocks
- Skip connections
- BatchNorm
- Deep architecture
- ~90% accuracy
⚙️ Step 4: Training Configuration
Loss Function and Optimizer
import torch.optim as optim
model = ImprovedCNN() # Choose your model
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Learning rate scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
mode='min',
factor=0.1,
patience=5,
verbose=True
)
Training Loop with Progress Tracking
from tqdm import tqdm
import time
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
since = time.time()
best_acc = 0.0
train_loss_history = []
val_loss_history = []
train_acc_history = []
val_acc_history = []
for epoch in range(num_epochs):
print(f'Epoch {epoch}/{num_epochs - 1}')
print('-' * 10)
# Each epoch has a training and validation phase
for phase in ['train', 'val']:
if phase == 'train':
model.train() # Set model to training mode
dataloader = trainloader
else:
model.eval() # Set model to evaluate mode
dataloader = valloader
running_loss = 0.0
running_corrects = 0
# Iterate over data
for inputs, labels in tqdm(dataloader, desc=phase):
inputs = inputs.to(device)
labels = labels.to(device)
# Zero the parameter gradients
optimizer.zero_grad()
# Forward
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
# Backward + optimize only if in training phase
if phase == 'train':
loss.backward()
optimizer.step()
# Statistics
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
epoch_loss = running_loss / len(dataloader.dataset)
epoch_acc = running_corrects.double() / len(dataloader.dataset)
print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
# Record history
if phase == 'train':
train_loss_history.append(epoch_loss)
train_acc_history.append(epoch_acc)
# Update learning rate
scheduler.step(epoch_loss)
else:
val_loss_history.append(epoch_loss)
val_acc_history.append(epoch_acc)
# Deep copy the model if it's the best so far
if epoch_acc > best_acc:
best_acc = epoch_acc
torch.save(model.state_dict(), 'best_model.pth')
print()
time_elapsed = time.time() - since
print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
print(f'Best val Acc: {best_acc:4f}')
return {
'model': model,
'train_loss': train_loss_history,
'val_loss': val_loss_history,
'train_acc': train_acc_history,
'val_acc': val_acc_history
}
Training Visualization
def plot_training_history(history):
plt.figure(figsize=(12, 4))
# Loss plot
plt.subplot(1, 2, 1)
plt.plot(history['train_loss'], label='train')
plt.plot(history['val_loss'], label='val')
plt.title('Loss over epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
# Accuracy plot
plt.subplot(1, 2, 2)
plt.plot(history['train_acc'], label='train')
plt.plot(history['val_acc'], label='val')
plt.title('Accuracy over epochs')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.tight_layout()
plt.show()
📊 Step 5: Model Evaluation
Test Set Evaluation
def evaluate_model(model, testloader):
model.eval()
correct = 0
total = 0
all_preds = []
all_labels = []
with torch.no_grad():
for images, labels in testloader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
all_preds.extend(predicted.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
accuracy = 100 * correct / total
print(f'Accuracy on test set: {accuracy:.2f}%')
return all_preds, all_labels, accuracy
Confusion Matrix
from sklearn.metrics import confusion_matrix
import seaborn as sns
def plot_confusion_matrix(all_labels, all_preds, classes):
cm = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=classes, yticklabels=classes)
plt.title('Confusion Matrix')
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()
Class-wise Accuracy
def class_accuracy(all_labels, all_preds, classes):
cm = confusion_matrix(all_labels, all_preds)
class_acc = cm.diagonal() / cm.sum(axis=1)
plt.figure(figsize=(10, 5))
sns.barplot(x=classes, y=class_acc)
plt.title('Class-wise Accuracy')
plt.ylabel('Accuracy')
plt.xticks(rotation=45)
plt.ylim(0, 1)
plt.show()
for i, acc in enumerate(class_acc):
print(f'{classes[i]}: {acc:.2f}')
🔍 Step 6: Model Interpretation
Visualizing Feature Maps
def visualize_feature_maps(model, image):
# Choose a layer to visualize
layer = model.conv1
# Hook to get the feature maps
features = None
def hook(module, input, output):
nonlocal features
features = output.detach()
handle = layer.register_forward_hook(hook)
# Forward pass
model.eval()
with torch.no_grad():
_ = model(image.unsqueeze(0))
# Remove hook
handle.remove()
# Visualize feature maps
plt.figure(figsize=(12, 8))
for i in range(min(16, features.shape[1])): # Show first 16 filters
plt.subplot(4, 4, i+1)
plt.imshow(features[0, i].cpu(), cmap='viridis')
plt.axis('off')
plt.suptitle('Feature Maps from First Conv Layer')
plt.show()
Grad-CAM Visualization
class GradCAM:
def __init__(self, model, target_layer):
self.model = model
self.target_layer = target_layer
self.gradient = None
self.activation = None
# Register hooks
target_layer.register_forward_hook(self.save_activation)
target_layer.register_backward_hook(self.save_gradient)
def save_activation(self, module, input, output):
self.activation = output.detach()
def save_gradient(self, module, grad_input, grad_output):
self.gradient = grad_output[0].detach()
def __call__(self, x, class_idx=None):
# Forward pass
output = self.model(x)
if class_idx is None:
class_idx = output.argmax(dim=1).item()
# Backward pass for specific class
self.model.zero_grad()
one_hot = torch.zeros_like(output)
one_hot[0][class_idx] = 1
output.backward(gradient=one_hot)
# Calculate weights
weights = self.gradient.mean(dim=(2, 3), keepdim=True)
cam = (weights * self.activation).sum(dim=1, keepdim=True)
cam = F.relu(cam)
cam = F.interpolate(cam, x.shape[2:], mode='bilinear', align_corners=False)
cam = cam - cam.min()
cam = cam / cam.max()
return cam.squeeze().cpu().numpy()
def visualize_gradcam(model, image, target_layer):
grad_cam = GradCAM(model, target_layer)
cam = grad_cam(image.unsqueeze(0))
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.imshow(image.permute(1, 2, 0).cpu().numpy() / 2 + 0.5)
plt.title('Original Image')
plt.axis('off')
plt.subplot(1, 2, 2)
plt.imshow(image.permute(1, 2, 0).cpu().numpy() / 2 + 0.5)
plt.imshow(cam, cmap='jet', alpha=0.5)
plt.title('Grad-CAM')
plt.axis('off')
plt.show()
💾 Step 7: Model Saving and Loading
Saving the Model
# Save entire model
torch.save(model, 'cifar10_model.pth')
# Save only state dict (recommended)
torch.save(model.state_dict(), 'cifar10_model_state_dict.pth')
# Save with additional metadata
checkpoint = {
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'epoch': epoch,
'loss': loss,
'accuracy': accuracy
}
torch.save(checkpoint, 'checkpoint.pth')
Loading the Model
# Load entire model
model = torch.load('cifar10_model.pth')
# Load state dict
model = ImprovedCNN() # Initialize your model first
model.load_state_dict(torch.load('cifar10_model_state_dict.pth'))
# Load checkpoint
checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
accuracy = checkpoint['accuracy']
🚀 Step 8: Deployment Options
1. Flask Web Application
from flask import Flask, request, jsonify
from PIL import Image
import io
app = Flask(__name__)
model = load_model() # Your model loading function
@app.route('/predict', methods=['POST'])
def predict():
if 'file' not in request.files:
return jsonify({'error': 'no file uploaded'}), 400
file = request.files['file']
image = Image.open(io.BytesIO(file.read()))
image = preprocess(image) # Your preprocessing function
prediction = model(image)
return jsonify({
'class': classes[prediction.argmax()],
'confidence': float(prediction.max())
})
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000)
2. FastAPI Application
from fastapi import FastAPI, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
app = FastAPI()
# CORS setup
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
contents = await file.read()
image = Image.open(io.BytesIO(contents))
image = preprocess(image)
prediction = model(image)
return {
"class": classes[prediction.argmax()],
"confidence": float(prediction.max())
}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)
3. ONNX Export for Production
import torch.onnx
# Input to the model
dummy_input = torch.randn(1, 3, 32, 32)
# Export the model
torch.onnx.export(
model, # model being run
dummy_input, # model input
"cifar10_model.onnx", # where to save the model
export_params=True, # store the trained parameter weights
opset_version=11, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding
input_names=['input'], # model's input names
output_names=['output'], # model's output names
dynamic_axes={
'input': {0: 'batch_size'}, # variable length axes
'output': {0: 'batch_size'}
}
)
📈 Performance Optimization
1. Mixed Precision Training
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
for inputs, labels in trainloader:
inputs = inputs.to(device)
labels = labels.to(device)
optimizer.zero_grad()
with autocast():
outputs = model(inputs)
loss = criterion(outputs, labels)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
2. Data Loading Optimization
# Use pinned memory for faster host to device transfers
trainloader = torch.utils.data.DataLoader(
trainset,
batch_size=128,
shuffle=True,
num_workers=4,
pin_memory=True
)
3. Learning Rate Finder
def find_lr(model, criterion, optimizer, trainloader, init_value=1e-8, end_value=10, beta=0.98):
num = len(trainloader) - 1
mult = (end_value / init_value) ** (1/num)
lr = init_value
optimizer.param_groups[0]['lr'] = lr
avg_loss = 0.
best_loss = 0.
losses = []
log_lrs = []
for batch_num, (inputs, targets) in enumerate(trainloader):
inputs = inputs.to(device)
targets = targets.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
avg_loss = beta * avg_loss + (1-beta) * loss.item()
smoothed_loss = avg_loss / (1 - beta**(batch_num+1))
if batch_num > 0 and smoothed_loss > 4 * best_loss:
return log_lrs, losses
if smoothed_loss < best_loss or batch_num == 0:
best_loss = smoothed_loss
losses.append(smoothed_loss)
log_lrs.append(math.log10(lr))
loss.backward()
optimizer.step()
lr *= mult
optimizer.param_groups[0]['lr'] = lr
return log_lrs, losses
📝 Best Practices
✅ Do
- Use data augmentation to prevent overfitting
- Normalize your input data
- Use batch normalization for deeper networks
- Monitor training with TensorBoard
- Save checkpoints regularly
- Experiment with different architectures
- Use learning rate scheduling
❌ Don't
- Use too large batch sizes (start with 32-128)
- Forget to shuffle your training data
- Use too high learning rates
- Ignore class imbalance in your dataset
- Overfit on your test set by tuning too much
- Forget to set model to eval mode during inference
📚 Resources & Further Reading
Official Documentation
Research Papers
- Deep Residual Learning for Image Recognition (ResNet)
- Very Deep Convolutional Networks (VGG)
- MobileNets: Efficient Convolutional Neural Networks
Books
Online Courses
- Practical Deep Learning for Coders (fast.ai)
- Stanford CS231n: Convolutional Neural Networks for Visual Recognition
❓ Frequently Asked Questions
Why is my model not learning?
Common reasons include:
- Learning rate is too high or too low
- Data isn't properly normalized
- Model architecture is too simple/complex
- Gradient vanishing/exploding
How can I improve my model's accuracy?
Try these techniques:
- Add more data or use data augmentation
- Use a deeper/more complex architecture
- Add batch normalization
- Fine-tune hyperparameters
How do I choose batch size?
General guidelines:
- Start with 32 or 64
- Larger batch sizes need higher learning rates
- Smaller batches often generalize better
- Limited by GPU memory
When should I stop training?
Consider stopping when:
- Validation loss stops improving
- Training loss is very close to zero
- You see signs of overfitting
- You've reached your time/compute budget
📜 License
This project is licensed under the MIT License - see the LICENSE file for details.
🙏 Acknowledgments
- PyTorch team for the amazing deep learning framework
- Alex Krizhevsky for creating the CIFAR-10 dataset
- All the researchers who developed CNN architectures
Comments