Residual Network(resnet) CIFAR-10 PyTorch

Residual Networks Classification with PyTorch - Complete Guide

Residual Networks Classification with PyTorch

This guide provides a comprehensive explanation of implementing Residual Networks (ResNet) for image classification using PyTorch, covering both custom ResNet implementation and transfer learning approaches.

1. Introduction to Residual Networks

Residual Networks (ResNets) were introduced by Microsoft Research in 2015 to address the degradation problem in deep neural networks. As networks get deeper, accuracy gets saturated and then degrades rapidly. ResNets solve this by introducing "skip connections" or "shortcuts" that allow gradients to flow through the network more effectively.

Key Features of ResNets:
  • Skip Connections: Allow the network to learn identity functions, making deeper networks easier to train
  • Residual Blocks: Basic building blocks that include convolutional layers with skip connections
  • Batch Normalization: Used after each convolutional layer to stabilize training

2. Complete Code Implementation

The following code implements both a custom ResNet and uses transfer learning with a pre-trained ResNet for CIFAR-10 classification.

2.1 Device Setup

# %% Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Explanation: This line checks if CUDA (GPU support) is available and sets the device accordingly. PyTorch will use GPU acceleration if available, falling back to CPU otherwise.

2.2 Data Loading with Augmentation

# %% Data loading with augmentation
transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)

Explanation:

  • Data Augmentation: For training data, we apply random horizontal flips and random crops to increase dataset variability and prevent overfitting.
  • Normalization: We normalize the images to have mean=0.5 and std=0.5 for each channel (RGB).
  • CIFAR-10: A dataset of 60,000 32x32 color images in 10 classes (6,000 images per class).
  • DataLoaders: Create iterators that provide batches of images and labels during training/testing.

Note: The test set transformations don't include augmentation since we want to evaluate on the original images.

2.3 Residual Block Definition

# %% Residual block definition
class ResidualBlock(nn.Module): 
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = downsample
        
    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)
        out = self.conv1(x)
        out = self.bn1(out) 
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += identity
        out = self.relu(out)
        return out

Explanation: This is the fundamental building block of ResNet.

  • Two Convolutional Layers: Each with 3x3 kernels, batch normalization, and ReLU activation.
  • Skip Connection: The input (identity) is added to the output of the second convolution.
  • Downsample: Optional downsampling operation for the identity when dimensions change.
  • Bias=False: Since we're using batch norm, we don't need bias terms in the conv layers.

The residual block implements the equation: F(x) + x, where F(x) represents the convolutional operations.

2.4 Custom ResNet Definition

# %% Custom ResNet definition
class ResNet(nn.Module):
    def __init__(self, num_classes=10):
        super(ResNet, self).__init__()
        self.in_channels = 64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)  # Updated for CIFAR-10
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()
        self.layer1 = self._make_layer(64, 64, 2)
        self.layer2 = self._make_layer(64, 128, 2, stride=2)
        self.layer3 = self._make_layer(128, 256, 2, stride=2)
        self.layer4 = self._make_layer(256, 512, 2, stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)

    def _make_layer(self, in_channels, out_channels, blocks, stride=1):
        downsample = None
        if stride != 1 or self.in_channels != out_channels:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels),
            )
        layers = [ResidualBlock(self.in_channels, out_channels, stride, downsample)]
        self.in_channels = out_channels  # Important fix
        for _ in range(1, blocks):
            layers.append(ResidualBlock(out_channels, out_channels))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

Explanation: This implements a complete ResNet architecture.

  • Initial Convolution: Processes the input images (3 channels) to 64 feature maps.
  • Residual Layers: Four layers with increasing channels (64→128→256→512) and downsampling.
  • _make_layer: Helper function that creates a sequence of residual blocks.
  • Global Average Pooling: Reduces spatial dimensions to 1x1 before the fully connected layer.
  • Adaptive for CIFAR-10: Uses 3x3 conv with stride=1 in first layer (unlike ImageNet models).

