Super-Convergence with JUST PyTorch

Source: Deep Learning on Medium

Super-Convergence with JUST PyTorch

A guide to decrease training time whilst increasing results with built-in PyTorch functions and classes

Why?

When creating Snaked, my snake classification model I needed to find a way to improve results. Super-Convergence was just that, a way to train a model faster whilst getting better results! HOWEVER, I found no guides on how to do it with the built-in PyTorch scheduler.

Cover image sourced from here

Imports

import torch
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader

from torch import nn, optim
from torch_lr_finder import LRFinder

Setting Hyperparameters

Set transforms

transforms = transforms.Compose([
transforms.RandomResizedCrop(size=256, scale=(0.8, 1)),
transforms.RandomRotation(90),
transforms.ColorJitter(),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.CenterCrop(size=224),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

Load the data, model and basic hyper parameters

train_loader = DataLoader(datasets.CIFAR10(root="train_data", train=True, download=True, transform=transforms))
test_loader = DataLoader(datasets.CIFAR10(root="test_data", train=False, download=True, transform=transforms))

model = models.mobilenet_v2(pretrained=True)

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters())


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)


Note that doing this requires a seperate library from [here](https://github.com/davidtvs/pytorch-lr-finder).


```python
lr_finder = LRFinder(model, optimizer, criterion, device)
lr_finder.range_test(train_loader, end_lr=10, num_iter=1000)
lr_finder.plot()
plt.savefig("LRvsLoss.png")
plt.close()
HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Stopping early, the loss has diverged
Learning rate search finished. See the graph with {finder_name}.plot()

Create a scheduler

Use the one cycle learning rate scheduler (for super-convergence).

Note that the scheduler uses the maximum learning rate from the graph. To choose look for the maximum gradient (slope) downwards.

The number of epochs to train for and the steps per epoch must be entered in. It is common practice to use the batch size as the steps per epoch.

scheduler = optim.lr_scheduler.OneCycleLR(optimizer, 2e-3, epochs=50, steps_per_epoch=len(train_loader))

Train model

Train the model for 50 epochs. Print stats after every epoch (loss and accuracy).

Different schedulers should be called in different within the code. Placing the scheduler in the wrong place will cause bugs, so with the one-cycle policy ensure that the step method is called straight after each batch.

best_acc = 0
epoch_no_change = 0

for epoch in range(0, 50):
print(f"Epoch {epoch}/49".format())

for phase in ["train", "validation"]:
running_loss = 0.0
running_corrects = 0



if phase == "train":
model.train()
else: model.eval()


for (inputs, labels) in train_loader:

inputs, labels = inputs.to(device), labels.to(device)


optimizer.zero_grad()

with torch.set_grad_enabled(phase == "train"):

outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)

if phase == "train":

loss.backward()
optimizer.step()


scheduler.step()

running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)


epoch_loss = running_loss / len(self.data_loaders[phase].sampler)
epoch_acc = running_corrects.double() / len(self.data_loaders[phase].sampler)
print("\nPhase: {}, Loss: {:.4f}, Acc: {:.4f}".format(phase, epoch_loss, epoch_acc))


if phase == "validation" and epoch_acc > best_acc:
epoch_no_change += 1

if epoch_no_change > 5:
break

Thanks for READING!

I hope this is easy enough to understand relatively quickly. As when I first implemented super-convergence it took me a long time to figure out how to use the scheduler (I couldn’t find any code which utilized it). If you liked this blog post consider checking out other ways to improve your model. If you’d like to see how super-convergence is used in a real project, look no further than my snake classification project.