Test Time Augmentation (TTA) and how to perform it with Keras

Source: Deep Learning on Medium

Data Augmentation

Go to the profile of Nathan Hubens

There exists a lot of ways to improve the results of a neural network by changing the way we train it. In particular, Data Augmentation is a common practice to virtually increase the size of training dataset, and is also used as a regularization technique, making the model more robust to slight changes in the input data.

Data Augmentation is the process of randomly applying some operations (rotation, zoom, shift, flips,…) to the input data. By this mean, the model is never shown twice the exact same example and has to learn more general features about the classes he has to recognize.

Example of Data Augmentation on the CIFAR10 dataset

However, there also exists some ways to improve the results of the model by changing the way we test it, and this is where Test Time Augmentation (TTA) comes into play…

What is Test Time Augmentation ?

Similar to what Data Augmentation is doing to the training set, the purpose of Test Time Augmentation is to perform random modifications to the test images. Thus, instead of showing the regular, “clean” images, only once to the trained model, we will show it the augmented images several times. We will then average the predictions of each corresponding image and take that as our final guess.

But all of this will be clearer with an example. Let’s take a neural network trained on CIFAR10 and that is presented the following test image:

The test image (a boat)

Here is the prediction of the model, expressed as its confidence that the given image belongs to the possible classes (the highest score corresponds to the predicted label):

And here is the true label of the image:

As we can see, the model is outputting a wrong answer, because he is the most confident that the image belongs to the second class (corresponding to cars), but the correct answer is the ninth class (corresponding to boat).

What happens now if we apply Test Time Augmentation ? We will present 5 slightly modified version of the same image and ask the network to predict the class of each of them.

Modified version of the test image

And here are the corresponding predictions:

Prediction 1
Prediction 2
Prediction 3
Prediction 4
Prediction 5

As you can see, only predictions 1 and 4 are correct with reasonable confidence. 2 is only correct by a very small margin while 3 and 5 are incorrect. What happens now if we take the mean of those 5 results ?

Average of the 5 predictions

Now we can see that the mean gives the correct answer, with a reasonable confidence. So instead of having a wrong answer, as it was the case with the original test image, we now have a correct answer.

The reason why it works is that, by averaging our predictions, on randomly modified images, we are also averaging the errors. The error can be big in a single vector, leading to a wrong answer, but when averaged, only the correct answer stand out.

Test Time Augmentation is particularly useful for test images that the model is pretty unsure. Even if those 5 images seemed very similar to you, for the model they were very different, by looking at its predictions.

How to use it with Keras ?

Test Time Augmentation can be easily used with Keras, even though it is not clearly mentionned in the documentation.

The first step is to create a really simple Convolutional Neural Network, that we will train on CIFAR10 for the demonstration:

We can then define the augmentation we want to perform on the training images by using the ImageDataGenerator class:

We can now train the network for a few epochs:

The final accuracy of the model on the validation images is:

To do Test Time Augmentation, we can reuse the same Data Generator used for training, and apply it to validation images.

We can then show the model 10 times (for example) the randomly modified images, get the prediction for each, and take the average:

And now we get an accuracy of:

The final accuracy has increased by more than 3%, without changing anything to the model !

More data augmentation isn’t always better

While data augmentation is a very effective technique to get better results, it must be used wisely. If not used properly, it can hurt the accuracy of your model. I’ll show you why it can happen:

Bad data augmentation on MNIST

Can you guess the digits represented in each image ?

I’ll give you a hint, there is no 6 in the picture…

The kind of augmentation you want to do depends on the data you have ! In the case of MNIST, you certainly don’t want to do random flips or too large rotations as they totally change the content of the image, a 6 can be flipped into a 9, and inversely, making it very difficult for your model to learn to differentiate those classes.

However, in a database like CIFAR10, you totally want to do horizontal flips as they don’t change the image, a horse looking to the right or looking to the left is still a horse. But here, vertical flips also don’t make sense, as it is very unlikely that you want your model to recognize a ship that is upside down.

In some case, like satellite imaging or crop cultures images, where having upside down images don’t change their meaning, you can use large rotation and vertical flips as data augmentation.

In conclusion, data augmentation can be used both for boosting the results of your model at training time but also at testing time, as long as you use it wisely.

Don’t hesitate to play with the ImageDataGenerator arguments to see how it impacts your results !

I hope this article was clear and useful for new Deep Learning practitioners and that it gave you a good insight on what Test Time Augmentation is ! Feel free to give me feed back or ask me questions is something is not clear enough.

The code is available here :