Source: Deep Learning on Medium

Previously, we discussed augmenting deep neural network models for fast inference with branches in Fast Inference for Deep Learning Models.

To recap, one can add exit branches to existing models to exit the inference execution early, resulting in fast inference. In the figure shown, the main branch is the original model and the early exit branch is the additional augmented layer to the original model.

To do fast inference with this augmented model, we pass a sample through to Exit 1 and get a confidence value. To decide whether a sample should exit at a branch, we look at the confidence of the softmax output of that branch, i.e.,

If the confidence value is high enough according to a set threshold, we exit the sample, otherwise the sample continues up to the main branch exit and the final result is the average of the outputs from both exits.

If the confidence is higher than a threshold T, we exit the inference. The issue remaining is how to pick this threshold or a set of thresholds when there are multiple augmented branches. In this article, we will discuss a method to find such threshold or set of thresholds.

The goal for us is to minimize the inaccuracy of an exit branch while maintaining the overall accuracy of the augmented model.

Let E be the number of exits at a branch and O of which is predicted incorrectly by that branch. Given a set of N test samples and I of which is predicted incorrectly by the augmented model, we then want to minimize

(O / E) + (I / N) * (1 – E / N)

We want to exit samples as early as we can in order to achieve fast inference, so we start the optimization with the earlier branches first, find the threshold for that branch, freeze it and then optimize the later branches.

With this, we have the following result.

The figure above shows the result of this optimization on AlexNet-style CNN model with 8 layers 2 augmented branches on FashionMNIST dataset. The blue brute-force curve is obtained by trying many different set of thresholds. The red star is our resulting set of thresholds of 0.98 for the first branch and 0.96 for the second branch. The purple Original dot is the original model before branch augmentation. With branch augmentation and a way of finding exit thresholds, we can cut the average inference time in half.

We will be releasing a Pytorch library for everyone to augment your model and optimize for the suitable set of thresholds soon. Please stay tuned and let us know what you think or if you have ideas to share. Thank you!