Flower Classification with Transfer Learning using MobileNetV2
Flower Classification with Transfer Learning using MobileNetV2
Project Overview
This project demonstrates transfer learning using MobileNetV2 to classify flowers from the Oxford Flowers102 dataset, which contains 102 different flower categories.
Why Transfer Learning?
- Efficiency: Leverages pre-trained weights from ImageNet (1.4M images)
- Performance: Achieves good accuracy with limited training data
- Resource-friendly: MobileNetV2 is optimized for mobile/edge devices
System Architecture
[MobileNetV2 Backbone] → [Feature Extractor] → [Custom Classifier Head (102 units)]
Input: 224×224 RGB images → Output: 102-class probabilities
Implementation Details
flower_classification.py
"""
Flower Classification System using MobileNetV2
Complete implementation with training and evaluation
"""
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix
Key Dependencies
- PyTorch: Core deep learning framework
- Torchvision: For datasets, models and transforms
- tqdm: Progress bars for training visualization
- scikit-learn: For evaluation metrics
Data Pipeline
Data Augmentation Strategy
We apply different transformations to training and validation data:
# Training transforms - aggressive augmentation
transform_train = transforms.Compose([
transforms.Resize((224, 224)), # MobileNetV2 requires 224x224 input
transforms.RandomHorizontalFlip(p=0.5), # 50% chance of flip
transforms.RandomRotation(15), # Rotate between -15 and +15 degrees
transforms.ColorJitter(
brightness=0.2,
contrast=0.2,
saturation=0.2,
hue=0.1
),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], # ImageNet stats
std=[0.229, 0.224, 0.225]
)
])
# Validation transforms - minimal preprocessing
transform_test = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
Why These Transformations?
- RandomHorizontalFlip: Simulates different orientations
- RandomRotation: Accounts for camera angle variations
- ColorJitter: Makes model robust to lighting changes
- Normalization: Uses ImageNet stats for pretrained weights
Model Architecture
def initialize_model():
# Load pretrained MobileNetV2
model = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1)
# Freeze all layers except the final classifier
for param in model.parameters():
param.requires_grad = False
# Replace the classifier head
num_ftrs = model.classifier[1].in_features
model.classifier[1] = nn.Sequential(
nn.Dropout(0.2), # Regularization
nn.Linear(num_ftrs, 102), # 102 flower classes
nn.LogSoftmax(dim=1) # For NLLLoss
)
return model
Architecture Decisions
- Layer Freezing: Only train the classifier head initially
- Dropout: Added for regularization (prevents overfitting)
- LogSoftmax: Used with NLLLoss for numerical stability
- Parameter Count: ~3.4M total, only ~130K trainable
Training Process
def train_model(model, train_loader, val_loader, epochs=10):
# Loss and optimizer
criterion = nn.NLLLoss()
optimizer = optim.Adam(
model.classifier.parameters(),
lr=0.001,
weight_decay=1e-4 # L2 regularization
)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
mode='min',
factor=0.1,
patience=3,
verbose=True
)
best_acc = 0.0
for epoch in range(epochs):
# Training phase
model.train()
running_loss = 0.0
for inputs, labels in tqdm(train_loader, desc=f'Train Epoch {epoch+1}'):
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
# Validation phase
val_loss, val_acc = evaluate(model, val_loader, criterion)
scheduler.step(val_loss)
# Save best model
if val_acc > best_acc:
best_acc = val_acc
torch.save(model.state_dict(), 'best_model.pth')
print(f'Epoch {epoch+1}: Train Loss: {running_loss/len(train_loader.dataset):.4f}, '
f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
Training Optimization
- Learning Rate Scheduling: Reduces LR when validation loss plateaus
- Early Stopping: Implemented through LR scheduler patience
- Weight Decay: L2 regularization to prevent overfitting
- Model Checkpointing: Saves best performing model
Evaluation Metrics
def evaluate(model, data_loader, criterion):
model.eval()
running_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in data_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)
running_loss += loss.item() * inputs.size(0)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
return running_loss/len(data_loader.dataset), 100*correct/total
def generate_classification_report(model, data_loader):
all_preds = []
all_labels = []
model.eval()
with torch.no_grad():
for inputs, labels in data_loader:
inputs = inputs.to(device)
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
all_preds.extend(preds.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
# Confusion Matrix
plt.figure(figsize=(15,15))
sns.heatmap(
confusion_matrix(all_labels, all_preds),
annot=False,
fmt='d',
cmap='Blues'
)
plt.title('Confusion Matrix')
plt.savefig('confusion_matrix.png')
# Classification Report
print(classification_report(
all_labels,
all_preds,
target_names=data_loader.dataset.classes,
digits=4
))
Complete System Workflow
def main():
# Initialize
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Data
train_set, val_set = load_datasets()
train_loader = DataLoader(train_set, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_set, batch_size=32, shuffle=False, num_workers=4)
# Model
model = initialize_model().to(device)
# Train
train_model(model, train_loader, val_loader, epochs=15)
# Evaluate
model.load_state_dict(torch.load('best_model.pth'))
generate_classification_report(model, val_loader)
# Optional: Full model fine-tuning
unfreeze_model(model) # Unfreeze some backbone layers
train_model(model, train_loader, val_loader, epochs=5) # Fine-tune with lower LR
if __name__ == '__main__':
main()
Expected Performance
- Initial training: ~85-90% accuracy (classifier only)
- After fine-tuning: ~92-95% accuracy
- Inference speed: ~15ms per image on GPU
Comments