← Back to Course
Practical Work 4

Transfer Learning Classifier

Building a custom image classifier with PyTorch and pre-trained models

Duration 2.5 hours
Difficulty Intermediate
Session 4 - Custom Models

Objectives

By the end of this practical work, you will be able to:

  • Organize image data for training
  • Load and modify a pre-trained model for your task
  • Apply data augmentation to improve model robustness
  • Train a custom classifier using transfer learning
  • Evaluate model performance and save for deployment

Prerequisites

  • Python 3.8+ with PyTorch installed
  • GPU recommended (CUDA) but CPU will work
  • Basic understanding of neural networks
  • Dataset: 100+ images per class (we'll use a sample dataset)

Install required packages:

pip install torch torchvision matplotlib tqdm pillow

Dataset: For this lab, we'll use a "defect detection" scenario with 3 classes: good, scratch, dent. You can download sample data or use your own images.

Instructions

Step 1: Organize Your Dataset

Create the following folder structure:

data/
├── train/
│   ├── good/
│   │   ├── img001.jpg
│   │   ├── img002.jpg
│   │   └── ...
│   ├── scratch/
│   │   └── ...
│   └── dent/
│       └── ...
└── val/
    ├── good/
    │   └── ...
    ├── scratch/
    │   └── ...
    └── dent/
        └── ...

Split your data: approximately 80% for training, 20% for validation. Aim for at least 50 images per class for training.

Step 2: Create the Data Pipeline

Set up data loading with augmentation:

# data_pipeline.py
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

def create_dataloaders(data_dir: str, batch_size: int = 32):
    """Create train and validation dataloaders with augmentation."""

    # ImageNet normalization values
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )

    # Training transforms with augmentation
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ColorJitter(brightness=0.2, contrast=0.2),
        transforms.ToTensor(),
        normalize
    ])

    # Validation transforms (no augmentation)
    val_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize
    ])

    # Load datasets
    train_dataset = datasets.ImageFolder(
        f"{data_dir}/train",
        transform=train_transform
    )

    val_dataset = datasets.ImageFolder(
        f"{data_dir}/val",
        transform=val_transform
    )

    # Create dataloaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4,
        pin_memory=True
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=4,
        pin_memory=True
    )

    return train_loader, val_loader, train_dataset.classes


if __name__ == "__main__":
    train_loader, val_loader, classes = create_dataloaders("data")
    print(f"Classes: {classes}")
    print(f"Training batches: {len(train_loader)}")
    print(f"Validation batches: {len(val_loader)}")

Step 3: Build the Transfer Learning Model

Load a pre-trained ResNet and modify it:

# model.py
import torch
import torch.nn as nn
from torchvision import models

def create_model(num_classes: int, freeze_base: bool = True):
    """Create a transfer learning model based on ResNet-50."""

    # Load pre-trained ResNet-50
    model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)

    # Freeze base layers if specified
    if freeze_base:
        for param in model.parameters():
            param.requires_grad = False

    # Replace the final classification layer
    num_features = model.fc.in_features
    model.fc = nn.Sequential(
        nn.Dropout(0.5),
        nn.Linear(num_features, num_classes)
    )

    return model


def count_parameters(model):
    """Count trainable and total parameters."""
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total = sum(p.numel() for p in model.parameters())
    return trainable, total


if __name__ == "__main__":
    model = create_model(num_classes=3)
    trainable, total = count_parameters(model)
    print(f"Trainable parameters: {trainable:,}")
    print(f"Total parameters: {total:,}")
    print(f"Percentage trainable: {trainable/total*100:.2f}%")

Step 4: Implement the Training Loop

Create the training and validation functions:

# train.py
import torch
import torch.nn as nn
from tqdm import tqdm

def train_epoch(model, train_loader, criterion, optimizer, device):
    """Train for one epoch."""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for inputs, labels in tqdm(train_loader, desc="Training"):
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    epoch_loss = running_loss / total
    epoch_acc = correct / total
    return epoch_loss, epoch_acc


def validate(model, val_loader, criterion, device):
    """Validate the model."""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in tqdm(val_loader, desc="Validating"):
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            running_loss += loss.item() * inputs.size(0)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    epoch_loss = running_loss / total
    epoch_acc = correct / total
    return epoch_loss, epoch_acc

Step 5: Run the Training

Create the main training script:

# main.py
import torch
import torch.nn as nn
import torch.optim as optim
from data_pipeline import create_dataloaders
from model import create_model
from train import train_epoch, validate

