Mixture of Variational Autoencoders — a Fusion Between MoE and VAE

Source: Deep Learning on Medium


An unsupervised approach to digit classification and generation

Go to the profile of Yoel Zeldes

The Variational Autoencoder (VAE) is a paragon for neural networks that try to learn the shape of the input space. Once trained, the model can be used to generate new samples from the input space.

If we have labels for our input data, it’s also possible to condition the generation process on the label. In the MNIST case, it means we can specify which digit we want to generate an image for.

Let’s take it one step further… Could we condition the generation process on the digit without using labels at all? Could we achieve the same results using an unsupervised approach?

If we wanted to rely on labels, we could do something embarrassingly simple. We could train 10 independent VAE models, each using images of a single digit.

That would obviously work, but you’re using the labels. That’s cheating!

OK, let’s not use them at all. Let’s train our 10 models, and just, well, have a look with our eyes on each image before passing it to the appropriate model.

Hey, you’re cheating again! While you don’t use the labels per se, you do look at the images in order to route them to the appropriate model.

Fine… If instead of doing the routing ourselves we let another model learn the routing, that wouldn’t be cheating at all, would it?

Right! :)

We can use an architecture of 11 modules as follows:

A manager module routing an input to the appropriate expert module

But how will the manager decide which expert to pass the image to? We could train it to predict the digit of the image, but again — we don’t want to use the labels!

Phew… I thought you’re gonna cheat…

So how can we train the manager without using the labels? It reminds me of a different type of model — Mixture of Experts (MoE). Let me take a small detour to explain how MoE works. We’ll need it, since it’s going to be a key component of our solution.


Mixture of Experts explained to non-experts

MoE is a supervised learning framework. You can find a great explanation by Geoffrey Hinton on Coursera and on YouTube. MoE relies on the possibility that the input might be segmented according to the 𝑥→𝑦 mapping. Have a look at this simple function:

The ground truth is defined to be the purple parabola for 𝑥<𝑥’, and the green parabola for 𝑥≥𝑥’. If we were to specify by hand where the split point 𝑥’ is, we could learn the mapping in each input segment independently using two separate models.

In complex datasets we might not know the split points. One (bad) solution is to segment the input space by clustering the 𝑥’s using K-means. In the two parabolas example, we’ll end up with 𝑥’’ as the split point between two clusters. Thus, when we’ll train the model on the 𝑥<𝑥’’ segment, it’ll be inaccurate.

So how can we train a model that learns the split points while at the same time learns the mapping that defines the split points?

MoE does so using an architecture of multiple subnetworks — one manager and multiple experts:

MoE architecture

The manager maps the input into a soft decision over the experts, which is used in two contexts:

First, the output of the network is a weighted average of the experts’ outputs, where the weights are the manager’s output.

Second, the loss function is

𝑦¡ is the label, 𝑦¯¡ is the output of the i’th expert, 𝑝¡ is the i’th entry of the manager’s output. When you differentiate the loss, you get these results (I encourage you to watch the video for more details):

  1. The manager decides for each expert how much it contributes to the loss. In other words, the manager chooses which experts should tune their weights according to their error.
  2. The manager tunes the probabilities it outputs in such a way that the experts that got it right will get higher probabilities than those that didn’t.

This loss function encourages the experts to specialize in different kinds of inputs.


The last piece of the puzzle… is 𝑥

Let’s get back to our challenge! MoE is a framework for supervised learning. Surely we can change 𝑦 to be 𝑥 for the unsupervised case, right? MoE’s power stems from the fact that each expert specializes in a different segment of the input space with a unique mapping 𝑥→𝑦. If we use the mapping 𝑥→𝑥, each expert will specialize in a different segment of the input space with unique patterns in the input itself.

We’ll use VAEs as the experts. Part of the VAE’s loss is the reconstruction loss, where the VAE tries to reconstruct the original input image 𝑥:

MoE architecture where the experts are implemented as VAE

A cool byproduct of this architecture is that the manager can classify the digit found in an image using its output vector!

One thing we need to be careful about when training this model is that the manager could easily degenerate into outputting a constant vector — regardless of the input in hand. This results in one VAE specialized in all digits, and nine VAEs specialized in nothing. One way to mitigate it, which is described in the MoE paper, is to add a balancing term to the loss. It encourages the outputs of the manager over a batch of inputs to be balanced:

Enough talking — It’s training time!

Images generated by the experts. Each column belongs to a different expert.

In the last figure we see what each expert has learned. After each epoch we used the experts to generate images from the distributions they specialized in. The i’th column contains the images generated by the i’th expert.

We can see that some of the experts easily managed to specialize in a single digit, e.g. — 1. Some got a bit confused by similar digits, such as the expert that specialized in both 3 and 5.


An expert specializing in 2

What else?

Using a simple model without a lot of tuning and tweaking, we got reasonable results. Optimally, we would want each expert to specialize in exactly one digit, thus achieving a perfect unsupervised classification via the output of the manager.

Another interesting experiment would be to turn each expert into a MoE of its own! It will allow us to learn hierarchical parameters by which VAEs should specialize. For instance, some of the digits have multiple ways to be drawn: 7 can be drawn with or without a strikethrough line. This source of variation could be modeled by the MoE in the second level of the hierarchy. But I’ll leave something for a future post…


Originally published by me at anotherdatum.com.