Accelerate Model Training With Batch Normalization

Source: Deep Learning on Medium

The Batch Normalization paper published back in 2015 by Sergey Ioffe, Christian Szegedy took the deep learning community by storm. It became one of the most implemented techniques in deep learning after it was released. Notably, its ability to accelerate training of deep learning models and achieve the same accuracy in 14 times fewer training steps was a great catch. Indeed that brought in the attention which it gets today (who doesn’t want to train faster?). And there’s been a lot of similar papers like layer normalization, instance normalization and a few others. These try to overcome some of the disadvantages of batch normalization.

Let’s dive in!

What Is Normalization?

Let’s start by understanding what the typical input data normalization does. Though deep learning models don’t need much feature engineering, scaling input features to the same range helps to train the model faster. Min-Max scaling is one such feature scaling methods that brings all the inputs to the 0 to 1 range. It’s given by the following expression.

Min-Max Normalization

It’s simply subtracting the least value in the data from each data point and dividing it by the range of the values in the data. The range is just the maximum value in the input data distribution minus the least value. You could intuitively try to digest how this improves model training with the following example. Consider a machine learning model that takes the square footage and the number of rooms in a house as input and predicts the price.

While the square footage of houses may be in hundreds or thousands, the number of rooms would mostly be under 10. Owing to this vast difference in range between the different input features, the model struggles converge. This is because the model values larger features more than the smaller ones.

One more thing to notice is that a normalized feature column won’t have large variations in values like the unnormalized version. For example, a mansion may have thousands of square feet of land but a small house may be only a few hundred square feet. While the model should value houses with larger area more, too big a difference will hinder the model’s learning.

What’s Batch Normalization Then?

The Short Version

We saw how an ML model can benefit from normalized input features in the previous section. Now, let’s consider a fully connected deep neural network as our model. And, what if we added normalization for the hidden layers too? Technically, each hidden layer’s input is the previous layer’s output. It seems about right to apply normalization to that too, right? Yes. And that’s what batch normalization does.

The Long Version

In naive gradient descent, we feed one input at a time and update the model using backpropagation. In contrast, using mini-batch gradient descent, we pass forward and back a bundle of input examples at once. And this bundle is what we call the “batch“. Therefore, each hidden layer would see batches of activations coming from the layer before. Let’s explore further with an example.

Fully Connected Network For MNIST

Consider a fully connected neural network with 2 hidden layers and let’s assume we’re training it on the MNIST digits dataset. In case you’re not familiar with MNIST, the dataset contains 28×28 sized images of digits from 0 to 9. The digit’s image is given as input to the model by ravelling it into a long 1D vector of size 784. While training in batches, multiple digit image vectors are stacked and fed. Given that, the input shape becomes (batch_size, 784).

Dimensions of input data

The above illustration’s batch size is 2 and the number of features in each of the inputs is 784 (pixels). For simplicity’s sake, let’s assume that each layer of our model has 50 neurons each. Therefore, the activations of the first hidden layer would be of shape (2,50). Because 2 is our batch size and 50 is the number of neurons. If our batch size is 64, then the shape would be (64,784)

The shape of 1st Hidden Layer’s Activations

Further, let’s dive in to inspect the activation of a single neuron of the first hidden layer. A single neuron will produce just one activation if only one input was given. However, if the inputs are given in batches, it produces activation of shape (batch_size,1). Such that 1 activation for each input in the batch.

Adding a Batch Normalization Layer

Normally, each layer in a network would apply a non-linear activation function like ‘Sigmoid‘ after multiplying the input with the weights. If we’re inserting a batch normalization layer, we have to add it before we apply the non-linearity. The batch normalization layer normalizes the activations by applying the standardization method.

μ is the mean and σ is the standard deviation

It subtracts the mean from the activations and divides the difference by the standard deviation. The standard deviation is just the square root of variance. In detail, we calculate the mean and standard deviation across the batch and subtract and divide respectively.

You see, the layers adjust their parameters during backpropagation with the assumption that their input distribution stays the same. But all the layers are updated with backpropagation thereby changing each of their outputs.

This changes the inputs of the layers as each layer gets its inputs from the previous layer. This is called internal covariate shift. Hence, the layers experience a constantly changing input distribution. Importantly, this normalization method standardizes the distribution of the output activations.

