Transfer Learning Classifier
Building a custom image classifier with PyTorch and pre-trained models
Objectives
By the end of this practical work, you will be able to:
- Organize image data for training
- Load and modify a pre-trained model for your task
- Apply data augmentation to improve model robustness
- Train a custom classifier using transfer learning
- Evaluate model performance and save for deployment
Prerequisites
- Python 3.8+ with PyTorch installed
- GPU recommended (CUDA) but CPU will work
- Basic understanding of neural networks
- Dataset: 100+ images per class (we'll use a sample dataset)
Install required packages:
pip install torch torchvision matplotlib tqdm pillow
Dataset: For this lab, we'll use a "defect detection" scenario with 3 classes: good, scratch, dent. You can download sample data or use your own images.
Instructions
Step 1: Organize Your Dataset
Create the following folder structure:
data/
├── train/
│ ├── good/
│ │ ├── img001.jpg
│ │ ├── img002.jpg
│ │ └── ...
│ ├── scratch/
│ │ └── ...
│ └── dent/
│ └── ...
└── val/
├── good/
│ └── ...
├── scratch/
│ └── ...
└── dent/
└── ...
Split your data: approximately 80% for training, 20% for validation. Aim for at least 50 images per class for training.
Step 2: Create the Data Pipeline
Set up data loading with augmentation:
# data_pipeline.py
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
def create_dataloaders(data_dir: str, batch_size: int = 32):
"""Create train and validation dataloaders with augmentation."""
# ImageNet normalization values
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
# Training transforms with augmentation
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15),
transforms.ColorJitter(brightness=0.2, contrast=0.2),
transforms.ToTensor(),
normalize
])
# Validation transforms (no augmentation)
val_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize
])
# Load datasets
train_dataset = datasets.ImageFolder(
f"{data_dir}/train",
transform=train_transform
)
val_dataset = datasets.ImageFolder(
f"{data_dir}/val",
transform=val_transform
)
# Create dataloaders
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=4,
pin_memory=True
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=4,
pin_memory=True
)
return train_loader, val_loader, train_dataset.classes
if __name__ == "__main__":
train_loader, val_loader, classes = create_dataloaders("data")
print(f"Classes: {classes}")
print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")
Step 3: Build the Transfer Learning Model
Load a pre-trained ResNet and modify it:
# model.py
import torch
import torch.nn as nn
from torchvision import models
def create_model(num_classes: int, freeze_base: bool = True):
"""Create a transfer learning model based on ResNet-50."""
# Load pre-trained ResNet-50
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
# Freeze base layers if specified
if freeze_base:
for param in model.parameters():
param.requires_grad = False
# Replace the final classification layer
num_features = model.fc.in_features
model.fc = nn.Sequential(
nn.Dropout(0.5),
nn.Linear(num_features, num_classes)
)
return model
def count_parameters(model):
"""Count trainable and total parameters."""
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
return trainable, total
if __name__ == "__main__":
model = create_model(num_classes=3)
trainable, total = count_parameters(model)
print(f"Trainable parameters: {trainable:,}")
print(f"Total parameters: {total:,}")
print(f"Percentage trainable: {trainable/total*100:.2f}%")
Step 4: Implement the Training Loop
Create the training and validation functions:
# train.py
import torch
import torch.nn as nn
from tqdm import tqdm
def train_epoch(model, train_loader, criterion, optimizer, device):
"""Train for one epoch."""
model.train()
running_loss = 0.0
correct = 0
total = 0
for inputs, labels in tqdm(train_loader, desc="Training"):
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
epoch_loss = running_loss / total
epoch_acc = correct / total
return epoch_loss, epoch_acc
def validate(model, val_loader, criterion, device):
"""Validate the model."""
model.eval()
running_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in tqdm(val_loader, desc="Validating"):
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)
running_loss += loss.item() * inputs.size(0)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
epoch_loss = running_loss / total
epoch_acc = correct / total
return epoch_loss, epoch_acc
Step 5: Run the Training
Create the main training script:
# main.py
import torch
import torch.nn as nn
import torch.optim as optim
from data_pipeline import create_dataloaders
from model import create_model
from train import train_epoch, validate
def main():
# Configuration
DATA_DIR = "data"
NUM_EPOCHS = 10
BATCH_SIZE = 32
LEARNING_RATE = 0.001
# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Data
train_loader, val_loader, classes = create_dataloaders(DATA_DIR, BATCH_SIZE)
print(f"Classes: {classes}")
# Model
model = create_model(num_classes=len(classes), freeze_base=True)
model = model.to(device)
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(
filter(lambda p: p.requires_grad, model.parameters()),
lr=LEARNING_RATE
)
# Learning rate scheduler
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
# Training loop
best_acc = 0.0
history = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": []}
for epoch in range(NUM_EPOCHS):
print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")
print("-" * 40)
train_loss, train_acc = train_epoch(
model, train_loader, criterion, optimizer, device
)
val_loss, val_acc = validate(model, val_loader, criterion, device)
scheduler.step()
# Save history
history["train_loss"].append(train_loss)
history["train_acc"].append(train_acc)
history["val_loss"].append(val_loss)
history["val_acc"].append(val_acc)
print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
# Save best model
if val_acc > best_acc:
best_acc = val_acc
torch.save({
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"val_acc": val_acc,
"classes": classes
}, "best_model.pth")
print(f"Saved best model with accuracy: {val_acc:.4f}")
print(f"\nTraining complete. Best validation accuracy: {best_acc:.4f}")
return history
if __name__ == "__main__":
history = main()
Step 6: Visualize Training Results
Plot the training curves:
# visualize.py
import matplotlib.pyplot as plt
def plot_training_history(history: dict):
"""Plot training and validation curves."""
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
epochs = range(1, len(history["train_loss"]) + 1)
# Loss plot
ax1.plot(epochs, history["train_loss"], "b-", label="Training")
ax1.plot(epochs, history["val_loss"], "r-", label="Validation")
ax1.set_xlabel("Epoch")
ax1.set_ylabel("Loss")
ax1.set_title("Training and Validation Loss")
ax1.legend()
ax1.grid(True, alpha=0.3)
# Accuracy plot
ax2.plot(epochs, history["train_acc"], "b-", label="Training")
ax2.plot(epochs, history["val_acc"], "r-", label="Validation")
ax2.set_xlabel("Epoch")
ax2.set_ylabel("Accuracy")
ax2.set_title("Training and Validation Accuracy")
ax2.legend()
ax2.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig("training_curves.png", dpi=150)
plt.show()
if __name__ == "__main__":
# Example usage with dummy data
history = {
"train_loss": [1.0, 0.6, 0.4, 0.3, 0.2],
"train_acc": [0.5, 0.7, 0.8, 0.85, 0.9],
"val_loss": [1.1, 0.7, 0.5, 0.4, 0.35],
"val_acc": [0.45, 0.65, 0.75, 0.8, 0.85]
}
plot_training_history(history)
Step 7: Run Inference on New Images
Create a script to classify new images:
# predict.py
import torch
from torchvision import transforms
from PIL import Image
from model import create_model
def load_model(checkpoint_path: str):
"""Load a trained model from checkpoint."""
checkpoint = torch.load(checkpoint_path, map_location="cpu")
classes = checkpoint["classes"]
model = create_model(num_classes=len(classes), freeze_base=False)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
return model, classes
def predict(model, image_path: str, classes: list):
"""Predict class for a single image."""
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
image = Image.open(image_path).convert("RGB")
input_tensor = transform(image).unsqueeze(0)
with torch.no_grad():
outputs = model(input_tensor)
probabilities = torch.softmax(outputs, dim=1)
confidence, predicted = probabilities.max(1)
return {
"class": classes[predicted.item()],
"confidence": confidence.item(),
"all_probabilities": {
cls: prob.item()
for cls, prob in zip(classes, probabilities[0])
}
}
if __name__ == "__main__":
model, classes = load_model("best_model.pth")
result = predict(model, "test_image.jpg", classes)
print(f"Predicted: {result['class']}")
print(f"Confidence: {result['confidence']:.2%}")
print("\nAll probabilities:")
for cls, prob in result["all_probabilities"].items():
print(f" {cls}: {prob:.2%}")
Expected Output
After completing this practical work, you should have:
- Organized dataset in the correct folder structure
- Training output showing loss/accuracy per epoch
best_model.pth- Saved model checkpointtraining_curves.png- Visualization of training progress- Successful predictions on test images
Deliverables
- Complete project folder with all Python files
- Trained model file (
best_model.pth) - Training curves visualization
- Brief report including: final accuracy, training time, observations about model performance
Bonus Challenges
- Challenge 1: Try unfreezing the base layers and fine-tuning with a lower learning rate. Compare results.
- Challenge 2: Replace ResNet-50 with EfficientNet-B0 and compare performance/speed.
- Challenge 3: Add a confusion matrix visualization to see which classes are confused.
- Challenge 4: Implement early stopping to prevent overfitting.