What If Only Batch Normalization Layers Were Trained?

Original article can be found here (source): Artificial Intelligence on Medium

In sum, all three explanations focus on the normalization aspect of Batch Normalization. Contrastingly, we shall look at the shift-and-scale point of BN, realized by the γ, and β parameters.

Reproducing the Paper

If an idea is any good, it should be resilient to the implementation and the choice of hyperparameters. In my code, I recreated the main experiment as minimally as possible, using Tensorflow 2 and my own choice of hyperparameters. In more detail, I tested the following proposition:

ResNet models can achieve decent results on the CIFAR-10 dataset with all weighs locked, except for the batch normalization parameters.

Thus, I will be using Keras’ CIFAR-10 and ResNet modules and the overall recommendation for the CIFAR-10 dataset, which is the Categorical Cross-Entropy loss and the Softmax activation. My code downloads the dataset and the randomly-initialized ResNet model, freezes the unwanted layers, and trains for 50 epochs using a batch size of 1024 images. You can inspect the code below:

A couple of things should be noted in the above code:

  1. The Keras API only has the ResNet-50, 101, and 152 models. To keep it simple, I have only used those. If you want to go deeper, refer to this guide for a custom implementation of the entire ResNet architecture.
  2. The ResNet model uses the ‘ones’ initialization strategy for the γ parameter. In our limited training scenario, this is too symmetric to be trained by gradient descent. Instead, as suggested in the paper, the ‘he_normal’ initialization is used. For this, we re-initialize the Batch Normalization weights manually before training.
  3. The authors trained for 160 epochs, using a batch size of 128 images, and an SGD optimizer with 0.9 as momentum. The learning rate was initially set to 0.01 and scheduled to 0.001 and 0.0001 at epochs 80 and 120. For such a naive idea, I found this to be too specific. Instead, I used 50 epochs, a batch size of 1024, the vanilla Adam, and a fixed learning rate of 0.01. If the idea is any good, this shouldn’t be a problem.
  4. The authors also used Data Augmentation, whereas I did not. Again, if the idea is any good, none of these changes should be a significant problem.

Results

Here are the results I have obtained with the above code:

Training Accuracy for the ResNet Models training only Batch Normalization Layers
Validation Accuracy for the ResNet Models training only Batch Normalization Layers

Numerically, the three models achieved 50, 60, and 62% training accuracy and 45, 52, and 50% validation accuracy.

To have a good sense of how well a model is performing, we should always consider the performance of random guessing. The CIFAR-10 dataset has ten classes. Thus, at random, we can be right 10% of the time. The above methods are about five times better than random guessing. Therefore, we can consider them as having decent performance.

Interestingly, the validation accuracy took ten epochs to start increasing, which is a clear sign that, for the first ten epochs, the network was just overfitting the data as best as it could. Later on, the validation performance rises substantially. However, it varies greatly every five epochs, which shows the model is not very stable.

In the paper, Figure 2 shows that they achieved a validation accuracy of ~70, ~75, and ~77%. Considering the authors done some tuning, used a custom training schedule, and employed Data Augmentation, this seems pretty reasonable and consistent with my findings, confirming the hypothesis.

Using an 866-layers ResNet, the authors got to almost ~85% accuracy, which is only a few percentage points below the ~91% achievable by training the whole architecture. Furthermore, they tested different initialization schemes, architectures, and tested unfreezing the last layer and skip connections, which resulted in some additional performance gains.

Besides accuracy, the authors also investigated the histogram of the γ and β parameters, finding that the network learned to suppress about a third of all activations in each BN layer, by setting γ to near-zero values.

Discussion

At this point, you might ask: why all this? First of all, it’s fun 🙂 Second, BN layers are commonplace, but we still have only a superficial understanding of their role. What we know are their benefits. Thirdly, this kind of investigation is what leads us to a more in-depth understanding of how our models operate.

I don’t believe that this has practical applications by itself. No one will freeze their layers and leave it all to the BNs. However, this might inspire different training schedules. Maybe training the network for a few epochs like this and then training all weights might lead to superior performance. Inversely, this technique might prove useful for fine-tuning pre-trained models. I can also see this idea being leveraged for pruning weights of big networks.

What puzzles me the most with this study is how much we have all been ignoring these two parameters. I, at least, never minded both. I recall seeing only one discussion about it, which argued that it is good to initialize γ with ‘zeros’ on ResNet blocks to force the back-propagation algorithm to use the skip connections more in early epochs.

My second questioning regards the SELU and SERLU activation functions, which have the self-normalizing property. Both functions make the Batch Normalization layers obsolete, as they naturally normalize their outputs during training. Now, I question myself if this captures the entirety of the Batch Normalization layer.

Finally, the hypothesis is still a bit primitive. It only considers the CIFAR-10 dataset and significantly deep networks. It is open if this can scale to other datasets or solve different tasks, such as a Batchnorm-only GAN. Also, I would find it interesting to see a follow-up article on the role of γ and β for fully trained networks.