A practical example in Transfer learning with PyTorch

Source: Deep Learning on Medium

Transfer learning with PyTorch

PyTorch offer us several trained networks ready to download to your computer. Here are the available models. For our purpose, we are going to choose AlexNet.

Each model has its own benefits to solve a particular type of problem. Ranging from image classification to semantic segmentation. Some are faster than others and required less/more computation power to run. For example choosing SqueezeNet requires 50x fewer parameters than AlexNet while achieving the same accuracy in ImageNet dataset, so it is a fast, smaller and high precision network architecture (suitable for embedded devices with low power) while VGG network architecture have better precision than AlexNet or SqueezeNet but is more heavier to train and run in inference process. Below, you can see different network architectures and its size downloaded by PyTorch in a cache directory.

different model sizes

So far we have only talked about theory, let’s put the concepts into practice. In this GitHub Page you have all the code necessary to collect your data, train the model and running it in a live demo.

First, let’s import all the necessary packages

import torch
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms

Now we use the ImageFolder dataset class available with the torchvision.datasets package. We attach transforms to prepare the data for training and then split the dataset into training and test sets

dataset = datasets.ImageFolder(
transforms.ColorJitter(0.1, 0.1, 0.1, 0.1),
transforms.Resize((224, 224)),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [len(dataset) - 50, 50])

We’ll create two DataLoader instances, which provide utilities for shuffling data, producing batches of images, and loading the samples in parallel with multiple workers.

train_loader = torch.utils.data.DataLoader(

test_loader = torch.utils.data.DataLoader(

Now, we define the neural network we’ll be training. The alexnet model was originally trained for a dataset that had 1000 class labels, but our dataset only has two class labels! We’ll replace the final layer with a new, untrained layer that has only two outputs ( 👍 and 👎).

model = models.alexnet(pretrained=True)
model.classifier[6] = torch.nn.Linear(model.classifier[6].in_features, 2)

Now, it’s time to train the neural network and save the model with the best performance possible. Feel free to try different hyperparameters and see how it performs

BEST_MODEL_PATH = 'best_model.pth'
best_accuracy = 0.0

optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

for epoch in range(NUM_EPOCHS):

for images, labels in iter(train_loader):
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
loss = F.cross_entropy(outputs, labels)

test_error_count = 0.0
for images, labels in iter(test_loader):
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
test_error_count += float(torch.sum(torch.abs(labels - outputs.argmax(1))))

test_accuracy = 1.0 - float(test_error_count) / float(len(test_dataset))
print('%d: %f' % (epoch, test_accuracy))
if test_accuracy > best_accuracy:
torch.save(model.state_dict(), BEST_MODEL_PATH)
best_accuracy = test_accuracy

That’s all, now our model is able to classify our images in real time!