Transfer Learning in Practice
Master transfer learning strategies by fine-tuning EfficientNet on custom datasets and tracking experiments with Weights & Biases
Objectives
By the end of this practical work, you will be able to:
- Fine-tune EfficientNet on a custom image classification dataset
- Compare different transfer learning strategies (feature extraction, full fine-tuning, progressive unfreezing)
- Track experiments and visualize metrics with Weights & Biases
Prerequisites
- Kaggle account (for downloading datasets)
- Weights & Biases account (free tier available at wandb.ai)
- Small custom dataset (e.g., subset of Oxford 102 Flowers or Oxford-IIIT Pets)
- Understanding of CNN architectures and transfer learning concepts
Install required packages:
pip install tensorflow wandb tensorflow-datasets scikit-learn matplotlib
Instructions
Step 1: Set Up Weights & Biases
Create a Weights & Biases account and configure your API key:
import wandb # (#1:Import W&B library)
from wandb.integration.keras import WandbMetricsLogger # (#2:Keras callback for logging)
# Login to W&B (will prompt for API key on first run)
wandb.login() # (#3:Authenticate with W&B)
# Initialize a new run
wandb.init(
project="transfer-learning-comparison", # (#4:Project name in W&B)
config={
"architecture": "EfficientNetB0",
"dataset": "oxford_flowers102",
"epochs": 15,
"batch_size": 32
}
)
Note: Your API key can be found at wandb.ai/authorize. Store it securely and never commit it to version control.
Step 2: Load Custom Image Classification Dataset
Load the Oxford 102 Flowers dataset (or similar) using TensorFlow Datasets:
import tensorflow as tf
import tensorflow_datasets as tfds
# Load Oxford 102 Flowers dataset
(train_ds, val_ds, test_ds), info = tfds.load( # (#1:Load dataset with splits)
'oxford_flowers102',
split=['train', 'validation', 'test'],
with_info=True,
as_supervised=True # (#2:Returns (image, label) tuples)
)
num_classes = info.features['label'].num_classes # (#3:Get number of classes)
print(f"Number of classes: {num_classes}")
print(f"Training samples: {info.splits['train'].num_examples}")
# Define image size for EfficientNet
IMG_SIZE = 224
BATCH_SIZE = 32
Step 3: Create Data Generators with Augmentation
Set up preprocessing and data augmentation pipelines:
def preprocess(image, label):
"""Resize and normalize images."""
image = tf.image.resize(image, [IMG_SIZE, IMG_SIZE]) # (#1:Resize to model input size)
image = tf.cast(image, tf.float32)
image = tf.keras.applications.efficientnet.preprocess_input(image) # (#2:EfficientNet preprocessing)
return image, label
def augment(image, label):
"""Apply data augmentation."""
image = tf.image.random_flip_left_right(image) # (#3:Random horizontal flip)
image = tf.image.random_brightness(image, 0.2) # (#4:Random brightness adjustment)
image = tf.image.random_contrast(image, 0.8, 1.2) # (#5:Random contrast)
return image, label
# Prepare training data with augmentation
train_data = train_ds.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
train_data = train_data.map(augment, num_parallel_calls=tf.data.AUTOTUNE) # (#6:Apply augmentation)
train_data = train_data.shuffle(1000).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
# Validation and test data (no augmentation)
val_data = val_ds.map(preprocess).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
test_data = test_ds.map(preprocess).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
Step 4: Load EfficientNetB0 with ImageNet Weights
Load the pre-trained model without the classification head:
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras import layers, Model
def create_model(num_classes, trainable_base=False):
"""Create EfficientNetB0-based classifier."""
# Load pre-trained EfficientNetB0
base_model = EfficientNetB0(
weights='imagenet', # (#1:Load ImageNet weights)
include_top=False, # (#2:Exclude original classifier head)
input_shape=(IMG_SIZE, IMG_SIZE, 3)
)
# Freeze or unfreeze base model
base_model.trainable = trainable_base # (#3:Control which layers train)
# Add custom classification head
inputs = layers.Input(shape=(IMG_SIZE, IMG_SIZE, 3))
x = base_model(inputs, training=False) # (#4:Pass through base model)
x = layers.GlobalAveragePooling2D()(x) # (#5:Global pooling)
x = layers.Dropout(0.3)(x) # (#6:Regularization)
x = layers.Dense(256, activation='relu')(x) # (#7:Dense layer)
x = layers.Dropout(0.2)(x)
outputs = layers.Dense(num_classes, activation='softmax')(x) # (#8:Output layer)
model = Model(inputs, outputs)
return model, base_model
# Create model with frozen base
model, base_model = create_model(num_classes, trainable_base=False)
model.summary()
224x224x3
(ImageNet weights)
+ Dropout
102 classes
Step 5: Strategy 1 - Feature Extraction
Train only the classifier head while keeping the base model frozen:
# Initialize W&B run for feature extraction
wandb.init(
project="transfer-learning-comparison",
name="strategy1-feature-extraction", # (#1:Name this experiment)
config={"strategy": "feature_extraction", "epochs": 5}
)
# Compile model with higher learning rate (only training new layers)
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3), # (#2:Higher LR for new layers)
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
# Train for 5 epochs
history_fe = model.fit(
train_data,
validation_data=val_data,
epochs=5, # (#3:5 epochs for feature extraction)
callbacks=[WandbMetricsLogger()] # (#4:Log metrics to W&B)
)
# Evaluate
fe_results = model.evaluate(test_data)
print(f"Feature Extraction - Test Accuracy: {fe_results[1]:.4f}")
wandb.log({"feature_extraction_test_accuracy": fe_results[1]}) # (#5:Log final result)
wandb.finish()
Tip: Feature extraction is fast because we only train a small number of parameters in the classifier head.
Step 6: Log Metrics to W&B
Monitor training progress in the Weights & Biases dashboard:
# Custom logging during training
class WandbCallback(tf.keras.callbacks.Callback):
"""Custom callback for detailed W&B logging."""
def on_epoch_end(self, epoch, logs=None):
# Log all metrics
wandb.log({
"epoch": epoch,
"train_loss": logs.get('loss'), # (#1:Training loss)
"train_accuracy": logs.get('accuracy'), # (#2:Training accuracy)
"val_loss": logs.get('val_loss'), # (#3:Validation loss)
"val_accuracy": logs.get('val_accuracy'), # (#4:Validation accuracy)
"learning_rate": float(self.model.optimizer.learning_rate) # (#5:Current LR)
})
def on_train_end(self, logs=None):
# Log model summary
wandb.log({"trainable_params": sum([
tf.reduce_prod(var.shape).numpy()
for var in self.model.trainable_variables
])}) # (#6:Total trainable parameters)
Step 7: Strategy 2 - Full Fine-Tuning
Unfreeze all layers and train with a lower learning rate:
# Create fresh model for fair comparison
model2, base_model2 = create_model(num_classes, trainable_base=True) # (#1:Unfreeze all layers)
# Initialize W&B run
wandb.init(
project="transfer-learning-comparison",
name="strategy2-full-finetuning",
config={"strategy": "full_finetuning", "epochs": 10}
)
# Compile with lower learning rate
model2.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5), # (#2:Lower LR to avoid destroying pretrained weights)
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
# Learning rate scheduler
lr_scheduler = tf.keras.callbacks.ReduceLROnPlateau(
monitor='val_loss',
factor=0.5, # (#3:Reduce LR by half)
patience=2,
min_lr=1e-7
)
# Train for 10 epochs
history_ft = model2.fit(
train_data,
validation_data=val_data,
epochs=10, # (#4:10 epochs for full fine-tuning)
callbacks=[WandbMetricsLogger(), lr_scheduler]
)
# Evaluate
ft_results = model2.evaluate(test_data)
print(f"Full Fine-tuning - Test Accuracy: {ft_results[1]:.4f}")
wandb.log({"full_finetuning_test_accuracy": ft_results[1]})
wandb.finish()
Warning: Using a high learning rate during fine-tuning can destroy the pretrained weights (catastrophic forgetting). Always use a lower learning rate (1e-5 or lower).
Step 8: Compare Validation Accuracies
Create a comparison visualization:
import matplotlib.pyplot as plt
def plot_comparison(history1, history2, title):
"""Plot training comparison between strategies."""
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# Accuracy comparison
axes[0].plot(history1.history['val_accuracy'], label='Feature Extraction') # (#1:Strategy 1 accuracy)
axes[0].plot(history2.history['val_accuracy'], label='Full Fine-tuning') # (#2:Strategy 2 accuracy)
axes[0].set_title('Validation Accuracy Comparison')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Accuracy')
axes[0].legend()
axes[0].grid(True)
# Loss comparison
axes[1].plot(history1.history['val_loss'], label='Feature Extraction')
axes[1].plot(history2.history['val_loss'], label='Full Fine-tuning')
axes[1].set_title('Validation Loss Comparison')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Loss')
axes[1].legend()
axes[1].grid(True)
plt.tight_layout()
# Log to W&B
wandb.log({"comparison_plot": wandb.Image(fig)}) # (#3:Log plot to W&B)
plt.show()
# Create comparison table
results_table = wandb.Table(
columns=["Strategy", "Test Accuracy", "Training Time"], # (#4:Define columns)
data=[
["Feature Extraction", fe_results[1], "~5 min"],
["Full Fine-tuning", ft_results[1], "~15 min"]
]
)
wandb.log({"results_comparison": results_table})
Note: Actual curves may vary. Feature extraction plateaus early; full fine-tuning achieves higher accuracy.
Step 9: Strategy 3 - Progressive Unfreezing
Gradually unfreeze layers in stages for optimal transfer:
# Create fresh model
model3, base_model3 = create_model(num_classes, trainable_base=False)
wandb.init(
project="transfer-learning-comparison",
name="strategy3-progressive-unfreezing",
config={"strategy": "progressive_unfreezing", "epochs": 15}
)
# Stage 1: Train only classifier (3 epochs)
model3.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
print("Stage 1: Training classifier head...")
model3.fit(train_data, validation_data=val_data, epochs=3,
callbacks=[WandbMetricsLogger()])
# Stage 2: Unfreeze top 50 layers (5 epochs)
base_model3.trainable = True
for layer in base_model3.layers[:-50]: # (#1:Freeze all but top 50 layers)
layer.trainable = False
model3.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=5e-5), # (#2:Lower LR)
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
print("Stage 2: Fine-tuning top 50 layers...")
model3.fit(train_data, validation_data=val_data, epochs=5,
callbacks=[WandbMetricsLogger()])
# Stage 3: Unfreeze all layers (7 epochs)
for layer in base_model3.layers: # (#3:Unfreeze all layers)
layer.trainable = True
model3.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5), # (#4:Even lower LR)
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
print("Stage 3: Fine-tuning all layers...")
history_prog = model3.fit(train_data, validation_data=val_data, epochs=7,
callbacks=[WandbMetricsLogger()])
# Evaluate
prog_results = model3.evaluate(test_data)
print(f"Progressive Unfreezing - Test Accuracy: {prog_results[1]:.4f}")
wandb.log({"progressive_unfreezing_test_accuracy": prog_results[1]})
wandb.finish()
Best Practice: Progressive unfreezing often achieves the best results by allowing the model to adapt gradually while preserving pretrained features.
Note: Gradual unfreezing preserves low-level features while adapting high-level representations.
Step 10: Analyze W&B Dashboard
Navigate to your W&B project dashboard to analyze experiments:
# Create a comprehensive comparison report
import pandas as pd
all_results = {
"Strategy": ["Feature Extraction", "Full Fine-tuning", "Progressive Unfreezing"],
"Test Accuracy": [fe_results[1], ft_results[1], prog_results[1]],
"Trainable Params": [
sum([tf.reduce_prod(v.shape).numpy() for v in model.trainable_variables]),
sum([tf.reduce_prod(v.shape).numpy() for v in model2.trainable_variables]),
sum([tf.reduce_prod(v.shape).numpy() for v in model3.trainable_variables])
]
}
results_df = pd.DataFrame(all_results)
print("\n" + "="*60)
print("TRANSFER LEARNING STRATEGY COMPARISON")
print("="*60)
print(results_df.to_string(index=False)) # (#1:Display comparison table)
# Log final comparison to W&B
wandb.init(project="transfer-learning-comparison", name="final-comparison")
wandb.log({"final_results": wandb.Table(dataframe=results_df)}) # (#2:Log to W&B)
# Print W&B dashboard URL
print(f"\nView full comparison at: https://wandb.ai/{wandb.run.entity}/{wandb.run.project}") # (#3:Dashboard link)
wandb.finish()
Info: In the W&B dashboard, you can compare runs side-by-side, create custom charts, and share reports with your team.
Step 11: Select Best Model and Evaluate
Choose the best performing model and perform detailed evaluation:
import numpy as np
# Determine best model
accuracies = [fe_results[1], ft_results[1], prog_results[1]]
best_idx = np.argmax(accuracies)
best_model = [model, model2, model3][best_idx]
strategy_names = ["Feature Extraction", "Full Fine-tuning", "Progressive Unfreezing"]
print(f"\nBest Strategy: {strategy_names[best_idx]}") # (#1:Identify best strategy)
print(f"Test Accuracy: {accuracies[best_idx]:.4f}")
# Get predictions on test set
y_true = []
y_pred = []
for images, labels in test_data:
predictions = best_model.predict(images, verbose=0) # (#2:Generate predictions)
y_true.extend(labels.numpy())
y_pred.extend(np.argmax(predictions, axis=1))
y_true = np.array(y_true)
y_pred = np.array(y_pred)
# Calculate per-class accuracy
unique_classes = np.unique(y_true)
class_accuracies = []
for c in unique_classes:
mask = y_true == c
class_acc = np.mean(y_pred[mask] == c) # (#3:Per-class accuracy)
class_accuracies.append(class_acc)
print(f"\nMean per-class accuracy: {np.mean(class_accuracies):.4f}")
print(f"Std per-class accuracy: {np.std(class_accuracies):.4f}")
Step 12: Generate Classification Report
Create a detailed classification report with precision, recall, and F1-score:
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
# Generate classification report
print("\n" + "="*60)
print("CLASSIFICATION REPORT")
print("="*60)
report = classification_report(
y_true,
y_pred,
output_dict=True # (#1:Get report as dictionary)
)
print(classification_report(y_true, y_pred))
# Log metrics to W&B
wandb.init(project="transfer-learning-comparison", name="final-evaluation")
wandb.log({
"macro_precision": report['macro avg']['precision'], # (#2:Macro-averaged precision)
"macro_recall": report['macro avg']['recall'],
"macro_f1": report['macro avg']['f1-score'],
"weighted_f1": report['weighted avg']['f1-score'] # (#3:Weighted F1)
})
# Create confusion matrix visualization
plt.figure(figsize=(12, 10))
cm = confusion_matrix(y_true, y_pred)
# For large number of classes, show only a subset
if len(unique_classes) > 20:
print("(Showing subset of confusion matrix due to large number of classes)")
cm_subset = cm[:20, :20] # (#4:Subset for visualization)
sns.heatmap(cm_subset, annot=True, fmt='d', cmap='Blues')
else:
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.title('Confusion Matrix')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
wandb.log({"confusion_matrix": wandb.Image(plt)}) # (#5:Log to W&B)
plt.show()
# Save best model
best_model.save('best_transfer_learning_model.h5') # (#6:Save model)
print("\nModel saved to 'best_transfer_learning_model.h5'")
wandb.finish()
Expected Output
After completing this practical work, you should achieve:
- Feature Extraction: ~70-75% accuracy (quick baseline)
- Full Fine-tuning: ~80-85% accuracy
- Progressive Unfreezing: ~85%+ accuracy (target)
Target: Achieve 85%+ validation accuracy with the progressive unfreezing strategy on the Oxford 102 Flowers dataset.
Your W&B dashboard should show:
- Training curves for all three strategies
- Side-by-side comparison of validation metrics
- Final test accuracies and confusion matrices
Extraction
Fine-tuning
Unfreezing
Note: Actual accuracy may vary based on dataset, hyperparameters, and training conditions.
Deliverables
- Jupyter Notebook: Complete notebook with all code, outputs, and analysis
- W&B Dashboard Link: Share the link to your experiment dashboard showing all three strategies
- Best Model: Saved model file (
best_transfer_learning_model.h5) - Comparison Report: Brief summary (in notebook) discussing:
- Which strategy performed best and why
- Trade-offs between training time and accuracy
- Observations from the W&B dashboard
Bonus Challenges
- Challenge 1: Use the
timmlibrary with PyTorch to implement the same experiments:import timm model = timm.create_model('efficientnet_b0', pretrained=True, num_classes=102) - Challenge 2: Compare EfficientNet-B0 vs ResNet50 - which performs better on your dataset?
- Challenge 3: Implement learning rate warmup and cosine annealing schedule
- Challenge 4: Try MixUp or CutMix augmentation to boost accuracy further
- Challenge 5: Export your best model to TensorFlow Lite for mobile deployment