Hyperparameter-tuning for Deep Learning Models with the Ax

Original article was published by Mengliu Zhao on Deep Learning on Medium

First, let’s load all the necessary modules.

# Original Code here:
# https://github.com/pytorch/examples/blob/master/mnist/main.py
import os
import argparse
from filelock import FileLock
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

!pip install ax-platform
from ax.service.managed_loop import optimize
from ax.utils.notebook.plotting import render
from ax.utils.tutorials.cnn_utils import train, evaluate

Then the main part for the classification network is kept unchanged as my last article, except the train_mnist function (we change the name to evaluate_mnist here, to show that the result is for evaluation purposes):

def evaluate_mnist(parameters):
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
train_loader, test_loader = get_data_loaders()
model = ConvNet().to(device)

optimizer = optim.SGD(
model.parameters(), lr=parameters.get("lr", 0.001), momentum=parameters.get("momentum", 0.95))

for epoch in range(50):
train_fun(model, optimizer, train_loader, device)

acc = test_fun(model, test_loader, device)
return acc

Then we show how to use the AxClient class:

  1. create the experiment with parameters and the corresponding objective function (choose minimize=False as we want to maximize the classification accuracy)
  2. run 50 trials (you need to specify parallelization in the initialization manually)
from ax.service.ax_client import AxClient

ax_client = AxClient()
parameters=[{"name": "lr", "type": "range", "bounds": [1e-4, 1e-2], "log_scale": True},
{"name": "momentum", "type": "range", "bounds": [0.1, 0.9]}],

for _ in range(50):
parameters, trial_index = ax_client.get_next_trial()
ax_client.complete_trial(trial_index=trial_index, raw_data=evaluate_mnist(parameters))

best_parameters, metrics = ax_client.get_best_parameters()


The best parameters will be printed as below:

{‘lr’: 0.007584330389670517, ‘momentum’: 0.9}