I’ve been trying to figure out ResNet and I keep reading articles which explain it in the same way and I just didn’t quite get it! As such, I’ve decided to try to explain it using basic drawings on my laptop trackpad (I didn’t actually use crayons, sorry).
Before I begin let me briefly explain the concept of a ResNet block: In a traditional convolutional network, you might have a long string of convolutional layers (with batch norms and pooling layers mixed in), but with ResNet, we cluster our network into blocks of three convolutional layers. What makes these blocks unique is that we add the output of the previous block onto the result of our current block.
For example, suppose we want a ResNet block result in Y. To solve this using a traditional conv net, we would have something like this:
F(x) = y
Where F(x) could be a single layer or a set of layers. For a ResNet block, we would do this:
R(x) + x = Y
Where R(x) is a set of three convolutional layers, the traditional makeup of a ResNet block. This adding of x forces R(x) to model the difference, or residual, between Y and x:
R(x) = Y – x
For a more technical and in-depth explanation of ResNet construction, check out this blog post by Michael Dietz.
So why is the ResNet block useful? As promised, I’ll try to explain it with drawings:
Suppose you are trying to model (draw) the following shape:
And your initial state (i.e. the output of the previous block/layer) is:
With a traditional conv net your problem might look something like this:
So we need to write some function, or in my example draw some picture, which matches the “model” image on the right hand side as best we can.
If we were to model the same function using a ResNet block, our function would look something like this:
Subtracting our starting image (the first half of the squiggle) from both sides gives us:
Which simplifies to:
Why is this better than the conv function? Its better because the ResNet function is simpler, we only need to draw half the shape rather than replicating the entire thing. You might be wondering why that’s such a big deal since, in the non-ResNet function, we still had knowledge on how to draw the first half of the squiggle as a starting point. The difference is that, with the non-ResNet function we still had draw both halves of the image, rather than just drawing the missing segment, which exposes us to more human error or, if you were actually training a network, the approximation error which accompanies any attempt to fit a complex pattern. To use another analogy, we all have the knowledge necessary to perform basic arithmetic, but we’d still rather avoid it because every time we do it in our heads, theres a small chance that we’ll get it wrong.
Just let we might make mistakes every time we attempt to solve a problem, our neural network may make mistakes when trying to fit an underlying pattern. The more complex that pattern is, the tougher a time the network will have in fitting to it.
In this way, each ResNet block only models whatever differs between the target pattern/image and the starting point. To take this to an extreme, if we suppose that our starting activations were a perfect match for our target activations, i.e. Y = x in F(x) = Y then, in a traditional conv net, we would have to hope that F(x) is trained to be an identity operation. With the Resnet block, we already add on the initial state x in R(x) + x = Y, therefore rather than training R(x) to be an identity, we can just zero it out, which is a much easier pattern to fit.
In reality we probably won’t have many cases where our starting point is a perfect match for the target point, but in cases where we have some partial match, we may be able to zero out some weights rather than trying to train them to perfectly match whatever segments of our pattern we’ve already figured out.
The Vanishing Gradient Problem
Although the analogy above explains how ResNet might have an easier time of finding fits than a traditional convolutional network will, it’s not actually the reason (or at least not the primary reason) why ResNet was designed! ResNet was created to tackle the difficulties in training very deep neural networks by solving the vanishing gradient problem. To get an understanding of the vanishing gradient problem, I’d check out this Quora post. The three sentence version is that some activation functions (e.g. sigmoid) on each layer F(x) scale activations into a very small range [0,1], which results in smaller gradients. As the gradient is propagated backwards through the network, it shrinks a little until it is “vanishingly small” by the time it arrives at the earlier layers of the network. This can be rectified using different activation functions, such as RELU, or by using the ResNet architecture.
To understand how ResNet solves this problem, think about how each ResNet block passes its input straight through to its end. The effect of this is that, rather than our gradients being shoved through layer/block F(x) which may shrink their magnitude, it’s being sent through R(x) + x, which means that yes, R(x) will contribute a shrunken version of the gradient, but passing the gradient through x doesn’t shrink it at all, which allows the back-propogation to maintain the original “signal” strength/magnitude.
To use another analogy, suppose you have a stick of crayon which you need to sharpen, but you’re using a great sword to do it. Each time you shear off a chunk of crayon, it’s going to shrink a little (or a lot) and after many clumsy swipes of your sword, you may only have a tiny nub of crayon left.
Now imagine that, after each of your sword swipes, you’re given a chunk of crayon large as that you started with, meaning that no matter how much you mangle the crayon you’re sharpening, you don’t lose out on total crayon mass. At this point the analogy sort of falls apart since you’d have to somehow combine the supplementary crayons with your sharpened crayon in order for it to be of any use, but that’s what I get for trying to bring Crayons back into the picture.
If this vanishing gradient solution isn’t making any sense to you, check out this great explanation by Hugh Perkins on stack exchange.
Does this help? Does it hurt? Let me know in the comments!
Source: Deep Learning on Medium