The architecture follows the ResNet-18 pattern but is adapted for the smaller CIFAR-10 images (32x32 vs. 224x224 in ImageNet).

2.5 Model Selection (Custom or Transfer Learning)

# %% Model selection (custom or transfer learning)
use_custom_resnet = True

if use_custom_resnet:
    model = ResNet().to(device)
else:
    model = models.resnet18(pretrained=True)
    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)  # CIFAR-10 fix
    model.maxpool = nn.Identity()  # Remove maxpool to retain resolution
    num_ftrs = model.fc.in_features
    model.fc = nn.Sequential(
        nn.Linear(num_ftrs, 256),
        nn.ReLU(),
        nn.Linear(256, 10)
    )
    model = model.to(device)

Explanation: This section provides two approaches:

  • Custom ResNet: Uses our implementation from scratch.
  • Transfer Learning: Uses a pre-trained ResNet-18 with modifications:
    • First convolution changed to work with 32x32 CIFAR-10 images
    • Max pooling layer removed (not needed for small images)
    • Final fully connected layer replaced with a custom classifier

Transfer learning is particularly useful when you have limited training data, as the pre-trained model has already learned useful feature detectors from ImageNet.

2.6 Loss Function and Optimizer

# %% Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

Explanation:

  • CrossEntropyLoss: Standard loss function for multi-class classification.
  • Adam Optimizer: Adaptive learning rate optimization algorithm that's often a good default choice.

2.7 Training Loop

# %% Training loop
num_epochs = 1

for epoch in range(num_epochs):
    running_loss = 0.0
    model.train()
    for images, labels in tqdm(trainloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(trainloader):.4f}")

Explanation: Standard PyTorch training loop with:

  • Forward Pass: Compute model predictions
  • Loss Calculation: Compare predictions with true labels
  • Backward Pass: Compute gradients
  • Optimizer Step: Update model weights
  • Progress Tracking: tqdm provides progress bars, we track average loss per epoch

In practice, you would use more epochs (typically 50-200 for CIFAR-10). The example uses 1 epoch for demonstration.

2.8 Testing Loop

# %% Testing loop
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in testloader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
accuracy = 100 * correct / total
print(f"Accuracy of the model on the test set: {accuracy:.2f}%")

Explanation: Evaluation phase with:

  • model.eval(): Sets the model to evaluation mode (affects dropout, batch norm, etc.)
  • torch.no_grad(): Disables gradient computation for efficiency
  • Accuracy Calculation: Compares predicted classes with true labels

3. Key Concepts Explained

3.1 Skip Connections

The fundamental innovation in ResNets is the skip connection that bypasses one or more layers. This allows the network to learn identity mappings more easily, addressing the vanishing gradient problem in deep networks.

3.2 Batch Normalization

Batch normalization normalizes the activations of each layer to have zero mean and unit variance, making training deeper networks more stable and allowing higher learning rates.

3.3 Transfer Learning vs. Training from Scratch

Transfer learning uses a model pre-trained on a large dataset (like ImageNet) and fine-tunes it for a new task. This is often more efficient than training from scratch, especially with limited data.

4. Potential Improvements

  • Learning Rate Scheduling: Add learning rate decay for better convergence
  • More Augmentations: Add color jitter, rotation, or other transformations
  • Regularization: Add dropout or weight decay to prevent overfitting
  • Different Architectures: Try ResNet-34, ResNet-50 with bottleneck blocks
  • Advanced Optimizers: Experiment with SGD with momentum or AdamW

5. Conclusion

This implementation demonstrates how to build and train ResNets in PyTorch, covering both custom implementation and transfer learning approaches. Residual Networks have become a cornerstone of modern computer vision due to their ability to train very deep networks effectively.

The key takeaways are understanding the residual block design, proper network initialization, and the importance of skip connections for deep network training.

Comments

Popular posts from this blog

Tech Duos For Web Development

CIFAR-10 Dataset Classification Using Convolutional Neural Networks (CNNs) With PyTorch

Long-short-term-memory (LSTM) Word Prediction With PyTorch