Proposing a new effect of learning rate decay — Network Stability

Original article was published on Deep Learning on Medium

Proposing a new effect of learning rate decay — Network Stability

Uncovering learning rate as a form of regularisation in stochastic gradient descent

Abstract

Modern literature suggests the learning rate is the most important hyper-parameter to tune for a deep neural network. [1] With too low of a learning rate, gradient descent can be painfully slow down a long steep valley or saddle point. With too high of a learning rate, gradient descent risks overshooting the minima. [2] Adaptive learning rate algorithms have been developed to take into account the momentum and accumulated gradient to be more robust to these situations for non-convex optimisation problems. [3]

A common theme is that decaying the learning rate after a certain number of epochs can help models converge to better minima by allowing weights to settle into more exact sharp minima. The idea is that with a given learning rate you may continually miss the actual minima by going back and forth across it. So by decaying the learning rate, we allow our weights to settle into these sharp minima.

As the deep learning community continually improves and builds new ways of adjusting the learning rate, like cyclical learning rate schedules, I believe it’s important we understand more deeply what effects the learning rate has on our optimisation problem.

In this study, I would like to propose an over-looked effect of learning rate decay: network stability. I will experiment with learning rate decay in different settings and show how network stability arises from it, and how network stability can benefit subnetworks during the training process.

I believe learning rate decay has two effects on the optimisation problem:

  1. Allow weights to settle into deeper minima
  2. Provide network stability of forward propagated activations and back propagated loss signals during the training process

First, I would like to establish what I mean by network stability through some mathematical notation.

Then, I will proceed to show how both of these effects take place through experiments on the ResNet-18 architecture on the CIFAR-10 dataset.

Theory

What is network stability? More importantly, how does it affect the loss landscape?

Let us view our loss landscape probabilistically.

For a single example x, we can calculate the loss L, given the current weights of the network w,

An example x is picked from a distribution D over the space of all images.

We can then view the loss landscape probabilistically as

The loss function calculated for a single image x, for a given set of weights w, produces an array of gradients for each weight through back propagation, namely,

There is a corresponding probability that we observe this gradient from the probabilistic loss landscape above as

The process of performing stochastic gradient descent across a given dataset, with batches of samples, can be seen as navigating through the probabilistic loss field through point estimates.

Over a single batch of images, the weights are kept constant. Each update at time step t moves through the following conditional landscape,

By averaging over large enough batches of size n , we hope that

So that the gradient step we take, is a likely step, for a given set of weights.

Each step t is then taken over the following conditional:

By decreasing the learning rate over the network, we decrease our changes in w.

Now subsequent iterations of batch gradient descent operate over probabilistic loss fields that are more similar to the previous iteration. This enforces some level of stability in our network during training, as for small enough changes in w,

By decreasing the learning rate, we enforce that stochastic gradient descent operates over a tighter conditional probability over the weights, rather than jump between weight regimes and their corresponding conditional probability fields. This is my notion of network stability.

The classical view of learning rate decay as enabling convergence to sharp minima is an oversimplification of a stochastic process, and treats the loss landscape as a constant.