Therefore it produces the desired distribution of activations that don’t vary too much when the weights change. In other words, it minimizes the internal covariate shift and aids the model to learn faster.

Although this is regarded as the main contributor for batch normalization’s success, some papers attribute it to other factors. This paper tells about how the batch normalization layer smoothens the loss curve which helps train faster.

Understanding With An Example

Again, let’s just consider the activation of a single neuron for clarity and set the batch size as 64. And, our two hidden layers have 50 neurons each. As we saw before, the output of a single neuron would be (64,1). The batch normalization layer computes the mean (μ) and standard deviation (σ) of the 64 values in the batch.

It is this μ value that is subtracted from the 64 values. And the σ divides the mean subtracted activations from the neuron. Rather, let’s visualize what the batch normalization layer does.

Note: This explanation is just for a single neuron in a layer for simplicity.

Batch Normalization’s output considering all the neurons differs from this in shape but all the processes are the same.

The flow of activation of a single neuron through a batch normalization layer

If we consider all the neurons in a layer, the activation shape would be (64,50). The mean and variance would then be calculated for each of the 50 neurons leading to 50 means and 50 variance values. And each column corresponding to a single neuron should be normalized with its respective mean and variance.

Effects of Batch Normalization On Activations

Well, we’ve successfully normalized a hidden layer’s activations and all that’s left is applying the nonlinearity, right?

No. It does not end here!

Well, there’s a problem with normalizing the activations of a hidden layer. If we’re feeding the normalized activations to a Sigmoid, the expressive power of the model is reduced. To clarify, the normalized activations are brought a 0 mean and unit variance distribution. Because of that, the activations fall in the linear region of the sigmoid non-linearity. We’ll go ahead and try to plot the functions using matplotlib and Python.

import numpy as npimport matplotlib.pyplot as pltdef sigmoid(x):return 1 / (1 + np.exp(-x))x = np.arange(-10,11,1) #creates a distribution [-10,-9,-8,.........7,8,9,10]x_sig = sigmoid(x)plt.plot(x,x_sig)plt.scatter(x,x_sig)

Consider this data distribution as the activation that we are gonna normalize. Below, the diagram shows the plot of the sigmoid of unnormalized data [-10 to 10].

Now, we’ll normalize the same data and see where the data points fall on the curve. And, here’s the Python code to do that.

import numpy as npimport matplotlib.pyplot as pltdef sigmoid(x):return 1 / (1 + np.exp(-x))x = np.arange(-10,11,1) #creates a distribution [-10,-9,-8,.........7,8,9,10]x_sig = sigmoid(x)plt.plot(x,x_sig) #to plot the blue, unnormalized sigmoid curvex = (x-np.mean(x))/np.std(x) #standardizing the datax_sig = sigmoid(x)plt.scatter(x,x_sig)#plot the standardized values (red data points)
The normalized points fall on the linear region of the sigmoid curve

Parameters of Batch Normalization

As we saw, batch normalization doesn’t let the model take advantage of the sigmoid non-linearity. This makes the sigmoid layer almost linear and we don’t want that! So, the authors of this paper gave a way to restore the representation power of the network. They introduced two parameters that shift and scale the normalized activations.

The gamma and beta parameters are learned with backpropagation. x_hat is the normalized activation

These two parameters are introduced so that they can undo the normalization done by the batch normalization layer. If Gamma is set to σ and Beta to the mean, it’ll fully denormalize the activations back to normal. At first, it may seem counter-intuitive. But this is done to enable the model to adjust the normalization the way it wants if that reduces the loss. And this brings back the expressive power that is lost when we normalized the activations.

One of the key advantages of using batch normalization is that it is fully differentiable. So, this means that you can backpropagate through it and can train using gradient descent. That said, we can update Gamma and Beta along with the model parameters.

Batch Normalization For Convolutions

Batch normalization after a convolution layer is a bit different. Normally, in a convolution layer, the input is fed as a 4-D tensor of shape (batch, Height, Width, Channels). But, the batch normalization layer normalizes the tensor across the batch, height and width dimensions. For each channel, mean and variance are computed across the other three dimensions. So, that gives ‘C’ number of mean and variance values which are then used for normalization.

For each channel in the tensor, the mean and variance are calculated across all the other dimensions

As a result, for each of the 64 channels in the above example, mean and variance are calculated for the 10x200x200 (400000) values. This gives us 64 means and 64 variances to normalize each channel separately.