Original article was published on Deep Learning on Medium
PyTorch backward() function explained with an Example (Part-1)
Lets understand what PyTorch backward() function does. First we will perform some calculations by pen and paper to see what actually is going on behind the code, and then we will try the same calculations using PyTorch .backward() functionality.
As as example, we are implementing the following equation, where we have a matrix X as input and loss as the output.
let assume that we have following 2×2 matrix.
Now say we add a number 2 to this matrix and because of broadcasting operation, this number 2 will get distributed to all the 4 elements of the result, as follows:
Now according to the diagram above, square the result and multiply by 3.
Next, a mean operator will add all the elements of the resulting matrix and divide by 4, the number of elements in the matrix.
Following equation will tell us what is actually happening behind the scene.
We can see that this loss is a function of 4 variables:
If we change any of these x variables, loss will get changed, that means we can find the partial derivative of loss with respect to (w.r.t) each xij. When we evaluate partial derivative w.r.t. xij we assume all other variable as fixed. For example, if we are differentiating the loss expression w.r.t x11 we treat x12, x21 and x22 as fixed numbers.
Now we represent loss matrix as follows:
Using Equation(1) and expending the above equation, we get,
After a substituting the values of x11, x12 x21 and x22 from matrix X (see Figure 1), we get
Solving it, we get
Now let’s do the same calculations using PyTorch built-in function call backward()
Following is my version of the run of the above code: