The reparameterization trick with code example

Original article was published on Deep Learning on Medium

The reparameterization trick with code example

First time I hear about this (well, actually first time it was readen…) I didn’t have any idea about what was it, but hey! it sounds pretty cool!

Very likely if you are reading this you already have some idea about this topic: it’s a way to use backpropagation when using random sampling.

That’s the thing, you can’t backpropagate when you pick up random numbers. Randomness on a computer usually means reading the voltage of some component, using the internal clock… some interaction with the world. There is no way to calculate the derivative of that.

This happens if we use variational auto-encoders for example, and maybe with bayesian networks? Things like that. This happens when:

  • We obtain a probability distribution. Another non typical example: noise layers for DQN.
  • We sample from the distribution.
  • We use the sampled value as the input for another layer.

Then at some point we want to backpropagate, how do we propagate through a random number? ¯\_(ツ)_/¯

De-Standarization of the distribution

Say we have a normal distribution as N(10, 5), that means mean 10, standard deviation 5 (I prefer to read as location 10 and scale 5). It’s the same as doing:

N(0, 1) *5 + 10

Don’t you believe me? Here you have an example with numpy:

We just de-standarized the data. Instead of making it 0 mean 1 std, we did the opposite, we moved and scaled somewhere else. Do you know who is very good at discovering where is that somewhere else? Neural networks, among others.

So trick is as follows (for a normal distribution):

  • Obtain new mean.
  • Obtain new std.
  • Apply those to a batch of random numbers from a N(0, 1) distribution.

The network will decide how shift and scale the distribution on it’s own. Then you can propagate the gradient of the number acting as the mean and the number acting as the std. Of course this have several implications that I’m not smart enough to even try to explain it, so I’m not going to do it.

Examples of how reparameterization works on standard normal distribution.

Code example

An example might look like this:

As you can see from a linear layer we obtain the scale and location and apply those a batch of random samples from a N(0, 1) distribution.

Working example

You should try without BatchNorm as I think it might be able to predict sin function on it’s own because of its internal affine transformation matrix.

NOTE: I had to scale the input to make it 0–1, I couldn’t make it work without it.

Does this work for other distributions?

In theory yes according to this stack overflow answer although doesn’t look straightfoward at all.

Variational autoencoders paper: