Flower Classification with Transfer Learning using MobileNetV2

Flower Classification with MobileNetV2 | Complete Guide

Flower Classification with Transfer Learning using MobileNetV2

Project Overview

This project demonstrates transfer learning using MobileNetV2 to classify flowers from the Oxford Flowers102 dataset, which contains 102 different flower categories.

Why Transfer Learning?

  • Efficiency: Leverages pre-trained weights from ImageNet (1.4M images)
  • Performance: Achieves good accuracy with limited training data
  • Resource-friendly: MobileNetV2 is optimized for mobile/edge devices

System Architecture

[MobileNetV2 Backbone] → [Feature Extractor] → [Custom Classifier Head (102 units)]

Input: 224×224 RGB images → Output: 102-class probabilities

Implementation Details

flower_classification.py
"""
Flower Classification System using MobileNetV2
Complete implementation with training and evaluation
"""

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix

Key Dependencies

  • PyTorch: Core deep learning framework
  • Torchvision: For datasets, models and transforms
  • tqdm: Progress bars for training visualization
  • scikit-learn: For evaluation metrics

Data Pipeline

Data Augmentation Strategy

We apply different transformations to training and validation data:

# Training transforms - aggressive augmentation
transform_train = transforms.Compose([
    transforms.Resize((224, 224)),   # MobileNetV2 requires 224x224 input
    transforms.RandomHorizontalFlip(p=0.5),  # 50% chance of flip
    transforms.RandomRotation(15),      # Rotate between -15 and +15 degrees
    transforms.ColorJitter(
        brightness=0.2, 
        contrast=0.2,
        saturation=0.2, 
        hue=0.1
    ),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],  # ImageNet stats
        std=[0.229, 0.224, 0.225]
    )
])

# Validation transforms - minimal preprocessing
transform_test = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

Why These Transformations?

  • RandomHorizontalFlip: Simulates different orientations
  • RandomRotation: Accounts for camera angle variations
  • ColorJitter: Makes model robust to lighting changes
  • Normalization: Uses ImageNet stats for pretrained weights

Model Architecture

def initialize_model():
    # Load pretrained MobileNetV2
    model = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1)
    
    # Freeze all layers except the final classifier
    for param in model.parameters():
        param.requires_grad = False
    
    # Replace the classifier head
    num_ftrs = model.classifier[1].in_features
    model.classifier[1] = nn.Sequential(
        nn.Dropout(0.2),  # Regularization
        nn.Linear(num_ftrs, 102),  # 102 flower classes
        nn.LogSoftmax(dim=1)  # For NLLLoss
    )
    
    return model

Architecture Decisions

  • Layer Freezing: Only train the classifier head initially
  • Dropout: Added for regularization (prevents overfitting)
  • LogSoftmax: Used with NLLLoss for numerical stability
  • Parameter Count: ~3.4M total, only ~130K trainable

Training Process

def train_model(model, train_loader, val_loader, epochs=10):
    # Loss and optimizer
    criterion = nn.NLLLoss()
    optimizer = optim.Adam(
        model.classifier.parameters(), 
        lr=0.001,
        weight_decay=1e-4  # L2 regularization
    )
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, 
        mode='min', 
        factor=0.1, 
        patience=3,
        verbose=True
    )
    
    best_acc = 0.0
    
    for epoch in range(epochs):
        # Training phase
        model.train()
        running_loss = 0.0
        
        for inputs, labels in tqdm(train_loader, desc=f'Train Epoch {epoch+1}'):
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item() * inputs.size(0)
        
        # Validation phase
        val_loss, val_acc = evaluate(model, val_loader, criterion)
        scheduler.step(val_loss)
        
        # Save best model
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), 'best_model.pth')
        
        print(f'Epoch {epoch+1}: Train Loss: {running_loss/len(train_loader.dataset):.4f}, '
              f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')

Training Optimization

  • Learning Rate Scheduling: Reduces LR when validation loss plateaus
  • Early Stopping: Implemented through LR scheduler patience
  • Weight Decay: L2 regularization to prevent overfitting
  • Model Checkpointing: Saves best performing model

Evaluation Metrics

def evaluate(model, data_loader, criterion):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, labels in data_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    return running_loss/len(data_loader.dataset), 100*correct/total

def generate_classification_report(model, data_loader):
    all_preds = []
    all_labels = []
    
    model.eval()
    with torch.no_grad():
        for inputs, labels in data_loader:
            inputs = inputs.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    # Confusion Matrix
    plt.figure(figsize=(15,15))
    sns.heatmap(
        confusion_matrix(all_labels, all_preds),
        annot=False,
        fmt='d',
        cmap='Blues'
    )
    plt.title('Confusion Matrix')
    plt.savefig('confusion_matrix.png')
    
    # Classification Report
    print(classification_report(
        all_labels,
        all_preds,
        target_names=data_loader.dataset.classes,
        digits=4
    ))

Complete System Workflow

def main():
    # Initialize
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Data
    train_set, val_set = load_datasets()
    train_loader = DataLoader(train_set, batch_size=32, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_set, batch_size=32, shuffle=False, num_workers=4)
    
    # Model
    model = initialize_model().to(device)
    
    # Train
    train_model(model, train_loader, val_loader, epochs=15)
    
    # Evaluate
    model.load_state_dict(torch.load('best_model.pth'))
    generate_classification_report(model, val_loader)
    
    # Optional: Full model fine-tuning
    unfreeze_model(model)  # Unfreeze some backbone layers
    train_model(model, train_loader, val_loader, epochs=5)  # Fine-tune with lower LR

if __name__ == '__main__':
    main()

Expected Performance

  • Initial training: ~85-90% accuracy (classifier only)
  • After fine-tuning: ~92-95% accuracy
  • Inference speed: ~15ms per image on GPU

Comments

Popular posts from this blog

Tech Duos For Web Development

CIFAR-10 Dataset Classification Using Convolutional Neural Networks (CNNs) With PyTorch

Long-short-term-memory (LSTM) Word Prediction With PyTorch