habib's blog

ResNet, So Simple Your Grandma Could Understand

Before going deep into the architecture granny, let us recap the topic "vanishing gradient descent". when training a network using backprop and gradient descent we calculate gradients (derivatives) of the loss with respect to the weights. These gradients tell us that how much we need to adjust the weights so that our loss is as min as possible. But sometimes these grads become so small that they make no updates to the weights and it results in training becoming slow as hell. We won't dive deep into the maths but think of it like this, during backprop the chain rule gets applied and in the calculation of the derivative of the loss function with respect to the weight if the derivatives are less than 1 and we keep on multiplying them again and again the result shrinks and it exponentially goes to 0. So this is how the grads vanish as they move backward.

Researches found many ways to tackle this problem and they tried better normalization and initialization techniques but then a new problem emerged! The problem of degradation. As the number of layers increased the performance of the network got more bad.

So what was the eureka moment for the authors of the resnet paper? Instead of asking a set of layers of neural network to learn a complete information from scratch what if we only made them learn the change or the small adjustment to the input? This small change is termed as residual. Think of it like this: H(x) is the entire desired complex info we wanted our layer to learn

H(x) = F(x) + x where F(x) is the residual function -> the part we want our layers to learn and x is the original input that was passed.

So the layers will be made to learn just the small difference that has occurred between the input and the output.

Now how does it help to sort our problem? Earlier our network had to learn the entire thing (H(x)) and now it only learns the small change that occurred to the input. If no adjustments were made then it can easily make that residual to zero which is much easier than learning to perfectly put the input through a bunch of complex deep layers.

How to build this into a neural network? Well simple as fuck ngl. just use a shortcut, skip the other layers and use a skip connection. it will ensure that the information (x) is always passed through and the layers in the block will not need to learn how to modify (x) and not to create the entire output from scratch.

Now to create a full network, we just simply stack these residual block on top of each other to create a very deep neural network.