Source: Deep Learning on Medium
Weight Initialization — The Fix
Last time we saw that correctly initializing weights is not about fanciness. Stacking 50 simple linear layers using the harmless normal initialization is catastrophic . But why? What is happening inside those matrix multiplications that is shooting our variances to NaNs? And more importantly, HOW DO WE FIX IT?
Let’s quickly set up our coding environment to do some state of the art experiments and recall what is happening. This time instead of using mnist let’s just generate some random numbers as the input, this way we can generate as many samples as we want to achieve better numeric approximation.
Yeah, after the very first layer our variance is already 783… Do note something very interesting here though: 783 is very close to 784, and 784 is the number of inputs to our layer, and this is not just another mere coincidence, as someone once said:
God does not play dice with deep learning — Some important guy
Why, I hope you’re asking, why? Soon you’ll understand that all the answers lies in matrix multiplication, so take a minute to forget about this article and appreciate its beauty here.
I’m all about hand-waving math and that’s what I’m going to do here, so if you don’t like that, I’m… sorry?
(fun fact: this article was originally going to be named “Weight initialization for bad mathematicians”)
This little matrix multiplication is all our layer is doing. I’ve simplified our problem by taking a single data sample (left) that is multiplied by a single set of weights that will result in a single activation.
And this is the magical moment. The weights are being multiplied by the input to produce the output, note that this is just an element-wise multiplication followed by a sum.
This sum is what is causing our variance to explode. More specifically, the variance will be equal to the number of elements in the sum*. (This is also the number of inputs of our linear layer).
You see that pesky asterisk there? That is the same asterisk you see in commercials when something sounds too good to be true, our statement only holds in very specific conditions, specifically only when the two things being multiplied both have mean 0 and variance 1, which fortunately is our case!
The rigorous mathematical proof is of course out of the scope of this article, but that does not stop us from waving our hands a lot to get an intuition about it. If you want to get a deeper understanding of what is happening here I highly recommend these videos from khan academy.
For making things more clear let’s use the power of code!
These are our five randomly generated numbers, they are all very close, in fact the furthest one is only 1.41 units away from the mean.
Now, let’s generate a bunch of these numbers and add them all together to see what happens.
They’re all over the place, one is -13.21 and the other 10.38 which is 14 units away from the mean, they’re really far apart. Recall that standard deviation describes exactly that “In average, how far away are my numbers from the mean?”, and in fact standard deviation and variance (which is just standard dev squared) are much higher now.
I hope this is straight forward to understand, try running
sum([torch.randn(1) for _ in range(100)]a bunch of times and observe the numbers you’re getting. Sometimes you get big numbers because you sampled more numbers that are closer to 1, sometimes you get smaller numbers because you sampled more numbers that are closer to 0, and this increase in variety of outcomes causes an increase of variance.
So how can we decrease the variance? To start we can try dividing all values of
x by 10 and see what happens.
Almost there! Variance went down as we expected, but it went down by 100 instead of 10, why? Why it scales down based on the square root of the number we divide?
This is the formula of variance (stolen for wikipedia). As said before, variance is just a measure of how far apart are our numbers to the mean. Yeah, all of that but squared. Why?
Just try to calculate it without having the square, any thoughts on what is going to happen?
Zero… I’m too lazy to write latex to prove that mathematically right now, but don’t be as lazy as I’m, take out pen and paper and try to prove it. It’s very simple, just write
x = [a, b, c] and calculate the mean and the non-squared variance from there (spoiler alert: everything will cancel out).
So what is the genius idea brilliant minds of the past had to solve this issue? SQUARE IT, lol. This way all terms of the sum become positive and bye bye zero.
After all of that, we now understand that if we want the variance of
x to be 10 times smaller we have to divide all
x by √10.
At this point I forgot our original problem, you probably didn’t because it takes you 4 minutes of reading to get here but 5 hours (maybe more, a lot more) of writing. Here’s it again:
We now know how to manipulate the variance, we know how to scale the weights to achieve any target variance we like, but what would be the ideal target?
Remember that asterisk? One of the conditions was that the input variance should be one. By doing that all of our layers activations should (hopefully) have a variance of one, no matter how many layers we stack.
Very simple now huh? Variance is 783 but you would like it to be 1? What do you do? Divide it by √783!
Now we can try the same experiment as before and stack 50 layers. Remember that as described and observed above the expected variance of the output of our layer is equal to the number of inputs to that layer, we use this information to scale our weights accordingly and write this handy little function:
And now if we stack 50 layers…
Horaaaay, it works!! The variance is not exactly 1, it actually varies a little bit each time you run the cell, but it’s very far away from exploding! Now let’s just add some ReLU and we are DONE!
(ReLU, Rectified Linear Unit is just the “cool kids” name for “Replace negatives with zero” function)
We only need to redefine
forward to see what happens.
Failure! Our variance is not going to NaN, but 0 is not super great either. Let’s see what happens with the activation before and after we used relu.