Source: Deep Learning on Medium
Differentiable Architecture Search (DARTS) by Hanxiao Liu et al. is an algorithm to automate the process of architecture design for neural networks. It is originally implemented in pure pytorch. This post touches on key ideas behind DARTS and show how I reimplemented DARTS using fastai for clarity and ease of use.
The code is at https://github.com/tinhb92/rnn_darts_fastai.
Using AWD-LSTM as the backbone, the goal is to find a good rnn cell to fill in the “recurrent” part of the model (RNNModel.rnns). Different configurations are tried to eventually extract the most suitable rnn cell.
Let’s explain the algorithm:
- train_search: search for a good rnn cell genotype
- train: derive the genotype, train it to convergence and evaluate on test set
The first part, searching for a good rnn cell genotype, consists of continuous relaxation (i.e. creating a mixed operation) and alternatively updating architecture alpha and weights theta using gradient descent
Instead of using only 1 operation (tanh, sigmoid …) at a node in the rnn cell, we apply several operations and get a weightage sum of these. In the experiment, there are 5 operations: none, tanh, relu, sigmoid, identity. The weights given to each of these 5 operations at each node are learnable parameters.
In Pytorch, we use torch.rand() to initialize this and set requires_grad = True. The author call this alpha/architecture parameters to distinguish from theta/normal parameters of the network. Updating theta requires the usual forward and backward passes.
The gradient for alpha is described in Equation (7) of the paper:
ω’ denotes the weights for a one-step forward model.
The first part of Eq(7) is calculated with 1 forward and 1 backward pass on validation data using ω’. The 2nd part of Eq(7) is calculated using finite difference approximation:
Evaluating the finite difference requires only 2 forward passes for the weights and 2 backward passes for alpha/
The complexity is reduced from O(|alpha||theta|) to O(|alpha| + |theta|).
After training, we select the operation with highest probability (ignore none) to convert the continuous relaxation to 1 discrete specification of the cell. This genotype is then trained to convergence and evaluated on test set
Implement with fastai
The 2 phases, train_seach and train, have their training loops coded in pure pytorch in the original implementation. This is fine if the loop is not too complicated but it can get messy quickly and make the code hard to decipher. fastai has developed a callbacks system to pigeonhole components of the training loop into separate parts for clarity.
You can specify what to do at each stage of training: on_train_begin, on_epoch_begin, on_batch_begin, on_loss_begin … and so on without cluttering the code. Find out more by reading the docs!
Following the original implementation, I also have the DartsCell, DartsRnn modules (for train phase) and their respective subclasses DartsCellSearch, DartsRnnSearch (for train_search phase). The difference is other interesting add-ons are put into their own callbacks for ease of understanding.
The train_search loop in the figure above is replaced by a learner with callbacks where you can quickly see an overview of what is required.
Train_search has its ArchParamUpdate callback and all the logic of updating alpha/architecture parameters mentioned above is stored there.
The optimizer (arch_opt) for alpha (arch_p in the code) is part of this callback.
Similarly, train phase has a separate callback to trigger ASGD optimizer and conditions for trigger is in that callback, separate from the training loop. Other tasks that train and train_search share such as regularization, hidden initialization, save and resume training also have their own callbacks.
For more details, you can check out my code and run the notebooks provided.