Source: Deep Learning on Medium
Deep Double Descent: when more data is a bad thing
Resolving a fundamental conflict between classical statistics and modern ML advice
I recently came across a very interesting paper written at OpenAI on the topic of deep double descent. The paper touches at the very nature of training machine learning systems and model complexity. I hope to summarize the points made in this paper in an approachable way in this post, and to advance the discussion of the tradeoffs between model size, data quantity and regularization.
There’s a fundamental conflict between statistical learning and modern ML theory. Classical statistics says that too large models are bad. This is due to the fact that complicated models are more prone to overfitting. In fact, one powerful theorem frequently applied in classical statistics is Occam’s razor, which in essence states that the simplest explanation is usually right.
This can be clearly explained with a visualization.
Despite this seemingly obvious tradeoff between complexity and generalizability, you often see that modern ML theory states that bigger models are better. The interesting thing about this statement is that, for the most part, it seems to be working. Research from some of the top AI research teams in the world, including teams from Google and Microsoft indicate that deeper models are not saturating. In fact, by implementing careful regularization and early stopping, it seems that often, the best way to improve your model’s performance by a few points is to simply add more layers or collect more training data.
Deep Double Descent
The focus of the OpenAI paper provides a practical investigation of the contradiction between classical statistics and modern ML theory.
Deep double descent is the phenomenon where performance improves, then gets worse as the model begins to overfit, and then finally improves more with increasing model size, data size, or training time. The graph above illustrates this behavior graphically.
The deep double descent phenomenon has several implications regarding the complexity of models, quantity of data and training time.
Sometimes, bigger models are worse
When experimenting with ResNet18, the OpenAI researchers found an interesting note about the tradeoff between bias and variance. Before a model’s complexity passed the interpolation threshold, or the point at which the model is are just large enough to fit the training set, larger models had higher test error. However, after the model’s complexity allowed it to fit the entire training set, larger models with more data started performing far better.
It seems that there is a region of complexity where models are more prone to overfitting, but if enough complexity is captured in the model, the larger the better.
Sometimes, more samples are worse
Interestingly, for models below the interpolation threshold, it seems that more training data actually produces worse performance on the test set. However, as the model becomes more complex, this tradeoff reverses, and the modern wisdom of “more data is better” begins to apply again.
One working hypothesis is that it’s possible that less complicated models cannot capture everything needed in a too large training set, and therefore cannot generalize to unseen data well. As the model becomes complex enough, however, it is able to overcome this limitation.
Sometimes, training longer undoes overfitting
In the graphs above with respect to the number of epochs trained, training and testing error first decreases sharply as the number of epochs increases. Eventually, test error starts increasing as the model begins to overfit. Finally, test error decreases again as the overfitting is, somewhat miraculously, undone.
In the paper, the researchers call this epoch-wise double descent. They also note that this peak in test error is just at the interpolation threshold. The intuition here is that if a model is not very complex, there is only one model that fits the train data best. If the models fits noisy data, its’ performance will dip dramatically. However, if a model is complex enough to pass the interpolation threshold, there are several models that fit the train set and the test set, and as you train longer, that allows you to approximate one of these models. The reason that this occurs is an open research question, and fundamentally important to the future of training deep neural networks.