def main():
    # Configuration
    DATA_DIR = "data"
    NUM_EPOCHS = 10
    BATCH_SIZE = 32
    LEARNING_RATE = 0.001

    # Device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Data
    train_loader, val_loader, classes = create_dataloaders(DATA_DIR, BATCH_SIZE)
    print(f"Classes: {classes}")

    # Model
    model = create_model(num_classes=len(classes), freeze_base=True)
    model = model.to(device)

    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=LEARNING_RATE
    )

    # Learning rate scheduler
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

    # Training loop
    best_acc = 0.0
    history = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": []}

    for epoch in range(NUM_EPOCHS):
        print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")
        print("-" * 40)

        train_loss, train_acc = train_epoch(
            model, train_loader, criterion, optimizer, device
        )
        val_loss, val_acc = validate(model, val_loader, criterion, device)

        scheduler.step()

        # Save history
        history["train_loss"].append(train_loss)
        history["train_acc"].append(train_acc)
        history["val_loss"].append(val_loss)
        history["val_acc"].append(val_acc)

        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
        print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

        # Save best model
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save({
                "epoch": epoch,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "val_acc": val_acc,
                "classes": classes
            }, "best_model.pth")
            print(f"Saved best model with accuracy: {val_acc:.4f}")

    print(f"\nTraining complete. Best validation accuracy: {best_acc:.4f}")
    return history


if __name__ == "__main__":
    history = main()

Step 6: Visualize Training Results

Plot the training curves:

# visualize.py
import matplotlib.pyplot as plt

def plot_training_history(history: dict):
    """Plot training and validation curves."""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

    epochs = range(1, len(history["train_loss"]) + 1)

    # Loss plot
    ax1.plot(epochs, history["train_loss"], "b-", label="Training")
    ax1.plot(epochs, history["val_loss"], "r-", label="Validation")
    ax1.set_xlabel("Epoch")
    ax1.set_ylabel("Loss")
    ax1.set_title("Training and Validation Loss")
    ax1.legend()
    ax1.grid(True, alpha=0.3)

    # Accuracy plot
    ax2.plot(epochs, history["train_acc"], "b-", label="Training")
    ax2.plot(epochs, history["val_acc"], "r-", label="Validation")
    ax2.set_xlabel("Epoch")
    ax2.set_ylabel("Accuracy")
    ax2.set_title("Training and Validation Accuracy")
    ax2.legend()
    ax2.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig("training_curves.png", dpi=150)
    plt.show()


if __name__ == "__main__":
    # Example usage with dummy data
    history = {
        "train_loss": [1.0, 0.6, 0.4, 0.3, 0.2],
        "train_acc": [0.5, 0.7, 0.8, 0.85, 0.9],
        "val_loss": [1.1, 0.7, 0.5, 0.4, 0.35],
        "val_acc": [0.45, 0.65, 0.75, 0.8, 0.85]
    }
    plot_training_history(history)

Step 7: Run Inference on New Images

Create a script to classify new images:

# predict.py
import torch
from torchvision import transforms
from PIL import Image
from model import create_model

def load_model(checkpoint_path: str):
    """Load a trained model from checkpoint."""
    checkpoint = torch.load(checkpoint_path, map_location="cpu")
    classes = checkpoint["classes"]

    model = create_model(num_classes=len(classes), freeze_base=False)
    model.load_state_dict(checkpoint["model_state_dict"])
    model.eval()

    return model, classes


def predict(model, image_path: str, classes: list):
    """Predict class for a single image."""
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    image = Image.open(image_path).convert("RGB")
    input_tensor = transform(image).unsqueeze(0)

    with torch.no_grad():
        outputs = model(input_tensor)
        probabilities = torch.softmax(outputs, dim=1)
        confidence, predicted = probabilities.max(1)

    return {
        "class": classes[predicted.item()],
        "confidence": confidence.item(),
        "all_probabilities": {
            cls: prob.item()
            for cls, prob in zip(classes, probabilities[0])
        }
    }


if __name__ == "__main__":
    model, classes = load_model("best_model.pth")
    result = predict(model, "test_image.jpg", classes)

    print(f"Predicted: {result['class']}")
    print(f"Confidence: {result['confidence']:.2%}")
    print("\nAll probabilities:")
    for cls, prob in result["all_probabilities"].items():
        print(f"  {cls}: {prob:.2%}")

Expected Output

After completing this practical work, you should have:

  • Organized dataset in the correct folder structure
  • Training output showing loss/accuracy per epoch
  • best_model.pth - Saved model checkpoint
  • training_curves.png - Visualization of training progress
  • Successful predictions on test images

Deliverables

  • Complete project folder with all Python files
  • Trained model file (best_model.pth)
  • Training curves visualization
  • Brief report including: final accuracy, training time, observations about model performance

Bonus Challenges

  • Challenge 1: Try unfreezing the base layers and fine-tuning with a lower learning rate. Compare results.
  • Challenge 2: Replace ResNet-50 with EfficientNet-B0 and compare performance/speed.
  • Challenge 3: Add a confusion matrix visualization to see which classes are confused.
  • Challenge 4: Implement early stopping to prevent overfitting.