Original article was published by Michael Deyzel on Deep Learning on Medium
Quick Tutorial: Using Bayesian optimization to tune your hyperparameters in PyTorch
A faster way to design your neural networks
Hyperparameters are the parameters in models that determine model architecture, learning speed and scope, and regularization.
The search for optimal hyperparameters requires some expertise and patience, and you’ll often find people using methods exhausting methods like grid search and random search to find the hyperparameters that work best for their neural networks.
A quick tutorial
I’m going to show you how to implement Bayesian optimization to automatically find the optimal parameterization for your neural network in PyTorch using Ax.
We’ll be building a simple CIFAR-10 classifier using transfer learning. Most of this code is from the official PyTorch beginner tutorial for a CIFAR-10 classifier.
Firstly, the usual
Install Ax using:
pip install ax-platform
Import all the necessary libraries:
Download the datasets and construct the data loaders (I would advise adjusting the training batch size to 32 or 64 later):
Let’s take a look at the CIFAR-10 dataset by creating some helper functions:
Training and evaluation functions
Ax requires a function that returns a trained model, and another that evaluates a model and returns a performance metric like accuracy or F1 score. We’re only building the training function here and using Ax’s own
evaluate tutorial function to test our model performance. You can check out ther API to model your function after theirs, if you’d like.
Next, we’re writing an
init_net() function that initliazes the model and returns the network ready-to-train. There are many opportunities for hyperparameter tuning here. You’ll notice the
parameterization argument, which is a dictionary containing the hyperparameters.
train_evaluate() function that the Bayesian optimizer calls on every run. The optimizer generates a new set of hyperparameters in
parameterization, passes it to this function, and the analyses the returned evaluation results .
Now, just specify the hyperparameters you want to sweep across and pass that to Ax’s