Training a CNN from Scratch
Build, train, and evaluate a Convolutional Neural Network on the CIFAR-10 dataset
Objectives
By the end of this practical work, you will be able to:
- Implement a simple CNN architecture using Keras
- Train a neural network on the CIFAR-10 dataset
- Visualize training curves (loss and accuracy)
- Evaluate model performance on a test set
Prerequisites
- Basic understanding of neural networks and CNNs
- Familiarity with Python and NumPy
- Completed previous practical works on image classification
Install required packages:
pip install tensorflow matplotlib scikit-learn
Instructions
Step 1: Load and Preprocess CIFAR-10
Start by loading the CIFAR-10 dataset and normalizing pixel values to the range [0, 1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt
# (#1:Load CIFAR-10 dataset)
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
# (#2:Normalize pixel values to 0-1 range)
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0
# (#3:Define class names for visualization)
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']
print(f"Training samples: {x_train.shape[0]}")
print(f"Test samples: {x_test.shape[0]}")
print(f"Image shape: {x_train.shape[1:]}")
Note: CIFAR-10 contains 60,000 32x32 color images in 10 classes, with 50,000 training images and 10,000 test images.
Step 2: Explore the Dataset
Visualize sample images from each class to understand the data:
# (#1:Create figure for displaying samples)
fig, axes = plt.subplots(2, 5, figsize=(12, 5))
axes = axes.flatten()
# (#2:Display one sample from each class)
for i, class_name in enumerate(class_names):
idx = np.where(y_train.flatten() == i)[0][0] # (#3:Find first image of each class)
axes[i].imshow(x_train[idx])
axes[i].set_title(class_name)
axes[i].axis('off')
plt.suptitle('CIFAR-10 Sample Images', fontsize=14)
plt.tight_layout()
plt.savefig('cifar10_samples.png', dpi=150)
plt.show()
# (#4:Show class distribution)
unique, counts = np.unique(y_train, return_counts=True)
print("Samples per class:")
for name, count in zip(class_names, counts):
print(f" {name}: {count}")
Step 3: Build the CNN Architecture
Create a CNN with three convolutional blocks followed by dense layers:
def build_cnn():
model = keras.Sequential([
# (#1:First convolutional block)
layers.Conv2D(32, (3, 3), padding='same', input_shape=(32, 32, 3)),
layers.BatchNormalization(),
layers.Activation('relu'),
layers.MaxPooling2D((2, 2)),
# (#2:Second convolutional block)
layers.Conv2D(64, (3, 3), padding='same'),
layers.BatchNormalization(),
layers.Activation('relu'),
layers.MaxPooling2D((2, 2)),
# (#3:Third convolutional block)
layers.Conv2D(128, (3, 3), padding='same'),
layers.BatchNormalization(),
layers.Activation('relu'),
# (#4:Flatten and dense layers)
layers.Flatten(),
layers.Dense(256),
layers.BatchNormalization(),
layers.Activation('relu'),
layers.Dropout(0.5), # (#5:Dropout for regularization)
# (#6:Output layer with 10 classes)
layers.Dense(10, activation='softmax')
])
return model
model = build_cnn()
model.summary()
Architecture Overview:
- Conv2D(32) -> MaxPooling -> Feature maps: 16x16x32
- Conv2D(64) -> MaxPooling -> Feature maps: 8x8x64
- Conv2D(128) -> Feature maps: 8x8x128
- Flatten -> Dense(256) -> Output(10)
x3
x32
x64
x128
Step 4: Add BatchNormalization and Dropout
The model above already includes BatchNormalization and Dropout. Here is why they are important:
- BatchNormalization: Normalizes layer inputs, speeds up training, and allows higher learning rates
- Dropout(0.5): Randomly drops 50% of neurons during training to prevent overfitting
Warning: Without regularization techniques like Dropout and BatchNormalization, your model may overfit quickly on CIFAR-10.
Step 5: Compile the Model
Configure the model with optimizer, loss function, and metrics:
# (#1:Compile with Adam optimizer)
model.compile(
optimizer=keras.optimizers.Adam(learning_rate=0.001), # (#2:Default learning rate)
loss='sparse_categorical_crossentropy', # (#3:For integer labels)
metrics=['accuracy']
)
print("Model compiled successfully!")
Step 6: Set Up Callbacks
Configure callbacks for model checkpointing and early stopping:
import os
# (#1:Create directory for saving models)
os.makedirs('models', exist_ok=True)
callbacks = [
# (#2:Save best model based on validation accuracy)
keras.callbacks.ModelCheckpoint(
filepath='models/best_cifar10_cnn.keras',
monitor='val_accuracy',
save_best_only=True,
verbose=1
),
# (#3:Stop training if no improvement)
keras.callbacks.EarlyStopping(
monitor='val_loss',
patience=5, # (#4:Wait 5 epochs before stopping)
restore_best_weights=True,
verbose=1
)
]
print("Callbacks configured!")
Step 7: Train the Model
Train for 30 epochs with a validation split:
# (#1:Train with validation split)
history = model.fit(
x_train, y_train,
epochs=30, # (#2:Maximum number of epochs)
batch_size=64,
validation_split=0.2, # (#3:Use 20% of training data for validation)
callbacks=callbacks,
verbose=1
)
print(f"\nTraining completed!")
print(f"Best validation accuracy: {max(history.history['val_accuracy']):.4f}")
Tip: Training may take 10-20 minutes depending on your hardware. Using a GPU will significantly speed up training.
Step 8: Plot Training Curves
Visualize the loss and accuracy curves for training vs validation:
def plot_training_history(history):
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# (#1:Plot loss curves)
axes[0].plot(history.history['loss'], label='Training Loss', linewidth=2)
axes[0].plot(history.history['val_loss'], label='Validation Loss', linewidth=2)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training vs Validation Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)
# (#2:Plot accuracy curves)
axes[1].plot(history.history['accuracy'], label='Training Accuracy', linewidth=2)
axes[1].plot(history.history['val_accuracy'], label='Validation Accuracy', linewidth=2)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].set_title('Training vs Validation Accuracy')
axes[1].legend()
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('training_curves.png', dpi=150) # (#3:Save figure)
plt.show()
plot_training_history(history)
Step 9: Evaluate on Test Set
Assess the final model performance on the held-out test set:
# (#1:Evaluate on test data)
test_loss, test_accuracy = model.evaluate(x_test, y_test, verbose=1)
print(f"\n{'='*50}")
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_accuracy:.4f} ({test_accuracy*100:.2f}%)")
print(f"{'='*50}")
Step 10: Generate Confusion Matrix
Create a confusion matrix to analyze per-class performance:
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns
# (#1:Get predictions)
y_pred = model.predict(x_test)
y_pred_classes = np.argmax(y_pred, axis=1)
y_true = y_test.flatten()
# (#2:Compute confusion matrix)
cm = confusion_matrix(y_true, y_pred_classes)
# (#3:Plot confusion matrix)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=class_names,
yticklabels=class_names)
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.title('Confusion Matrix - CIFAR-10 CNN')
plt.tight_layout()
plt.savefig('confusion_matrix.png', dpi=150) # (#4:Save confusion matrix)
plt.show()
# (#5:Print classification report)
print("\nClassification Report:")
print(classification_report(y_true, y_pred_classes, target_names=class_names))
Step 11: Show Misclassified Examples
Visualize examples where the model made incorrect predictions:
# (#1:Find misclassified indices)
misclassified_idx = np.where(y_pred_classes != y_true)[0]
print(f"Total misclassified: {len(misclassified_idx)} out of {len(y_test)}")
# (#2:Display some misclassified examples)
fig, axes = plt.subplots(2, 5, figsize=(15, 6))
axes = axes.flatten()
for i, idx in enumerate(misclassified_idx[:10]):
axes[i].imshow(x_test[idx])
true_label = class_names[y_true[idx]]
pred_label = class_names[y_pred_classes[idx]]
confidence = y_pred[idx][y_pred_classes[idx]] * 100
axes[i].set_title(f'True: {true_label}\nPred: {pred_label}\n({confidence:.1f}%)',
fontsize=9, color='red')
axes[i].axis('off')
plt.suptitle('Misclassified Examples', fontsize=14)
plt.tight_layout()
plt.savefig('misclassified_examples.png', dpi=150) # (#3:Save figure)
plt.show()
Analysis Tip: Look for patterns in misclassifications. Are certain classes frequently confused? This can guide architecture improvements.
Expected Output
After completing this practical work, you should achieve approximately:
- Test Accuracy: ~75% (can vary between 72-78%)
- Training curves showing convergence with validation tracking
- Confusion matrix revealing per-class performance
Target: Your model should achieve at least 75% accuracy on the test set. If you are getting lower accuracy, check your architecture and hyperparameters.
| Metric | Expected Range |
|---|---|
| Training Accuracy | 85-95% |
| Validation Accuracy | 75-82% |
| Test Accuracy | 72-78% |
Deliverables
- Jupyter Notebook: Complete notebook with all code cells executed and outputs visible
- Training Curves Plot:
training_curves.pngshowing loss and accuracy over epochs - Confusion Matrix:
confusion_matrix.pngwith per-class performance analysis - Saved Model:
models/best_cifar10_cnn.kerascontaining the best weights
- practical_work_4/
- cnn_training.ipynb
- training_curves.png
- confusion_matrix.png
- misclassified_examples.png
- cifar10_samples.png
- models/
- best_cifar10_cnn.keras
Bonus Challenges
- Try Different Architectures: Experiment with more/fewer convolutional layers, different filter sizes (5x5), or add more dense layers
- Add Data Augmentation: Use
ImageDataGeneratorortf.keras.layers.RandomFlip,RandomRotation,RandomZoomto augment training data - Achieve 80%+ Accuracy: Combine architecture improvements with data augmentation to push accuracy beyond 80%
- Learning Rate Scheduling: Implement learning rate decay using
ReduceLROnPlateaucallback
Data Augmentation Example
# (#1:Create data augmentation layers)
data_augmentation = keras.Sequential([
layers.RandomFlip("horizontal"), # (#2:Horizontal flip)
layers.RandomRotation(0.1), # (#3:Rotate up to 10%)
layers.RandomZoom(0.1), # (#4:Zoom up to 10%)
])
# (#5:Apply augmentation in model)
augmented_model = keras.Sequential([
data_augmentation,
# ... rest of your CNN layers
])