Deep Double Descent: when more data and bigger models are a bad thing

Source: Deep Learning on Medium

Deep Double Descent: when more data is a bad thing

Resolving a fundamental conflict between classical statistics and modern ML advice

Photo by Franki Chamaki on Unsplash

I recently came across a very interesting paper written at OpenAI on the topic of deep double descent. The paper touches at the very nature of training machine learning systems and model complexity. I hope to summarize the points made in this paper in an approachable way in this post, and to advance the discussion of the tradeoffs between model size, data quantity and regularization.

The Problem

There’s a fundamental conflict between statistical learning and modern ML theory. Classical statistics says that too large models are bad. This is due to the fact that complicated models are more prone to overfitting. In fact, one powerful theorem frequently applied in classical statistics is Occam’s razor, which in essence states that the simplest explanation is usually right.

This can be clearly explained with a visualization.

The green line is an example of a model overfitting the training data, while the black line is a simpler model that approximates the true distribution of the data.

Despite this seemingly obvious tradeoff between complexity and generalizability, you often see that modern ML theory states that bigger models are better. The interesting thing about this statement is that, for the most part, it seems to be working. Research from some of the top AI research teams in the world, including teams from Google and Microsoft indicate that deeper models are not saturating. In fact, by implementing careful regularization and early stopping, it seems that often, the best way to improve your model’s performance by a few points is to simply add more layers or collect more training data.

Deep Double Descent

The focus of the OpenAI paper provides a practical investigation of the contradiction between classical statistics and modern ML theory.

Empirical evidence shows that the truth of how modern machine learning systems work is a mixture of both classical statistics and modern theory.

Deep double descent is the phenomenon where performance improves, then gets worse as the model begins to overfit, and then finally improves more with increasing model size, data size, or training time. The graph above illustrates this behavior graphically.

The deep double descent phenomenon has several implications regarding the complexity of models, quantity of data and training time.

Sometimes, bigger models are worse

Before the model hits the interpretation threshold, there is a bias-variance tradeoff. Afterwards, the current wisdom of “Larger models are better” is applicable

When experimenting with ResNet18, the OpenAI researchers found an interesting note about the tradeoff between bias and variance. Before a model’s complexity passed the interpolation threshold, or the point at which the model is are just large enough to fit the training set, larger models had higher test error. However, after the model’s complexity allowed it to fit the entire training set, larger models with more data started performing far better.

It seems that there is a region of complexity where models are more prone to overfitting, but if enough complexity is captured in the model, the larger the better.

Sometimes, more samples are worse

There’s a point where models with more data actually perform worse on the test set. Again, however, there is a point near the interpolation threshold at which this reverses.

Interestingly, for models below the interpolation threshold, it seems that more training data actually produces worse performance on the test set. However, as the model becomes more complex, this tradeoff reverses, and the modern wisdom of “more data is better” begins to apply again.

One working hypothesis is that it’s possible that less complicated models cannot capture everything needed in a too large training set, and therefore cannot generalize to unseen data well. As the model becomes complex enough, however, it is able to overcome this limitation.

Sometimes, training longer undoes overfitting

In the graphs above with respect to the number of epochs trained, training and testing error first decreases sharply as the number of epochs increases. Eventually, test error starts increasing as the model begins to overfit. Finally, test error decreases again as the overfitting is, somewhat miraculously, undone.

In the paper, the researchers call this epoch-wise double descent. They also note that this peak in test error is just at the interpolation threshold. The intuition here is that if a model is not very complex, there is only one model that fits the train data best. If the models fits noisy data, its’ performance will dip dramatically. However, if a model is complex enough to pass the interpolation threshold, there are several models that fit the train set and the test set, and as you train longer, that allows you to approximate one of these models. The reason that this occurs is an open research question, and fundamentally important to the future of training deep neural networks.

If this research question is something that interests you, take a look at the paper and associated summarization, which inspired this post.