Advanced Convolutional Neural Networks#
Mahmood Amintoosi, Spring 2024
Computer Science Dept, Ferdowsi University of Mashhad
I should mention that the original material was from Tomas Beuzen’s course.
Lecture Outline#
Lecture Learning Objectives#
Load image data using
torchvision.datasets.ImageFolder()
to train a network in PyTorchExplain what “data augmentation” is and why we might want to do it
Be able to save and re-load a PyTorch model
Tune the hyperparameters of a PyTorch model using Ax
Describe what transfer learning is and the different flavours of it: “out-of-the-box”, “feature extractor”, “fine tuning”
Imports#
Show code cell source
# Auto-setup when running on Google Colab
import os
if 'google.colab' in str(get_ipython()) and not os.path.exists('/content/neural-networks'):
!git clone -q https://github.com/fum-cs/neural-networks.git /content/neural-networks
!pip --quiet install -r /content/neural-networks/requirements_colab.txt
%cd neural-networks/notebooks
Show code cell source
import warnings
warnings.filterwarnings('ignore')
import json
import numpy as np
import pandas as pd
from collections import OrderedDict
import torch
import torchvision
from torch import nn, optim
from torchvision import transforms, models, datasets
from torchsummary import summary
from PIL import Image
import memory_profiler # conda install -c anaconda memory_profiler
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from utils.plotting import *
import plotly.io as pio
pio.renderers.default = 'notebook'
1. Datasets, Dataloaders, and Transforms#
1.1. Preparing Data#
torch
andtorchvision
provide out-of-the-box functionality for loading in lots of different kinds of data.The way you create a dataloader depends on the data you have (i.e., do you have numpy arrays, tensors, images, or something else?) and the PyTorch docs can help you out
Loading data into PyTorch is usually a two-step process:
Create a
dataset
(this is your raw data)Create a
dataloader
(this will help you batch your data)
Working with CNNs and images, you’ll mostly be using
torchvision.datasets.ImageFolder()
(docs), it’s very easy to useIt assumes you have a directory structure with sub-directories for each class like this:
data
│
├── class_1
│ ├── image_1.png
│ ├── image_2.png
│ ├── image_3.png
│ └── etc.
└── class_2
├── image_1.png
├── image_2.png
├── image_3.png
└── etc.
For example, consider the training dataset I have in the current directory at
lectures/data/bitmoji_rgb
:
bitmoji_rgb
└── train
├── not_tom
│ ├── image_1.png
│ ├── image_2.png
│ ├── image_3.png
│ └── etc.
└── tom
├── image_1.png
├── image_2.png
├── image_3.png
└── etc.
TRAIN_DIR = "data/bitmoji_rgb/train/"
mem = memory_profiler.memory_usage()[0]
train_dataset = torchvision.datasets.ImageFolder(root=TRAIN_DIR)
print(f"Memory consumed: {memory_profiler.memory_usage()[0] - mem:.0f} mb")
Memory consumed: 0 mb
Notice how our memory usage is the same, we aren’t loading anything in yet, just making PyTorch aware of what kind of data we have and where it is
We can now check various information about our
train_dataset
:
print(f"Classes: {train_dataset.classes}")
print(f"Class count: {train_dataset.targets.count(0)}, {train_dataset.targets.count(1)}")
print(f"Samples:",len(train_dataset))
print(f"First sample: {train_dataset.samples[0]}")
Classes: ['not_tom', 'tom']
Class count: 857, 857
Samples: 1714
First sample: ('data/bitmoji_rgb/train/not_tom\\bitmoji_10187.png', 0)
Now, we could start working with this dataset directly
For example, here’s the first sample:
img, target = next(iter(train_dataset))
print(f"Class: {train_dataset.classes[target]}")
img
Class: not_tom
But often we want to apply some pre-processing to our data
For example,
ImageFolder
loads our data using thePIL
package, but we need tensors!
print(f"Image data type: {type(img)}")
print(f" Image size: {img.size}")
Image data type: <class 'PIL.Image.Image'>
Image size: (128, 128)
Any pre-processing we wish to apply to our images is done using
torchvision.transforms
There are a lot of transformation options here - we’ll explore some more later - for now, we’ll
Resize()
our images and convert themToTensor()
We use
transforms.Compose()
to chain multiple transformations together:
IMAGE_SIZE = 64
data_transforms = transforms.Compose([
transforms.Resize(IMAGE_SIZE),
transforms.ToTensor()
])
train_dataset = torchvision.datasets.ImageFolder(root=TRAIN_DIR,
transform=data_transforms)
img, target = next(iter(train_dataset))
print(f"Image data type: {type(img)}")
print(f" Image size: {img.shape}")
Image data type: <class 'torch.Tensor'>
Image size: torch.Size([3, 64, 64])
Okay cool, but there’s one more issue: we want to work with batches of data, because most of the time, we won’t be able to fit an entire dataset into RAM at once (especially when it comes to image data)
This is where PyTorch’s
dataloader
comes inIt allows us to specify how we want to batch our data:
BATCH_SIZE = 64
mem = memory_profiler.memory_usage()[0]
train_loader = torch.utils.data.DataLoader(train_dataset, # our raw data
batch_size=BATCH_SIZE, # the size of batches we want the dataloader to return
shuffle=True, # shuffle our data before batching
drop_last=False) # don't drop the last batch even if it's smaller than batch_size
print(f"Memory consumed: {memory_profiler.memory_usage()[0] - mem:.0f} mb")
Memory consumed: 0 mb
Once again, we aren’t loading anything yet, we just prepared the loader
We can now query the loader to return a batch of data (this will consume memory):
mem = memory_profiler.memory_usage()[0]
imgs, targets = next(iter(train_loader))
print(f" # of batches: {len(train_loader)}")
print(f" Image data type: {type(imgs)}")
print(f" Image batch size: {imgs.shape}") # dimensions are (batch size, image channels, image height, image width)
print(f" Target batch size: {targets.shape}")
print(f" Batch memory: {memory_profiler.memory_usage()[0] - mem:.2f} mb") # memory usage after loading batch
# of batches: 27
Image data type: <class 'torch.Tensor'>
Image batch size: torch.Size([64, 3, 64, 64])
Target batch size: torch.Size([64])
Batch memory: 6.81 mb
1.2. Saving and Loading PyTorch Models#
The PyTorch documentation about saving and loading models is fantastic and the process is very easy
It’s common PyTorch convention to save models using either a
.pt
or.pth
file extensionIt is recommended that you just save your model learned parameters from
model.state_dict()
:
# Save model
PATH = "models/my_model.pt"
torch.save(model.state_dict(), PATH) # save model at PATH
# Load model
model = MyModelClass() # create an instance of the model
model.load_state_dict(torch.load(PATH)) # load model from PATH
If you’re using the model for inference (not training), make sure to switch it to eval mode:
model.eval()
There are other options for saving models, in particular, if you want to save a model and continue training it later, you’ll want to save other necessary information like the optimizer state, the epoch you’re on, etc. This is all documented here in the PyTorch docs.
Let’s see an example of a model I saved earlier:
class bitmoji_CNN(nn.Module):
def __init__(self):
super().__init__()
self.main = nn.Sequential(
nn.Conv2d(3, 8, (5, 5)),
nn.ReLU(),
nn.MaxPool2d((2, 2)),
nn.Conv2d(8, 4, (3, 3)),
nn.ReLU(),
nn.MaxPool2d((3, 3)),
nn.Flatten(),
nn.Linear(324, 128),
nn.ReLU(),
nn.Linear(128, 1)
)
def forward(self, x):
out = self.main(x)
return out
PATH = "models/bitmoji_cnn.pt"
model = bitmoji_CNN()
model.load_state_dict(torch.load(PATH))
model.eval()
bitmoji_CNN(
(main): Sequential(
(0): Conv2d(3, 8, kernel_size=(5, 5), stride=(1, 1))
(1): ReLU()
(2): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
(3): Conv2d(8, 4, kernel_size=(3, 3), stride=(1, 1))
(4): ReLU()
(5): MaxPool2d(kernel_size=(3, 3), stride=(3, 3), padding=0, dilation=1, ceil_mode=False)
(6): Flatten(start_dim=1, end_dim=-1)
(7): Linear(in_features=324, out_features=128, bias=True)
(8): ReLU()
(9): Linear(in_features=128, out_features=1, bias=True)
)
)
1.3. Data Augmentation#
Data augmentation is used for two main purposes:
Make your CNN more robust to scale/rotation/translation in your images
Increase the size of your training set
Let’s explore point 1 a bit further. We can see below is a Bitmoji of me, does the CNN I loaded above predict this?
image = Image.open('img/tom-bitmoji.png')
image
IMAGE_SIZE = 64
image_tensor = transforms.functional.to_tensor(image.resize((IMAGE_SIZE, IMAGE_SIZE))).unsqueeze(0)
prediction = int(torch.sigmoid(model(image_tensor)) > 0.5)
print(f"Prediction: {train_dataset.classes[prediction]}")
Prediction: tom
Looks good!
But what happens if I flip my image. You can still tell it’s me, but can my CNN?
image = image.rotate(180)
image
image_tensor = transforms.functional.to_tensor(image.resize((IMAGE_SIZE, IMAGE_SIZE))).unsqueeze(0)
prediction = int(torch.sigmoid(model(image_tensor)) > 0.5)
print(f"Prediction: {train_dataset.classes[prediction]}")
Prediction: not_tom
Well that’s problematic. We’d like our CNN to be robust against these kinds of differences
We can expose our CNN to flipped images, so that it can learn to better predict them, with data augmentation
Common image augmentations include:
rotation/flipping
cropping
adding noise
You can view others in the PyTorch docs
We add transforms just like we did previously, using the
transform
argument oftorchvision.datasets.ImageFolder()
:
data_transforms = transforms.Compose([
transforms.RandomVerticalFlip(p=0.5),
transforms.RandomRotation(degrees=20),
transforms.Resize(IMAGE_SIZE),
transforms.ToTensor()
])
train_dataset = torchvision.datasets.ImageFolder(root=TRAIN_DIR,
transform=data_transforms)
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
drop_last=False)
sample_batch, target = next(iter(train_loader))
plot_bitmojis(sample_batch, rgb=True)
Here’s a model I trained earlier using the above augmentations (see Appendix D for the full code):
PATH = "models/bitmoji_cnn_augmented.pt"
model = bitmoji_CNN()
model.load_state_dict(torch.load(PATH))
model.eval();
Let’s try it out on the flipped image:
image
image_tensor = transforms.functional.to_tensor(image.resize((IMAGE_SIZE, IMAGE_SIZE))).unsqueeze(0)
prediction = int(torch.sigmoid(model(image_tensor)) > 0.5)
print(f"Prediction: {train_dataset.classes[prediction]}")
Prediction: tom
Now we got it!
1.4. Batch Normalization (Optional)#
Earlier in the course, we saw how normalizing the inputs to our neural network can help our optimization (by making sure the scale of one feature doesn’t overwhelm others)
But what about the hidden layers of our network? They also have data flowing into them and parameters to optimize, can we normalize them too to make optimization better?
Sure can! Batch normalization is the normalization of data in hidden layers
It is usually applied before the activation function of a hidden layer:
Where:
\(z\) = the output of your hideen layers before the activation function
\(\mu = \frac{1}{n}\sum_{i=1}^{n}z_i\) (i.e., the mean of \(z\))
\(\sigma^2 = \frac{1}{n}\sum_{i=1}^{n}(z_i-\mu)^2\) (i.e, the variance of \(z\))
Batch normalization can help stabilize and speed up optimization, make your network more invariant to changes in the training distribution, and often has a slight regularization effect
See this video by Andrew Ng if you want to learn more about the details
2. Hyperparameter Tuning#
With neural networks we potentially have a lot of hyperparameters to tune:
Number of layers
Number of nodes in each layer
Activation functions
Regularization
Initialization (starting weights)
Optimization hyperparams (learning rate, momentum, weight decay)
etc.
With so many parameters, a grid-search approach to optimization is not feasible
Luckily, there are many packages out there that make neural network hyperparameter tuning fast and easy:
We’ll be using Ax, created by Facebook (just like PyTorch):
pip install ax-platform
I’m pretty new to the platform myself, but it’s easy to use!
Below, I’ve adapted a tutorial from their docs:
from ax.service.managed_loop import optimize
from ax.plot.contour import plot_contour
from ax.plot.trace import optimization_trace_single_method
from ax.utils.notebook.plotting import render, init_notebook_plotting
First, I’ll create some simple training and validation loaders:
TRAIN_DIR = "data/bitmoji_rgb/train/"
VALID_DIR = "data/bitmoji_rgb/valid/"
IMAGE_SIZE = 64
BATCH_SIZE = 128
# Transforms
data_transforms = transforms.Compose([transforms.Resize(IMAGE_SIZE), transforms.ToTensor()])
# Load data and create dataloaders
train_dataset = torchvision.datasets.ImageFolder(root=TRAIN_DIR, transform=data_transforms)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
valid_dataset = torchvision.datasets.ImageFolder(root=VALID_DIR, transform=data_transforms)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=True)
# GPU available?
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using: {device}")
# Model
model = bitmoji_CNN()
model.to(device)
Using: cuda
bitmoji_CNN(
(main): Sequential(
(0): Conv2d(3, 8, kernel_size=(5, 5), stride=(1, 1))
(1): ReLU()
(2): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
(3): Conv2d(8, 4, kernel_size=(3, 3), stride=(1, 1))
(4): ReLU()
(5): MaxPool2d(kernel_size=(3, 3), stride=(3, 3), padding=0, dilation=1, ceil_mode=False)
(6): Flatten(start_dim=1, end_dim=-1)
(7): Linear(in_features=324, out_features=128, bias=True)
(8): ReLU()
(9): Linear(in_features=128, out_features=1, bias=True)
)
)
Now, we need a training function
This function will be re-run multiple times throughout the hyperparameter optimization process, as we wish to train the model on different hyperparameter configurations
The argument
parameters
is a dictionary containing the hyperparameters we wish to tune:
def train(model, train_loader, hyperparameters, epochs=20):
"""Training wrapper for PyTorch network."""
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(),
lr=hyperparameters.get("lr", 0.0001),
betas=(hyperparameters.get("beta1", 0.9), 0.999))
for epoch in range(epochs):
for X, y in train_loader:
if device.type == 'cuda':
X, y = X.to(device), y.to(device)
optimizer.zero_grad()
y_hat = model(X).flatten()
loss = criterion(y_hat, y.type(torch.float32))
loss.backward()
optimizer.step()
return model
We also need an
evaluate()
function that reports how well our model is doing on some validation dataThis will also be called multiple times during the hyperparameter optimization:
def evaluate(model, valid_loader):
"""Validation wrapper for PyTorch network."""
model.eval()
accuracy = 0
with torch.no_grad(): # this stops pytorch doing computational graph stuff under-the-hood and saves memory and time
for X, y in valid_loader:
if device.type == 'cuda':
X, y = X.to(device), y.to(device)
y_hat = model(X).flatten()
y_hat_labels = torch.sigmoid(y_hat) > 0.5
accuracy += (y_hat_labels == y).type(torch.float32).mean().item()
accuracy /= len(valid_loader) # avg accuracy
return accuracy
Let’s make sure our evaluation function is working:
evaluate(model, valid_loader)
0.513671875
Looks good! The accuracy is bad right now because we haven’t trained our model yet
First, train the model without parameter tuning …
hyperparameters = {
"lr": 0.0001,
"beta1": 0.9
}
model = train(model, train_loader, hyperparameters)
evaluate(model, valid_loader)
0.66796875
It is a bit better, but the tuning will produce better results.
We then have a wrapper function that puts everything together
Basically each iteration of hyperparameter optimization (i.e., each time we try a new set of hyperparameters), this function is executed. It trains the model using the given hyperparameters, and then evaluates the model’s performance.
def train_evaluate(parameterization):
model = bitmoji_CNN()
model.to(device)
model = train(model, train_loader, hyperparameters=parameterization)
return evaluate(model, valid_loader)
Finally, we use
optimize()
to run Bayesian optimization on a hyperparameter dictionaryI ran this on a GPU (Nvidia Geforce 3090):
%%time
best_parameters, values, experiment, model = optimize(
parameters=[
{"name": "lr", "type": "range", "bounds": [1e-6, 0.4], "log_scale": True, "value_type": 'float'},
{"name": "beta1", "type": "range", "bounds": [0.3, 0.99], "value_type": 'float'},
],
evaluation_function=train_evaluate,
objective_name='accuracy',
total_trials = 20
)
[INFO 11-07 13:33:35] ax.service.utils.instantiation: Created search space: SearchSpace(parameters=[RangeParameter(name='lr', parameter_type=FLOAT, range=[1e-06, 0.4], log_scale=True), RangeParameter(name='beta1', parameter_type=FLOAT, range=[0.3, 0.99])], parameter_constraints=[]).
[INFO 11-07 13:33:35] ax.modelbridge.dispatch_utils: Using Models.BOTORCH_MODULAR since there is at least one ordered parameter and there are no unordered categorical parameters.
[INFO 11-07 13:33:35] ax.modelbridge.dispatch_utils: Calculating the number of remaining initialization trials based on num_initialization_trials=None max_initialization_trials=None num_tunable_parameters=2 num_trials=None use_batch_trials=False
[INFO 11-07 13:33:35] ax.modelbridge.dispatch_utils: calculated num_initialization_trials=5
[INFO 11-07 13:33:35] ax.modelbridge.dispatch_utils: num_completed_initialization_trials=0 num_remaining_initialization_trials=5
[INFO 11-07 13:33:35] ax.modelbridge.dispatch_utils: `verbose`, `disable_progbar`, and `jit_compile` are not yet supported when using `choose_generation_strategy` with ModularBoTorchModel, dropping these arguments.
[INFO 11-07 13:33:35] ax.modelbridge.dispatch_utils: Using Bayesian Optimization generation strategy: GenerationStrategy(name='Sobol+BoTorch', steps=[Sobol for 5 trials, BoTorch for subsequent trials]). Iterations after 5 will take longer to generate due to model-fitting.
[INFO 11-07 13:33:35] ax.service.managed_loop: Started full optimization with 20 steps.
[INFO 11-07 13:33:35] ax.service.managed_loop: Running optimization trial 1...
[INFO 11-07 13:34:21] ax.service.managed_loop: Running optimization trial 2...
[INFO 11-07 13:35:05] ax.service.managed_loop: Running optimization trial 3...
[INFO 11-07 13:35:48] ax.service.managed_loop: Running optimization trial 4...
[INFO 11-07 13:36:30] ax.service.managed_loop: Running optimization trial 5...
[INFO 11-07 13:37:12] ax.service.managed_loop: Running optimization trial 6...
[INFO 11-07 13:37:58] ax.service.managed_loop: Running optimization trial 7...
[INFO 11-07 13:38:43] ax.service.managed_loop: Running optimization trial 8...
[INFO 11-07 13:39:29] ax.service.managed_loop: Running optimization trial 9...
[INFO 11-07 13:40:15] ax.service.managed_loop: Running optimization trial 10...
[INFO 11-07 13:40:59] ax.service.managed_loop: Running optimization trial 11...
[INFO 11-07 13:41:42] ax.service.managed_loop: Running optimization trial 12...
[INFO 11-07 13:42:26] ax.service.managed_loop: Running optimization trial 13...
[INFO 11-07 13:43:12] ax.service.managed_loop: Running optimization trial 14...
[INFO 11-07 13:43:58] ax.service.managed_loop: Running optimization trial 15...
[INFO 11-07 13:44:42] ax.service.managed_loop: Running optimization trial 16...
[INFO 11-07 13:45:25] ax.service.managed_loop: Running optimization trial 17...
[INFO 11-07 13:46:08] ax.service.managed_loop: Running optimization trial 18...
[INFO 11-07 13:46:54] ax.service.managed_loop: Running optimization trial 19...
[INFO 11-07 13:46:54] ax.modelbridge.base: Untransformed parameter 0.40000000000000013 greater than upper bound 0.4, clamping
[INFO 11-07 13:47:38] ax.service.managed_loop: Running optimization trial 20...
CPU times: total: 12min 51s
Wall time: 14min 45s
best_parameters
{'lr': 0.0012342512404562014, 'beta1': 0.38676806451855483}
means, covariances = values
print(f"Accuracy: {means['accuracy']*100:.2f}%")
Accuracy: 81.29%
Good :)
render(plot_contour(model=model, param_x='lr', param_y='beta1', metric_name='accuracy'))
Show code cell output
Exersices:
Explain ax visualizations
Explain the ax Bayesian Optimization approach, which is based on Gaussian Process
best_objectives = np.array([[trial.objective_mean*100 for trial in experiment.trials.values()]])
best_objective_plot = optimization_trace_single_method(
y=np.maximum.accumulate(best_objectives, axis=1),
title="Model performance vs. # of iterations",
ylabel="Classification Accuracy, %",
)
render(best_objective_plot)
Show code cell output
Exercise: Use ax for finding the best parameters for training a model for MNIST
3. Explaining CNNs (Optional)#
CNNs and neural networks in general are primarily used for prediction (i.e., we want the best prediction performance, and we might not care how we get it)
However interpreting why a model makes certain predictions can be useful
Interpreting neural networks is an active area of research and it is difficult to do
There are a few main options:
Captum is a library for specifically interpreting PyTorch models - it’s quite new still but has some great functionality!
Let’s use it quickly to find what areas of the following bitmoji are important for the models prediction of “tom”
from captum.attr import GradientShap
from captum.attr import Occlusion
from captum.attr import visualization as viz
PATH = "models/bitmoji_cnn_augmented.pt"
model = bitmoji_CNN()
model.load_state_dict(torch.load(PATH))
model.eval()
bitmoji_CNN(
(main): Sequential(
(0): Conv2d(3, 8, kernel_size=(5, 5), stride=(1, 1))
(1): ReLU()
(2): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
(3): Conv2d(8, 4, kernel_size=(3, 3), stride=(1, 1))
(4): ReLU()
(5): MaxPool2d(kernel_size=(3, 3), stride=(3, 3), padding=0, dilation=1, ceil_mode=False)
(6): Flatten(start_dim=1, end_dim=-1)
(7): Linear(in_features=324, out_features=128, bias=True)
(8): ReLU()
(9): Linear(in_features=128, out_features=1, bias=True)
)
)
image
image_tensor = transforms.functional.to_tensor(image.resize((IMAGE_SIZE, IMAGE_SIZE))).unsqueeze(0)
prediction = int(torch.sigmoid(model(image_tensor)) > 0.5)
print(f"Prediction: {train_dataset.classes[prediction]}")
Prediction: tom
# Occlusion
occlusion = Occlusion(model)
attributions_occ = occlusion.attribute(image_tensor,
strides = (1, 3, 3),
sliding_window_shapes=(1, 10, 10),
baselines=0)
_ = viz.visualize_image_attr_multiple(np.transpose(attributions_occ.squeeze().detach().numpy(), (1,2,0)),
np.transpose(image_tensor.squeeze().detach().numpy(), (1,2,0)),
["original_image", "blended_heat_map"],
["all", "positive"],
titles = ["Original Image", "Occlusion"],
cmap="plasma",
fig_size=(6, 6),
alpha_overlay=0.7
)
# Gradient SHAP
torch.manual_seed(2020); np.random.seed(2020)
gradient_shap = GradientShap(model)
rand_img_dist = torch.cat([image_tensor * 0, image_tensor * 1])
attributions_gs = gradient_shap.attribute(image_tensor,
n_samples=20,
stdevs=0.15,
baselines=rand_img_dist,
target=0)
_ = viz.visualize_image_attr_multiple(np.transpose(attributions_gs.squeeze().detach().numpy(), (1,2,0)),
np.transpose(image_tensor.squeeze().detach().numpy(), (1,2,0)),
["original_image", "blended_heat_map"],
["all", "absolute_value"],
titles = ["Original Image", "Gradient SHAP"],
cmap="plasma",
show_colorbar=True,
fig_size=(6, 6),
alpha_overlay=0.7)
4. Transfer Learning#
Transfer learning is one of the most common techniques used in deep learning
It refers to using a model already trained on one task as a starting point for learning to perform another task
There are many famous deep learning architectures out there that have been very successful across a wide range of problem, e.g.: AlexNet, VGG, ResNet, Inception, MobileNet, etc.
Many of these models have been pre-trained on famous datasets like ImageNet (which contains 1.2 million labelled images with 1000 categories)
So, why not use these famous architectures for our own tasks?!
I like to think of there being three main kinds of transfer learning:
Use a pre-trained network out-of-the-box
Use a pre-trained network as a “feature extractor” and add new layers to it for your own task
Same as 2 but “fine-tune” the weights of the pre-trained network using your own data
We’ll briefly explore these options below
4.1. Out-Of-The-Box#
This is the least common option of transfer learning
You basically download a model that performs the same task as you want to do and just use it to predict your own images
We can easily download famous models using the
torchvision.models
module. All models are available with pre-trained weights (based on ImageNet’s 224 x 224 images)For example:
densenet = models.densenet121(pretrained=True)
densenet.eval()
Show code cell output
DenseNet(
(features): Sequential(
(conv0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(norm0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu0): ReLU(inplace=True)
(pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(denseblock1): _DenseBlock(
(denselayer1): _DenseLayer(
(norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(inplace=True)
(conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(denselayer2): _DenseLayer(
(norm1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(inplace=True)
(conv1): Conv2d(96, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(denselayer3): _DenseLayer(
(norm1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(inplace=True)
(conv1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(denselayer4): _DenseLayer(
(norm1): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(inplace=True)
(conv1): Conv2d(160, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(denselayer5): _DenseLayer(
(norm1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(inplace=True)
(conv1): Conv2d(192, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(denselayer6): _DenseLayer(
(norm1): BatchNorm2d(224, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(inplace=True)
(conv1): Conv2d(224, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
)
(transition1): _Transition(
(norm): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
)
(denseblock2): _DenseBlock(
(denselayer1): _DenseLayer(
(norm1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(inplace=True)
(conv1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(denselayer2): _DenseLayer(
(norm1): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(inplace=True)
(conv1): Conv2d(160, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(denselayer3): _DenseLayer(
(norm1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(inplace=True)
(conv1): Conv2d(192, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(denselayer4): _DenseLayer(
(norm1): BatchNorm2d(224, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(inplace=True)
(conv1): Conv2d(224, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(denselayer5): _DenseLayer(
(norm1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(inplace=True)
(conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(denselayer6): _DenseLayer(
(norm1): BatchNorm2d(288, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(inplace=True)
(conv1): Conv2d(288, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(denselayer7): _DenseLayer(
(norm1): BatchNorm2d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(inplace=True)
(conv1): Conv2d(320, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(denselayer8): _DenseLayer(
(norm1): BatchNorm2d(352, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(inplace=True)
(conv1): Conv2d(352, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(denselayer9): _DenseLayer(
(norm1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(inplace=True)
(conv1): Conv2d(384, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(denselayer10): _DenseLayer(
(norm1): BatchNorm2d(416, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(inplace=True)
(conv1): Conv2d(416, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(denselayer11): _DenseLayer(
(norm1): BatchNorm2d(448, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(inplace=True)
(conv1): Conv2d(448, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(denselayer12): _DenseLayer(
(norm1): BatchNorm2d(480, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(inplace=True)
(conv1): Conv2d(480, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
)
(transition2): _Transition(
(norm): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
)
(denseblock3): _DenseBlock(
(denselayer1): _DenseLayer(
(norm1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(inplace=True)
(conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(denselayer2): _DenseLayer(
(norm1): BatchNorm2d(288, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(inplace=True)
(conv1): Conv2d(288, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(denselayer3): _DenseLayer(
(norm1): BatchNorm2d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(inplace=True)
(conv1): Conv2d(320, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(denselayer4): _DenseLayer(
(norm1): BatchNorm2d(352, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(inplace=True)
(conv1): Conv2d(352, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(denselayer5): _DenseLayer(
(norm1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(inplace=True)
(conv1): Conv2d(384, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(denselayer6): _DenseLayer(
(norm1): BatchNorm2d(416, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(inplace=True)
(conv1): Conv2d(416, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(denselayer7): _DenseLayer(
(norm1): BatchNorm2d(448, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(inplace=True)
(conv1): Conv2d(448, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(denselayer8): _DenseLayer(
(norm1): BatchNorm2d(480, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(inplace=True)
(conv1): Conv2d(480, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(denselayer9): _DenseLayer(
(norm1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(inplace=True)
(conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(denselayer10): _DenseLayer(
(norm1): BatchNorm2d(544, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(inplace=True)
(conv1): Conv2d(544, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(denselayer11): _DenseLayer(
(norm1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(inplace=True)
(conv1): Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(denselayer12): _DenseLayer(
(norm1): BatchNorm2d(608, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(inplace=True)
(conv1): Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(denselayer13): _DenseLayer(
(norm1): BatchNorm2d(640, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(inplace=True)
(conv1): Conv2d(640, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(denselayer14): _DenseLayer(
(norm1): BatchNorm2d(672, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(inplace=True)
(conv1): Conv2d(672, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(denselayer15): _DenseLayer(
(norm1): BatchNorm2d(704, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(inplace=True)
(conv1): Conv2d(704, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(denselayer16): _DenseLayer(
(norm1): BatchNorm2d(736, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(inplace=True)
(conv1): Conv2d(736, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(denselayer17): _DenseLayer(
(norm1): BatchNorm2d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(inplace=True)
(conv1): Conv2d(768, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(denselayer18): _DenseLayer(
(norm1): BatchNorm2d(800, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(inplace=True)
(conv1): Conv2d(800, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(denselayer19): _DenseLayer(
(norm1): BatchNorm2d(832, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(inplace=True)
(conv1): Conv2d(832, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(denselayer20): _DenseLayer(
(norm1): BatchNorm2d(864, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(inplace=True)
(conv1): Conv2d(864, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(denselayer21): _DenseLayer(
(norm1): BatchNorm2d(896, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(inplace=True)
(conv1): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(denselayer22): _DenseLayer(
(norm1): BatchNorm2d(928, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(inplace=True)
(conv1): Conv2d(928, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(denselayer23): _DenseLayer(
(norm1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(inplace=True)
(conv1): Conv2d(960, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(denselayer24): _DenseLayer(
(norm1): BatchNorm2d(992, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(inplace=True)
(conv1): Conv2d(992, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
)
(transition3): _Transition(
(norm): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
)
(denseblock4): _DenseBlock(
(denselayer1): _DenseLayer(
(norm1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(inplace=True)
(conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(denselayer2): _DenseLayer(
(norm1): BatchNorm2d(544, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(inplace=True)
(conv1): Conv2d(544, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(denselayer3): _DenseLayer(
(norm1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(inplace=True)
(conv1): Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(denselayer4): _DenseLayer(
(norm1): BatchNorm2d(608, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(inplace=True)
(conv1): Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(denselayer5): _DenseLayer(
(norm1): BatchNorm2d(640, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(inplace=True)
(conv1): Conv2d(640, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(denselayer6): _DenseLayer(
(norm1): BatchNorm2d(672, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(inplace=True)
(conv1): Conv2d(672, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(denselayer7): _DenseLayer(
(norm1): BatchNorm2d(704, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(inplace=True)
(conv1): Conv2d(704, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(denselayer8): _DenseLayer(
(norm1): BatchNorm2d(736, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(inplace=True)
(conv1): Conv2d(736, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(denselayer9): _DenseLayer(
(norm1): BatchNorm2d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(inplace=True)
(conv1): Conv2d(768, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(denselayer10): _DenseLayer(
(norm1): BatchNorm2d(800, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(inplace=True)
(conv1): Conv2d(800, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(denselayer11): _DenseLayer(
(norm1): BatchNorm2d(832, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(inplace=True)
(conv1): Conv2d(832, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(denselayer12): _DenseLayer(
(norm1): BatchNorm2d(864, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(inplace=True)
(conv1): Conv2d(864, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(denselayer13): _DenseLayer(
(norm1): BatchNorm2d(896, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(inplace=True)
(conv1): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(denselayer14): _DenseLayer(
(norm1): BatchNorm2d(928, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(inplace=True)
(conv1): Conv2d(928, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(denselayer15): _DenseLayer(
(norm1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(inplace=True)
(conv1): Conv2d(960, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(denselayer16): _DenseLayer(
(norm1): BatchNorm2d(992, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(inplace=True)
(conv1): Conv2d(992, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
)
(norm5): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(classifier): Linear(in_features=1024, out_features=1000, bias=True)
)
You can check out densenet’s architecture by printing it to screen but it’s huge so I won’t do that here
The layers can be accessed using the
.named_children()
method, the last one is the classification layer, a fully-connected layer outputting 1000 values (one for each ImageNet class):
list(densenet.named_children())[-1]
('classifier', Linear(in_features=1024, out_features=1000, bias=True))
The ImageNet class labels are stored in a file in this directory:
classes = json.load(open("data/imagenet_class_index.json"))
idx2label = [classes[str(k)][1] for k in range(len(classes))]
print("First 10 ImageNet classes:")
idx2label[:10]
First 10 ImageNet classes:
['tench',
'goldfish',
'great_white_shark',
'tiger_shark',
'hammerhead',
'electric_ray',
'stingray',
'cock',
'hen',
'ostrich']
Let’s try the model out on some random images:
image = Image.open('img/evie.png')
image
image = transforms.functional.to_tensor(image.resize((224, 224))).unsqueeze(0)
_, idx = torch.softmax(densenet(image), dim=1).topk(5)
print(f"Top 3 predictions: {[idx2label[_.item()] for _ in idx[0]]}")
Top 3 predictions: ['toy_poodle', 'Yorkshire_terrier', 'Maltese_dog', 'miniature_poodle', 'Shih-Tzu']
Not bad! Can we trick it?
image = Image.open('img/panda-tom.png')
image
image = transforms.functional.to_tensor(image).unsqueeze(0)
_, idx = torch.softmax(densenet(image), dim=1).topk(5)
print(f"Top 5 predictions: {[idx2label[_.item()] for _ in idx[0]]}")
Top 5 predictions: ['mask', 'ski_mask', 'teddy', 'giant_panda', 'jersey']
Not bad either!
Anyway, you get the point. This workflow is constrained to the architecture of the model, i.e., we can only predict the ImageNet classes at this point. What if I wanted to make prediction for another problem, say a binary classification problem? Read on.
4.2. Feature Extractor#
In this method, we use a pre-trained model as a “feature extractor” which creates useful features for us that we can use to train some other model
We really have two options here:
Add some extra layers to the pre-trained network to suit our particular task
Pass training data through the network and save the output to use as features for training some other model
Let’s do approach 1 first. Let’s adapt
densenet
to predict our bitmoji data. I’m going to load the model, and then “freeze” all of its parameters (we don’t want to update them!)
densenet = models.densenet121(pretrained=True)
for param in densenet.parameters(): # Freeze parameters so we don't update them
param.requires_grad = False
We saw before that the last layer of
densenet
is a fully-connected linear layerLinear(in_features=1024, out_features=1000)
We are going to do binary classification, so I’m going to replace this layer with my own layers (I’m using
OrderedDict()
here so I can name my layers, but you don’t have to do this):
new_layers = nn.Sequential(OrderedDict([
('new1', nn.Linear(1024, 500)),
('relu', nn.ReLU()),
('new2', nn.Linear(500, 1))
]))
densenet.classifier = new_layers
Let’s check that the last layer of our model is updated:
densenet.classifier
Sequential(
(new1): Linear(in_features=1024, out_features=500, bias=True)
(relu): ReLU()
(new2): Linear(in_features=500, out_features=1, bias=True)
)
Looks good!
Now we need to train our new layers:
IMAGE_SIZE = 224
BATCH_SIZE = 128
# New dataloaders
train_dataset = torchvision.datasets.ImageFolder(root=TRAIN_DIR, transform=data_transforms)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
valid_dataset = torchvision.datasets.ImageFolder(root=VALID_DIR, transform=data_transforms)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=True)
def trainer(model, criterion, optimizer, train_loader, valid_loader, device, epochs=5, verbose=True):
"""Simple training wrapper for PyTorch network."""
train_accuracy = []
valid_accuracy = []
for epoch in range(epochs): # for each epoch
train_batch_loss = 0
train_batch_acc = 0
valid_batch_loss = 0
valid_batch_acc = 0
# Training
for X, y in train_loader:
if device.type == 'cuda':
X, y = X.to(device), y.to(device)
optimizer.zero_grad()
y_hat = model(X).flatten()
y_hat_labels = torch.sigmoid(y_hat) > 0.5
loss = criterion(y_hat, y.type(torch.float32))
loss.backward()
optimizer.step()
train_batch_loss += loss.item()
train_batch_acc += (y_hat_labels == y).type(torch.float32).mean().item()
train_accuracy.append(train_batch_acc / len(train_loader))
# Validation
model.eval()
with torch.no_grad():
for X, y in valid_loader:
if device.type == 'cuda':
X, y = X.to(device), y.to(device)
y_hat = model(X).flatten()
y_hat_labels = torch.sigmoid(y_hat) > 0.5
loss = criterion(y_hat, y.type(torch.float32))
valid_batch_loss += loss.item()
valid_batch_acc += (y_hat_labels == y).type(torch.float32).mean().item()
valid_accuracy.append(valid_batch_acc / len(valid_loader))
model.train()
# Print progress
if verbose:
print(f"Epoch {epoch + 1}:",
f"Train Accuracy: {train_accuracy[-1]:.2f}.",
f"Valid Accuracy: {valid_accuracy[-1]:.2f}.")
return {"train_accuracy": train_accuracy, "valid_accuracy": valid_accuracy}
%%time
# We have a big model so this will take some time to run! If you have a GPU, things could be much faster!
densenet.to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(densenet.parameters())
results = trainer(densenet, criterion, optimizer, train_loader, valid_loader, device, epochs=10)
Epoch 1: Train Accuracy: 0.61. Valid Accuracy: 0.67.
Epoch 2: Train Accuracy: 0.73. Valid Accuracy: 0.76.
Epoch 3: Train Accuracy: 0.77. Valid Accuracy: 0.75.
Epoch 4: Train Accuracy: 0.80. Valid Accuracy: 0.71.
Epoch 5: Train Accuracy: 0.83. Valid Accuracy: 0.74.
Epoch 6: Train Accuracy: 0.85. Valid Accuracy: 0.75.
Epoch 7: Train Accuracy: 0.88. Valid Accuracy: 0.67.
Epoch 8: Train Accuracy: 0.88. Valid Accuracy: 0.74.
Epoch 9: Train Accuracy: 0.90. Valid Accuracy: 0.76.
Epoch 10: Train Accuracy: 0.90. Valid Accuracy: 0.74.
CPU times: total: 26.8 s
Wall time: 35.8 s
Cool, we leveraged the power of
densenet
to get a really great model!Now, you can use pre-trained model as arbitrary feature extractors, you don’t have to add on layers, you can just extract the output values of the network (well, you can extract values from any layer you like) and use those values as “features” to train another model
Below, I’m going to pass all my bitmoji data through the network and save the outputs:
def get_features(model, train_loader, valid_loader):
"""Extract output of squeezenet model"""
with torch.no_grad(): # turn off computational graph stuff
Z_train = torch.empty((0, 1024)) # Initialize empty tensors
y_train = torch.empty((0))
Z_valid = torch.empty((0, 1024))
y_valid = torch.empty((0))
for X, y in train_loader:
Z_train = torch.cat((Z_train, model(X)), dim=0)
y_train = torch.cat((y_train, y))
for X, y in valid_loader:
Z_valid = torch.cat((Z_valid, model(X)), dim=0)
y_valid = torch.cat((y_valid, y))
return Z_train.detach(), y_train.detach(), Z_valid.detach(), y_valid.detach()
densenet = models.densenet121(pretrained=True)
densenet.classifier = nn.Identity() # remove that last "classification" layer
Z_train, y_train, Z_valid, y_valid = get_features(densenet, train_loader, valid_loader)
Now we have some extracted features
Let’s train a classifier on the data, say, a
LogisticRegression()
model:
%%time
# Let's scale our data
scaler = StandardScaler()
Z_train = scaler.fit_transform(Z_train)
Z_valid = scaler.transform(Z_valid)
# Fit a model
model = LogisticRegression(max_iter=1000)
model.fit(Z_train, y_train)
print(f"Train accuracy: {model.score(Z_train, y_train) * 100:.2f}%")
print(f"Valid accuracy: {model.score(Z_valid, y_valid) * 100:.2f}%")
Train accuracy: 100.00%
Valid accuracy: 77.00%
CPU times: total: 969 ms
Wall time: 415 ms
So what did we just do:
We passed out bitmoji images through squeezenet and saved all the output values. Squeezenet outputs 1000 values per input. We had 1714 bitmoji images, so we extracted a tensor of shape
(1714, 1000)
from squeezenet.So we now have a dataset of 1000 features and 1714 examples. Our target remains binary
("not_tom", "tom")
=(0, 1)
. We used this data to train a logistic regression model. Cool!
4.3. Fine Tuning#
Okay, this is the final and most common workflow of transfer learning
Above, we stacked some extra layers onto
densenet
and just trained those layer (we “froze” all ofdensenet
’s weights)But we can also “fine tune”
densenet
’s weights if we like, to make it more suited to our dataWe can choose to “fine tune” all of
densenet
’s ~8 million parameters, or just some of themTo do this, we use the same workflow as in 4.2 above, but we unfreeze the layers we wish to “fine-tune”
# Load (but don't freeze!) the model
densenet = models.densenet121(pretrained=True)
# Replace classification layer
new_layers = nn.Sequential(OrderedDict([
('new1', nn.Linear(1024, 500)),
('relu', nn.ReLU()),
('new2', nn.Linear(500, 1))
]))
densenet.classifier = new_layers
# Move to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
densenet.to(device);
%%time
# Train the model (I did this on a GPU)
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(densenet.parameters())
results = trainer(densenet, criterion, optimizer, train_loader, valid_loader, device, epochs=10)
Epoch 1: Train Accuracy: 0.75. Valid Accuracy: 0.74.
Epoch 2: Train Accuracy: 0.95. Valid Accuracy: 0.94.
Epoch 3: Train Accuracy: 0.97. Valid Accuracy: 0.96.
Epoch 4: Train Accuracy: 0.98. Valid Accuracy: 0.97.
Epoch 5: Train Accuracy: 0.97. Valid Accuracy: 0.93.
Epoch 6: Train Accuracy: 0.98. Valid Accuracy: 0.95.
Epoch 7: Train Accuracy: 0.99. Valid Accuracy: 0.98.
Epoch 8: Train Accuracy: 0.98. Valid Accuracy: 0.96.
Epoch 9: Train Accuracy: 0.99. Valid Accuracy: 0.93.
Epoch 10: Train Accuracy: 0.99. Valid Accuracy: 0.95.
CPU times: total: 28.2 s
Wall time: 41.6 s
Wow! By far our best results yet
You could also choose to fine-tune just some layers, for example, below I’ll freeze everything but the last two layers:
# Freeze all but the last two layers
for layer in densenet.features[:-2]:
for param in layer.parameters():
param.requires_grad = False
# Now re-train...
In a nutshell: fine-tuning more layers typically leads to better accuracy but at greater computational cost. There’s always a trade-off you have to consider!
4.4. Transfer Learning Summary#
Use a pre-trained model out-of-the-box (good if a model already exists for your problem)
Use a pre-trained model as a “feature extractor” (good if you want to adapt a pre-trained model for a specific problem)
Fine-tune a pre-trained model (same as 2 but generally yields better results, although at more computational cost)
Exercises, Read the following papers:#
5. The Lecture in Three Conjectures#
PyTorch makes data loading easy with
dataset
anddataloader
.Hyperparameter tuning is hard. Use software to help you - I recommend Ax.
Transfer learning is a great way to leverage successful models and datasets for your own problem. I do this all the time.