Original article was published by Michał Oleszak on Artificial Intelligence on Medium
The Learning Rate Finder
Get to the neighborhood of optimal values quickly without costly searches.
The learning rate is arguably the most important hyperparameter to tune in a neural network. Unfortunately, it is also one of the hardest to tune properly. But don’t despair, for the Learning Rate Finder will get you to pretty decent values quickly! Let’s see how it works and how to implement it in TensorFlow.
Why is it important?
To answer this question, let’s kick off with defining the learning rate. When you train a neural network, an optimization algorithm (typically some flavor of gradient descent) traverses the surface of the loss function seeking to walk down the slope, where the loss is decreasing. The learning rate is basically the size of the step it takes. And it’s pretty important this step size is not too small and not too large.
With too small a learning rate, the algorithm would take ages to reach the minimum, as in the left panel in the picture above. To make things worse, if there are local minima in the loss surface, the optimizer might get stuck in there, unable to get out with only small steps.
If the learning rate is too large, on the other hand, the optimization algorithm might overshoot the minimum and bounce around it, never to converge, and in the worst case, it can even diverge completely, like in the right panel of the picture above. Hence, it’s really vital to get your learning rate just right!
Why is it hard?
The learning rate is a tricky hyperparameter to tune for a number of reasons:
- In most cases, domain knowledge or previous studies are of little help, for a learning rate that worked well for one problem might not be even half as good for another, even a closely-related one.
- Tuning learning rates via a grid search or a random search is typically costly, both in terms of time and computing power, especially for large networks.
- The optimal learning rate is tightly coupled with other hyperparameters. Hence, each time you change the amount of regularization or the network’s architecture, you should re-tune the learning rate.
Enters the Learning Rate Finder
Looking for the optimal rating rate has long been a game of shooting at random to some extent until a clever yet simple method was proposed by Smith (2017).
He noticed that by monitoring the loss early in the training, enough information is available to tune the learning rate. The idea is to train the network for only one epoch, starting with a very low learning rate an increase it every iteration until a very large value (so that each new mini-batch is trained using a higher learning rate than the previous one). Then, we can plot the loss versus the learning rate (log-scaled) for each iteration. Typically we would obtain a plot similar to this one:
Why such a shape? At the start of training, right after the network’s weights have been randomly initialized, it’s very easy to make progress (decrease the loss). This progress is slow initially as we are making small steps, but as we increase the learning rate, the loss starts decreasing faster and faster. At some point, however, the learning rate becomes too large and the loss diverges.
Based on this plot, a good learning rate can be picked. Then, we simply restart training with the chosen value. What is this optimal value, you might ask? Surprisingly, it’s not the value for which the loss reaches its minimum! The optimal learning rate is slightly smaller, or below, the loss-minimizing value. Why? There are two reasons for this.
- Firstly, at the minimum, the learning rate is already too large as the loss is on the brink of exploding. A bit less would not make much difference, but a bit more would blow everything up, so it’s better to stay on the safe side.
- Secondly, if you’re using an optimizer with momentum (such as the popular Adam), then the loss on each iteration is actually a moving average of the current mini-batch loss and the past losses. Consequently, when the learning rate starts being too large, the loss plot won’t explode immediately, as the average will be pulled down by the past, small losses. When we do see the loss curve skyrocket, it means the learning rate has been too large for some time already.
Hence, for the example plot above, we would pick the optimal learning rate of 10^-3, which is more or less one order of magnitude smaller than the value at the minimum.
This method might not pick you the single very optimal value of the learning rate, but the value it chooses should be quite close to the optimum. Its greatest appeal, though, is its speed. You only need to train a network once and for one epoch only — no costly random searches!
Implementing the Learning Rate Finder in TensorFlow
Implementing this approach in TensorFlow is quite easy. We need only four components:
- Something to increase the learning rate at each iteration.
- Something to record the learning rate and the loss at each iteration.
- Code to train the model for one epoch while recording the losses and learning rates.
- A plot-drawer.
The first two can be combined together and implemented as a custom callback.
At the end of processing each batch, it will append the current loss and learning rate to two respective lists set up for storing these values, and then it will multiply the learning rate by a factor, passed as an argument.
Now, what factor should we pass? This depends on the initial small value of the learning rate that we want to start with, the large value that we finally want to reach at the end of the epoch, and the number of iterations per epoch.
If we wanted to start with a learning rate of 0.000001 and increase it exponentially in 1000 iterations to finally reach 10, then the appropriate multiplicative factor is given by
np.exp(np.log(10 / 10**-6) / 1000). We can verify this easily:
The two plots are the same except for the fact that the left one shows the vertical axis (the learning rate) in the logarithmic scale. As you can see, we indeed go from 0.000001 to 10 in 1000 steps. The log-learning-rate increases linearly, so the learning rate does so exponentially. This way, we explore the small values in more detail than the large values.
The last thing to get out of the way is the number of iterations we will have in our one-epoch training — it won’t always be 1000. It is simply the number of examples in the training data divided by whatever batch size we want to use.
Let’s pack it all together in a neat function.
Our function takes a compiled TensorFlow model as input, along with the training data and the batch size. It first computes the learning rate multiplicative factor as we’ve just discussed. We use the floor division operator (
//) when computing the number of iterations in order to get an integer. Then, it sets the model’s learning rate to the minimum we start with and trains it for one epoch using the custom callback we have defined before. Finally, it extracts the losses and learning rates from the callback and plots them, yielding the plot we have already seen.
Let’s try it out on the infamous Fashion-MNIST dataset.
As discussed, judging by the plot, we declare 10^-3 as the best learning rate, re-compile the model, and it’s ready for training with the learning rate tuned!
The learning rate is a tricky hyperparameter to tune in a neural network. However, a simple solution exists:
- Train your network for only one epoch, increasing the learning rate at each iteration (starting from very small values and finishing at very large values).
- Plot the loss vs the log-learning-rate after each iteration. The loss curve will likely initially slope downwards at an increasing speed as the learning rate increases, then it will reach a minimum, and then explode upwards.
- The best learning rate is one order of magnitude smaller than the one that minimizes the loss in the plot.
- This method of tuning the learning rate can be easily implemented in TensorFlow using a custom callback.
Thanks for reading! I hope you have learned something useful that will boost your projects 🚀
If you liked this post, try one of my other articles. Can’t choose? Pick one of these: