← Back to Presentations
Practical Work 4

Training a CNN from Scratch

Build, train, and evaluate a Convolutional Neural Network on the CIFAR-10 dataset

Duration 2 hours
Difficulty Intermediate
Session 4 - CNN Architectures

Objectives

By the end of this practical work, you will be able to:

  • Implement a simple CNN architecture using Keras
  • Train a neural network on the CIFAR-10 dataset
  • Visualize training curves (loss and accuracy)
  • Evaluate model performance on a test set

Prerequisites

  • Basic understanding of neural networks and CNNs
  • Familiarity with Python and NumPy
  • Completed previous practical works on image classification

Install required packages:

pip install tensorflow matplotlib scikit-learn

Instructions

Step 1: Load and Preprocess CIFAR-10

Start by loading the CIFAR-10 dataset and normalizing pixel values to the range [0, 1]:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt

# (#1:Load CIFAR-10 dataset)
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()

# (#2:Normalize pixel values to 0-1 range)
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0

# (#3:Define class names for visualization)
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
               'dog', 'frog', 'horse', 'ship', 'truck']

print(f"Training samples: {x_train.shape[0]}")
print(f"Test samples: {x_test.shape[0]}")
print(f"Image shape: {x_train.shape[1:]}")

Note: CIFAR-10 contains 60,000 32x32 color images in 10 classes, with 50,000 training images and 10,000 test images.

Step 2: Explore the Dataset

Visualize sample images from each class to understand the data:

# (#1:Create figure for displaying samples)
fig, axes = plt.subplots(2, 5, figsize=(12, 5))
axes = axes.flatten()

# (#2:Display one sample from each class)
for i, class_name in enumerate(class_names):
    idx = np.where(y_train.flatten() == i)[0][0]  # (#3:Find first image of each class)
    axes[i].imshow(x_train[idx])
    axes[i].set_title(class_name)
    axes[i].axis('off')

plt.suptitle('CIFAR-10 Sample Images', fontsize=14)
plt.tight_layout()
plt.savefig('cifar10_samples.png', dpi=150)
plt.show()

# (#4:Show class distribution)
unique, counts = np.unique(y_train, return_counts=True)
print("Samples per class:")
for name, count in zip(class_names, counts):
    print(f"  {name}: {count}")

Step 3: Build the CNN Architecture

Create a CNN with three convolutional blocks followed by dense layers:

def build_cnn():
    model = keras.Sequential([
        # (#1:First convolutional block)
        layers.Conv2D(32, (3, 3), padding='same', input_shape=(32, 32, 3)),
        layers.BatchNormalization(),
        layers.Activation('relu'),
        layers.MaxPooling2D((2, 2)),

        # (#2:Second convolutional block)
        layers.Conv2D(64, (3, 3), padding='same'),
        layers.BatchNormalization(),
        layers.Activation('relu'),
        layers.MaxPooling2D((2, 2)),

        # (#3:Third convolutional block)
        layers.Conv2D(128, (3, 3), padding='same'),
        layers.BatchNormalization(),
        layers.Activation('relu'),

        # (#4:Flatten and dense layers)
        layers.Flatten(),
        layers.Dense(256),
        layers.BatchNormalization(),
        layers.Activation('relu'),
        layers.Dropout(0.5),  # (#5:Dropout for regularization)

        # (#6:Output layer with 10 classes)
        layers.Dense(10, activation='softmax')
    ])
    return model

model = build_cnn()
model.summary()

Architecture Overview:

  • Conv2D(32) -> MaxPooling -> Feature maps: 16x16x32
  • Conv2D(64) -> MaxPooling -> Feature maps: 8x8x64
  • Conv2D(128) -> Feature maps: 8x8x128
  • Flatten -> Dense(256) -> Output(10)
CNN Architecture Visualization
32x32
x3
Input
->
16x16
x32
Conv+Pool
->
8x8
x64
Conv+Pool
->
8x8
x128
Conv
->
8192
Flatten
->
256
Dense
->
10
Output
Feature maps shrink spatially but increase in depth through convolutional layers

Step 4: Add BatchNormalization and Dropout

The model above already includes BatchNormalization and Dropout. Here is why they are important:

  • BatchNormalization: Normalizes layer inputs, speeds up training, and allows higher learning rates
  • Dropout(0.5): Randomly drops 50% of neurons during training to prevent overfitting

Warning: Without regularization techniques like Dropout and BatchNormalization, your model may overfit quickly on CIFAR-10.

Step 5: Compile the Model

Configure the model with optimizer, loss function, and metrics:

# (#1:Compile with Adam optimizer)
model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=0.001),  # (#2:Default learning rate)
    loss='sparse_categorical_crossentropy',  # (#3:For integer labels)
    metrics=['accuracy']
)

print("Model compiled successfully!")

Step 6: Set Up Callbacks

Configure callbacks for model checkpointing and early stopping:

import os

# (#1:Create directory for saving models)
os.makedirs('models', exist_ok=True)

callbacks = [
    # (#2:Save best model based on validation accuracy)
    keras.callbacks.ModelCheckpoint(
        filepath='models/best_cifar10_cnn.keras',
        monitor='val_accuracy',
        save_best_only=True,
        verbose=1
    ),
    # (#3:Stop training if no improvement)
    keras.callbacks.EarlyStopping(
        monitor='val_loss',
        patience=5,  # (#4:Wait 5 epochs before stopping)
        restore_best_weights=True,
        verbose=1
    )
]

print("Callbacks configured!")

Step 7: Train the Model

Train for 30 epochs with a validation split:

# (#1:Train with validation split)
history = model.fit(
    x_train, y_train,
    epochs=30,  # (#2:Maximum number of epochs)
    batch_size=64,
    validation_split=0.2,  # (#3:Use 20% of training data for validation)
    callbacks=callbacks,
    verbose=1
)

print(f"\nTraining completed!")
print(f"Best validation accuracy: {max(history.history['val_accuracy']):.4f}")

Tip: Training may take 10-20 minutes depending on your hardware. Using a GPU will significantly speed up training.

Step 8: Plot Training Curves

Visualize the loss and accuracy curves for training vs validation:

def plot_training_history(history):
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    # (#1:Plot loss curves)
    axes[0].plot(history.history['loss'], label='Training Loss', linewidth=2)
    axes[0].plot(history.history['val_loss'], label='Validation Loss', linewidth=2)
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('Training vs Validation Loss')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)

    # (#2:Plot accuracy curves)
    axes[1].plot(history.history['accuracy'], label='Training Accuracy', linewidth=2)
    axes[1].plot(history.history['val_accuracy'], label='Validation Accuracy', linewidth=2)
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Accuracy')
    axes[1].set_title('Training vs Validation Accuracy')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig('training_curves.png', dpi=150)  # (#3:Save figure)
    plt.show()

plot_training_history(history)
Expected Output: Training Curves
Loss vs Epoch
2.0 1.0 0
0 15 30
Train
Val
Accuracy vs Epoch
100% 50% 0%
0 15 30
Train ~90%
Val ~78%
Gap between training and validation indicates some overfitting. Your curves may vary.

Step 9: Evaluate on Test Set

Assess the final model performance on the held-out test set:

# (#1:Evaluate on test data)
test_loss, test_accuracy = model.evaluate(x_test, y_test, verbose=1)

print(f"\n{'='*50}")
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_accuracy:.4f} ({test_accuracy*100:.2f}%)")
print(f"{'='*50}")

Step 10: Generate Confusion Matrix

Create a confusion matrix to analyze per-class performance:

from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns

# (#1:Get predictions)
y_pred = model.predict(x_test)
y_pred_classes = np.argmax(y_pred, axis=1)
y_true = y_test.flatten()

# (#2:Compute confusion matrix)
cm = confusion_matrix(y_true, y_pred_classes)

# (#3:Plot confusion matrix)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=class_names,
            yticklabels=class_names)
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.title('Confusion Matrix - CIFAR-10 CNN')
plt.tight_layout()
plt.savefig('confusion_matrix.png', dpi=150)  # (#4:Save confusion matrix)
plt.show()

