The Curious Case of the Validation Loss Mismatch

Original article was published on Deep Learning on Medium

The Curious Case of the Validation Loss Mismatch

Yesterday I stumbled upon Greg Yang’s (@TheGregYang) twitter thread on how the cross-entropy loss blows up on held-out data, even though the error-rate keeps declining [1]. Having been mystified by this in the past, I read through the different responses hoping to get some answers, but it appears this phenomenon is still not well understood. Since I had my own data, I decided to share my results on this matter. Maybe a community effort would provide us with a better understanding of this phenomenon.

A few months ago I was looking into how the complexity of a deep network affects its optimization trajectory. Since I wanted to perform quite a few training sessions, I decided to focus on a simple image classification task, one that would allow me to execute few training sessions, and still have some change left. Accordingly, I focused on the Fashion-MNIST dataset [2], and trained several versions of ResNet18 [3]. As with many network architectures, a ResNet typically widens from one layer to another, doubling the width (number of image channels) of subsequent residual blocks. Thus, one can think of the width of the first layer as a parameter that determines the complexity of the overall network. Table (1) specifies the initial widths of the networks I used, along with the networks’ overall number of parameters.

Table 1: ResNet18 width of the first convolutional layer, and the overall number of parameters in the corresponding network. As the width increases, the complexity of the overall network increases geometrically.

Since originally I wanted to examine the optimization trajectory of several losses, I decided to focus on a one-vs-all regime instead of multi-class classification. This would allow a close comparison of the cross-entropy loss, to the logistic, and the hinge losses (for the cross-entropy loss I used a redundant softmax layer with two outputs). From preliminary results I knew that ResNet18 makes the most mistakes on the ‘shirt’ class, so I used it as the target class in all my experiments. For training, I used the Adam optimizer [4] with an adaptive learning rate decay for 300 epochs. During training I tracked the cross-entropy loss, and the error-rate (the task loss) both on the training set, and on a validation set that was produced from 20% of the data. Figure (1) depicts the cross-entropy loss, and the error-rate both on the training, and the validation sets. Those who are interested in examining figure (1) in greater detail can follow this link to an interactive version of the figure (to zoom use shift+mouse scroll).

Figure 1: cross-entropy loss, and the error rate (task loss) when training ResNet18 of varying complexities. Each figure depicts one of the two losses as measured at the end of every epoch. The figures in the upper row depict the losses as measured on the train set, while the figures in the bottom row depict the losses as measured on the validation set.

Since the cross-entropy is an upper bound on the error-rate, the consistency of the two on the training set is far from surprising. However, it appears that the validation error-rates keep improving, even if overfitting occurs. Specifically, it appears that for the networks whose initial width is 8 or more, the cross entropy loss starts to increase even after only a few epochs. Furthermore, after the training cross-entropy reaches zero (top-left panel), the validation cross-entropy starts to increase rather smoothly, only to subside after 200+ epochs. Surprisingly enough, this does not prevent the validation error rates of all the networks from decreasing.

While the overall phenomenon is consistent with [1], there is one noteworthy exception. It appears that reducing the complexity of the network can prevent the validation loss from over-shooting. In my experiments, the validation cross-entropy curves clearly form two ‘clusters’, and for the simpler networks, the validation cross-entropy does not overshoot.

In addition, there are two more interesting observations that are worth pointing out. (I) After the training loss reaches zero, the complex networks’ loss drifts upwards rather smoothly and consistently. Since at this point the gradient is zero, any update to the parameters probably results from the two running averages that the Adam optimizer utilizes. (II) However, this dynamic is not consistent with the networks’ complexity. It appears that the highest validation loss was measured for the network of width 8, and that the loss for network of width 64 converged to a lower value than that of the network of width 128. While these phenomena are secondary in their importance, they might entail some clues into the dynamics that govern this validation loss mismatch.

Future plans: the research question I started with when performing these experiments morphed into something slightly different. I never got to repeating this experiment with the logistic, and hinge loss. It would be really interesting to see if this phenomenon occurs with the latter two, and especially with the hinge loss as it does not involve any exponential transformations that might have something to do with oddity. Hope I will get to these experiments sometime soon. Drop me a line if you get there first 🙂