Differentiable Architecture Search for RNN with fastai

Source: Deep Learning on Medium

Go to the profile of Bao-Tin Hoang

Differentiable Architecture Search (DARTS) by Hanxiao Liu et al. is an algorithm to automate the process of architecture design for neural networks. It is originally implemented in pure pytorch. This post touches on key ideas behind DARTS and show how I reimplemented DARTS using fastai for clarity and ease of use.

The code is at https://github.com/tinhb92/rnn_darts_fastai.


Overview of DARTS — from the paper

Using AWD-LSTM as the backbone, the goal is to find a good rnn cell to fill in the “recurrent” part of the model (RNNModel.rnns). Different configurations are tried to eventually extract the most suitable rnn cell.

DARTS Algorithm — from the paper

Let’s explain the algorithm:

  1. train_search: search for a good rnn cell genotype
  2. train: derive the genotype, train it to convergence and evaluate on test set

The first part, searching for a good rnn cell genotype, consists of continuous relaxation (i.e. creating a mixed operation) and alternatively updating architecture alpha and weights theta using gradient descent

Instead of using only 1 operation (tanh, sigmoid …) at a node in the rnn cell, we apply several operations and get a weightage sum of these. In the experiment, there are 5 operations: none, tanh, relu, sigmoid, identity. The weights given to each of these 5 operations at each node are learnable parameters.

In Pytorch, we use torch.rand() to initialize this and set requires_grad = True. The author call this alpha/architecture parameters to distinguish from theta/normal parameters of the network. Updating theta requires the usual forward and backward passes.

The gradient for alpha is described in Equation (7) of the paper:

Eq (7) in paper: Gradient for alpha

ω’ denotes the weights for a one-step forward model.

The first part of Eq(7) is calculated with 1 forward and 1 backward pass on validation data using ω’. The 2nd part of Eq(7) is calculated using finite difference approximation:

Finite difference approximation

Evaluating the finite difference requires only 2 forward passes for the weights and 2 backward passes for alpha/

The complexity is reduced from O(|alpha||theta|) to O(|alpha| + |theta|).

After training, we select the operation with highest probability (ignore none) to convert the continuous relaxation to 1 discrete specification of the cell. This genotype is then trained to convergence and evaluated on test set

Implement with fastai

The 2 phases, train_seach and train, have their training loops coded in pure pytorch in the original implementation. This is fine if the loop is not too complicated but it can get messy quickly and make the code hard to decipher. fastai has developed a callbacks system to pigeonhole components of the training loop into separate parts for clarity.

callbacks at a glance — thanks to Jeremy’s tweet

You can specify what to do at each stage of training: on_train_begin, on_epoch_begin, on_batch_begin, on_loss_begin … and so on without cluttering the code. Find out more by reading the docs!

Following the original implementation, I also have the DartsCell, DartsRnn modules (for train phase) and their respective subclasses DartsCellSearch, DartsRnnSearch (for train_search phase). The difference is other interesting add-ons are put into their own callbacks for ease of understanding.

train_search loop — link

The train_search loop in the figure above is replaced by a learner with callbacks where you can quickly see an overview of what is required.

learner with callbacks — link

Train_search has its ArchParamUpdate callback and all the logic of updating alpha/architecture parameters mentioned above is stored there.
The optimizer (arch_opt) for alpha (arch_p in the code) is part of this callback.

Similarly, train phase has a separate callback to trigger ASGD optimizer and conditions for trigger is in that callback, separate from the training loop. Other tasks that train and train_search share such as regularization, hidden initialization, save and resume training also have their own callbacks.

For more details, you can check out my code and run the notebooks provided.