# (#5:Print classification report)
print("\nClassification Report:")
print(classification_report(y_true, y_pred_classes, target_names=class_names))
Expected Output: Confusion Matrix (Simplified 5x5)
plane
car
cat
dog
truck
plane
820
25
18
12
20
car
15
850
8
5
45
cat
22
10
680
95
15
dog
18
8
88
710
12
truck
28
52
10
8
780
Correct
Some confusion
High confusion
Simplified 5-class view. Cat/dog and car/truck pairs show highest confusion. Your values will differ.

Step 11: Show Misclassified Examples

Visualize examples where the model made incorrect predictions:

# (#1:Find misclassified indices)
misclassified_idx = np.where(y_pred_classes != y_true)[0]
print(f"Total misclassified: {len(misclassified_idx)} out of {len(y_test)}")

# (#2:Display some misclassified examples)
fig, axes = plt.subplots(2, 5, figsize=(15, 6))
axes = axes.flatten()

for i, idx in enumerate(misclassified_idx[:10]):
    axes[i].imshow(x_test[idx])
    true_label = class_names[y_true[idx]]
    pred_label = class_names[y_pred_classes[idx]]
    confidence = y_pred[idx][y_pred_classes[idx]] * 100
    axes[i].set_title(f'True: {true_label}\nPred: {pred_label}\n({confidence:.1f}%)',
                      fontsize=9, color='red')
    axes[i].axis('off')

plt.suptitle('Misclassified Examples', fontsize=14)
plt.tight_layout()
plt.savefig('misclassified_examples.png', dpi=150)  # (#3:Save figure)
plt.show()

Analysis Tip: Look for patterns in misclassifications. Are certain classes frequently confused? This can guide architecture improvements.

Expected Output: Misclassified Examples
CAT
True: cat
Pred: dog
DOG
True: dog
Pred: cat
TRUCK
True: truck
Pred: car
BIRD
True: bird
Pred: plane
CAR
True: car
Pred: truck
Red borders indicate incorrect predictions. Common patterns: similar-looking classes get confused.

Expected Output

After completing this practical work, you should achieve approximately:

  • Test Accuracy: ~75% (can vary between 72-78%)
  • Training curves showing convergence with validation tracking
  • Confusion matrix revealing per-class performance

Target: Your model should achieve at least 75% accuracy on the test set. If you are getting lower accuracy, check your architecture and hyperparameters.

Metric Expected Range
Training Accuracy 85-95%
Validation Accuracy 75-82%
Test Accuracy 72-78%

Deliverables

  • Jupyter Notebook: Complete notebook with all code cells executed and outputs visible
  • Training Curves Plot: training_curves.png showing loss and accuracy over epochs
  • Confusion Matrix: confusion_matrix.png with per-class performance analysis
  • Saved Model: models/best_cifar10_cnn.keras containing the best weights
  • practical_work_4/
    • cnn_training.ipynb
    • training_curves.png
    • confusion_matrix.png
    • misclassified_examples.png
    • cifar10_samples.png
    • models/
      • best_cifar10_cnn.keras

Bonus Challenges

  • Try Different Architectures: Experiment with more/fewer convolutional layers, different filter sizes (5x5), or add more dense layers
  • Add Data Augmentation: Use ImageDataGenerator or tf.keras.layers.RandomFlip, RandomRotation, RandomZoom to augment training data
  • Achieve 80%+ Accuracy: Combine architecture improvements with data augmentation to push accuracy beyond 80%
  • Learning Rate Scheduling: Implement learning rate decay using ReduceLROnPlateau callback

Data Augmentation Example

# (#1:Create data augmentation layers)
data_augmentation = keras.Sequential([
    layers.RandomFlip("horizontal"),  # (#2:Horizontal flip)
    layers.RandomRotation(0.1),  # (#3:Rotate up to 10%)
    layers.RandomZoom(0.1),  # (#4:Zoom up to 10%)
])

# (#5:Apply augmentation in model)
augmented_model = keras.Sequential([
    data_augmentation,
    # ... rest of your CNN layers
])

Resources