Effect of learning rate for training convergence.

Source: Deep Learning on Medium

Introduction

In this post i am going to share my insights on how the different values of learning rate influences convergence while training, during training our model what should we infer from the learning rate plots and based on that how should we update our learning rates.

Overview

We will create a sample data and use a specific weights to create target values. Once we have the target values we will add Gaussian noise into it. Now the task we have is to predict the weights which caused the targets with that noise. During this process we will keep on changing the learning rate to reach the target and see the effect of it. Below are the 6 steps which that i will be following

  1. Creating Sample data
  2. Selecting specific weights
  3. Obtaining the target with added noise and visualizing it.
  4. First prediction using mean squared error and see how much far we are from the output.
  5. Using gradient descent optimization to reach to the target.
  6. Updating the learning rates and seeing the effects of it.

Implementation

We will be using pytorch, matplotlib and numpy for the entire implementation.

Firstly we will create a sample data which is of dimension [100,2] here the first column represents some data from uniform distribution and the second column is ones which goes with the bias.

import torch
import matplotlib.pyplot as plt
import numpy as np
from torch import nn
n=100
x = torch.ones(n,2) #sample data of shape [100,2]
x[:,0].uniform_(-1.,1)
x[:5]
sample data

Now we will create a tensor where 3. is the weight and 2 is the bias, roughly we will call this as weight matrix. This is the matrix which the algorithm has to learn ultimately.

a = torch.Tensor([3.,2]) #weights to be learned by the model.
weight matrix

Here we will be multiply our weights matrix with the sample data to get the predictions and then add some noise to it. Now this y is the target label.

multiplication of sample data and weight (x@a)
y = x@a + torch.rand(n) #target label to be predicted
plt.scatter(x[:,0], y);
target y visualization

Here we are going to do our first prediction using randomly chooses weight matrix of [-1,1]. Meanwhile we will also calculate the loss using mean squared error.

def mse(y_hat, y): #mean squared error 
return ((y_hat-y)**2).mean()
a = torch.Tensor([-1.,1]) #initializing random weights

y_hat = x@a #first target prediction
mse(y_hat, y)
plt.scatter(x[:,0],y)
plt.scatter(x[:,0],y_hat);
plot for first target prediction.

Gradient Descent

We will write the gradient descent implementation using pytorch so that gradients are calculated automatically. Above randomly created weight matrix a we will passed into nn.Parameters module of pytorch which automatically calculates the gradients for all the operation done on a and those gradients can be accessed using a.grad. We will predict the y, calculate the loss, backpropagate the loss , update the weight matrix. This process we will do for 100 times and we are expecting that at the end of 100 iterations we will be close to target label value.

a = nn.Parameter(a);

def update():
y_hat = x@a
loss = mse(y, y_hat)
if t % 10 == 0: print(loss)
loss.backward()
with torch.no_grad():
a.sub_(lr * a.grad)
a.grad.zero_()
lr = 1e-1
for t in range(100):
update()
plt.scatter(x[:,0],y)
plt.scatter(x[:,0],x@a.detach());
after 100 iteration predicted y(orange) and target y(blue).

UPDATING LEARNING RATES

We will see the effect of learning rates by using lr=0.1 and lr=1.01

As we can see from left side video is when we use low learning rate of 0.1 it take time for the model to reach to the minimum around 70 epochs.

On the other hand if we use slightly high learning rate of 0.7 model comes to the minimum in around 8 epochs.

Whereas higher learning rate of 1.01 pushes the model towards divergence.

CONCLUSION

As we can see from the left image while reaching towards convergence loss function starts fluctuating. This can be explained from the below image which shows the loss curve is flat in bottom which causes fluctuations in the loss function.