Original article was published by Stan Kriventsov on Deep Learning on Medium
An Image is Worth 16×16 Words: Transformers for Image Recognition at Scale (brief review of the ICLR 2021 paper)
In this post I would like to explain, without going into too many technical details, the significance of the new paper “An Image is Worth 16×16 Words: Transformers for Image Recognition at Scale” submitted (so far anonymously due to the double-blind review requirements) to the 2021 ICLR conference. In my next post, I will provide an example of using this new model (called the Vision Transformer) with PyTorch to make predictions on the standard MNIST dataset.
Deep learning (machine learning using neural networks with more than one hidden layer) has been around since the 1960s, but it truly came to the forefront in 2012 when AlexNet, a convolutional network (in simple terms, a network that looks for smaller patterns in each part of the image and then tries to combine them into an overall picture) designed by Alex Krizhevsky, won the annual ImageNet image classification competition by a large margin.
Over the next years, deep computer vision techniques experienced a true revolution, with new convolutional architectures (GoogleNet, ResNet, DenseNet, EfficientNet, etc) appearing every year to claim a new accuracy record on ImageNet and other benchmark datasets (e.g. CIFAR-10, CIFAR-100. The image below shows the progress of the top-1 accuracy of machine learning models on the ImageNet dataset (the accuracy of predicting correctly what the image contains on the first try) since 2011.
In the last few years, however, the most interesting developments in deep learning have been happening not in the image domain, but in natural language processing (NLP) due to the Transformer model, first presented in the 2017 paper “Attention is All You Need” by Ashish Vaswani et al. The idea of attention, referring to trainable weights modeling the importance of each connection between different parts of the input sentence, had a similar effect on NLP to that of convolutional networks on computer vision, greatly improving the results of machine learning models on various linguistic tasks such as natural language understanding and machine translation.
The reason why attention works particularly well on language data is that understanding human speech often requires keeping track of longer-term dependencies. For example, the first sentence of a paragraph may be “We arrived in New York”, and then a few (or more than a few) sentences later the author may say “The weather in the city was great”. To any human reader, it should be clear that “the city” in the last sentence refers to “New York”, but for a model based on only finding patterns in nearby data (such as convolutional networks) this connection may be impossible to detect.
The problem of longer-term dependencies can be addressed by using recurrent networks such as LSTMs, which in fact were the top models in NLP before the arrival of Transformers, but even those have a hard time matching specific words.
Global attention models in Transformers weigh the importance of each connection between ANY two words of the text, which explains their superior performance. For sequential data where attention is less important (for example, time-domain data such as daily sales numbers or stock prices) recurrent networks are still very much competitive and may be the best choice.
While dependencies between distant objects may be of particular significance in sequence models such as NLP, they certainly cannot be neglected in image tasks. To form a whole picture, it may often be essential to have knowledge of all parts of the image.
The reason attention models haven’t been doing better until now in computer vision lies both in the difficulty of scaling them (they scale as N², so a full set of attention weights between pixels of a 1000×1000 image would have a million terms) and, perhaps more importantly, in the fact that, as opposed to words in a text, individual pixels in a picture are not very meaningful by themselves, so connecting them via attention does not accomplish much.
The new paper suggests the approach of using attention not on pixels, but instead on small patches of the image (perhaps 16×16 as in the title, although the optimal patch size would really depend on the dimensions and the contents of the images to which the model is applied).
The picture above(taken from the paper) shows the schematic of the operation of the Vision Transformer. Each patch in the image is flattened by using a linear projection matrix, and a positional embedding (a numerical value containing information about where the patch originally was in the image) is concatenated to it. This is necessary since Transformers treat all inputs irrespective of their order, so having this positional information helps the model to properly evaluate the attention weights. An extra class token is added to the inputs (labeled 0 in the image) which is a placeholder for the class to be predicted in the classification task.
The Transformer encoder, similarly to the original 2017 version, consists of multiple blocks of attention, normalization, and fully-connected layers with residual (skip) connections, as shown in the right part of the picture. In each attention block, multiple heads can capture different patterns of connectivity. If you are interested in learning more about Transformers, I would recommend reading this excellent article by Jay Alammar.
The fully-connected MLP head at the output provides the desired class prediction. Of course, as always nowadays, the main model can be pre-trained on a large dataset of images, and then the final MLP head can be fine-tuned to a specific task via the standard transfer learning approach.
One feature of the new model is that, while, according to the paper, it is more efficient than convolutional approaches in terms of achieving the same accuracy of prediction with less computation, its performance really seems to keep improving as it is trained on more and more data, more so than the other models. The authors of the paper have trained the Vision Transformer on a private Google JFT-300M dataset containing 300 million (!) images, which resulted in state-of-the-art accuracy on a number of benchmarks. One can hope that this pre-trained model will soon be released to the public so that we can all try it out.
It’s definitely exciting to see this new application of neural attention to the computer vision domain. Hopefully, a lot more progress will be achieved in the coming years based on this development!
In my next post (coming in a day or two) I will show how you can train the Vision Transformer on your computer from scratch (using PyTorch) to make predictions on the MNIST dataset of handwritten digits.