Skip to main content
PyTorch Essentials
CHAPTER 12 Intermediate

Transfer Learning with Pretrained Models

Updated: May 16, 2026
6 min read

# CHAPTER 12

Transfer Learning with Pretrained Models

1. Introduction

Google, Microsoft, and OpenAI spend millions of dollars training massive neural networks on supercomputers for weeks at a time. These networks learn how to detect millions of complex features in the world. As a beginner on a laptop, you cannot compete with that compute power. But you don't have to! Transfer Learning is the process of downloading one of these massive, pre-trained "brains" and tweaking just the final layer to solve your specific problem. It is the most powerful technique in modern AI.

2. Learning Objectives

By the end of this chapter, you will be able to:
  • Explain the concept of Transfer Learning.
  • Understand the ImageNet dataset.
  • Import pre-trained models like ResNet from torchvision.models.
  • Freeze base layers using requires_grad = False to prevent destroying pre-trained weights.
  • Replace the final Linear layer to classify a custom dataset.

3. The Concept of Transfer Learning

Imagine a master chef who spent 10 years learning how to expertly chop vegetables, balance spices, and manage a kitchen. If you want them to bake a specific type of pie, you don't need to teach them how to hold a knife from scratch. You just give them the pie recipe. Similarly, a CNN trained on millions of images already knows how to detect edges, fur, eyes, and metal. We simply chop off the final "prediction" layer of the CNN and attach our own layer (e.g., "Is this a Hotdog or Not Hotdog?"). The network already knows what food looks like; it just needs a few minutes to learn what a hotdog looks like!

4. ImageNet and Famous Architectures

Most pre-trained models in Computer Vision were trained on ImageNet, a massive dataset of 14 million images categorized into 1,000 different classes. PyTorch's torchvision library includes dozens of these famous architectures built-in:
  • ResNet (e.g., ResNet18, ResNet50): Introduces "residual connections," allowing networks to be extremely deep without losing gradients. The industry workhorse.
  • VGG16: Older, simple architecture, great for learning.
  • EfficientNet: Highly optimized for maximum accuracy with minimal computational cost.

5. Mini Project: Custom Image Classifier

Let's use Transfer Learning to build a world-class image classifier in minutes. We will download a pre-trained ResNet18 and modify it to classify just 2 things (Cat vs Dog) instead of 1000.
python
123456789101112131415161718192021222324252627
import torch
import torch.nn as nn
import torchvision.models as models

# 1. Download the Pre-trained Base Model (ResNet18)
# weights=models.ResNet18_Weights.DEFAULT tells PyTorch to download the trained ImageNet weights
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)

# 2. CRITICAL: Freeze the Base Model
# We do not want Backpropagation to destroy the weights the supercomputer spent weeks training!
for param in model.parameters():
    param.requires_grad = False

# 3. Inspect the final layer
# In ResNet, the final classification layer is named 'fc' (Fully Connected).
# It currently has 1000 output features (ImageNet classes).
num_ftrs = model.fc.in_features
print(f"Original final layer expects {num_ftrs} inputs.")

# 4. Replace the final layer!
# By creating a NEW layer, PyTorch automatically sets requires_grad=True for this layer only.
# We change the output to 2 (Cat or Dog)
model.fc = nn.Linear(num_ftrs, 2)

# The model is now ready! 
# When we run the training loop, ONLY the new 'fc' layer will be trained.
print("Model ready for Transfer Learning.")

6. Fine-Tuning (Advanced)

Once you train your custom final layer for a few epochs and get decent accuracy, you can squeeze out an extra 2-5% accuracy using Fine-Tuning.
  1. 1. Unfreeze the entire model by setting requiresgrad = True for all parameters.
  1. 2. Re-compile the optimizer with a *very small* learning rate (e.g., 1e-5).
  1. 3. Train for a few more epochs.
This allows the pre-trained weights to make micro-adjustments specifically for your unique dataset without destroying their foundational knowledge.

7. Common Mistakes

  • Forgetting to Freeze the Base Model: If you download ResNet and immediately run your training loop without setting requiresgrad=False, the massive, random errors from your untrained final layer will backpropagate violently into the base model, completely destroying the pre-trained weights.
  • Wrong Preprocessing: Models trained on ImageNet expect data to be normalized in a very specific mathematical way (Mean: [0.485, 0.456, 0.406], Std: [0.229, 0.224, 0.225]). If you forget to apply these exact torchvision.transforms to your custom images before feeding them to ResNet, the model will output garbage.

8. Best Practices

  • Always start with Transfer Learning: For any Computer Vision task in the real world, you should *never* build a CNN from scratch unless you are researching new architectures. Transfer Learning will save you weeks of time, massive AWS computing bills, and require significantly less data.

9. Exercises

  1. 1. In PyTorch, what exact line of code do you use to "freeze" a parameter so the Optimizer ignores it during training?
  1. 2. If you are modifying a pre-trained VGG16 model, and its final classification layer is named classifier[6], how would you overwrite it to output 5 classes instead of 1000?

10. MCQ Quiz with Answers

Question 1

In Transfer Learning, what does "Freezing" a layer mean?

Question 2

When downloading models.resnet18(weights=...), what dataset were the default weights trained on?

11. Interview Questions

  • Q: Explain the two-step process of Transfer Learning (Feature Extraction followed by Fine-Tuning) and why different learning rates are required for each step.
  • Q: Why does Transfer Learning allow you to train highly accurate models even if you only have a very small custom dataset (e.g., 500 images)?

12. FAQs

Q: Does Transfer Learning work for text (NLP) too? A: Absolutely! In fact, modern NLP is entirely based on Transfer Learning. Models like BERT, LLaMA, and GPT are massive pre-trained language models that developers fine-tune for specific textual tasks.

13. Summary

Transfer Learning is the ultimate cheat code in Machine Learning. By downloading state-of-the-art architectures from torchvision, freezing their feature-extracting brains, and attaching our own nn.Linear prediction heads, we can build world-class AI applications on a standard laptop in a matter of minutes.

14. Next Chapter Recommendation

We have conquered Computer Vision. Now, we must tackle human language. How does a neural network, which only understands numbers, learn to read English sentences? In Chapter 13: Natural Language Processing Basics with PyTorch, we will learn how to turn words into math.

Finish this Chapter

Save your progress on your learning path and prepare for coding interview challenges.

Discussion

Join the discussion

Log in or create a free account to participate.

Sort: ·