Shrinking Variational Autoencoder Bottlenecks On-the-Fly

Original article can be found here (source): Deep Learning on Medium

Shrinking Variational Autoencoder Bottlenecks On-the-Fly

Variational Autoencoders — or VAEs in short — are a specific family of powerful deep neural network architectures. They actually consist of two neural networks that are jointly optimized: the encoder and the decoder. As the name autoencoder implies, the encoder will project input data into a low-dimensional latent space, after which the decoder will use this projection to reconstruct the original input data itself (‘auto’) as well as possible. This results in a diabolo-style architecture: the input and output dimensionality is the same, but the intermediate projection has a lower dimensionality, i.e. the data bottleneck.

In the image above, the input data x is transformed into a latent vector z by the encoder, and is reconstructed by the decoder as best as possible. When designing a neural network for this particular task, you have to specify which type of layers will be used, along with the dimensionality for each of these layers. For (variational) autoencoders in particular, you need to define the dimensionality of the bottleneck upfront. This is often a matter of guessing and gut feeling. It is, however, possible that you pick too little or too many dimensions for your particular application.

You might wonder: what does it mean to have “too little” or “too many” dimensions? It is a question that can be answered very differently, depending on your viewpoint, but it always comes down to defining in what case the number of dimensions is “just right”. In our work, we say that the number of dimensions is optimal if we are able to obtain a predefined reconstruction error. For example, if we do not want the quality of the reconstructions to drop below τ MSE, then this τ is the minimum reconstruction error that we aim for. This τ is our threshold: too narrow bottlenecks will not be able to attain it, while too wide bottlenecks will go (way) past the threshold. The minimum amount of dimensions that is needed to hit the threshold, is the optimal number.

More latent dimensions leads to a high-quality decoder outputs with a small reconstruction loss. If we decrease the number of dimensions, the quality of the reconstructions will drop gradually.

To achieve our goal, we essentially need two ingredients: a mechanism to delete dimensions from the bottleneck that are uninformative, and an optimization algorithm that can take into account our threshold τ on the reconstruction error. For the former we introduce a gating mechanism which essentially associates an on-off switch with each latent dimension. If the gate is closed for a particular dimension, it is multiplied by exactly 0, so that no information can be passed through. And vice versa, if the gate is open, the dimension is multiplied by 1.

A gate vector ν (with only 0s and 1s) is multiplied element-wise with the latent vector z.

Or course, such an on-off characteristic is not easily learned using gradient-based optimization, since the derivative of the Heaviside function is zero everywhere in its domain. We will therefore use the ARM gradient estimator, which stands for Augment-REINFORCE-Merge, and was published by Yin and Zhou at ICLR 2019 and has been used by Li and Ji to sparsify neural networks, as published at ECML 2019. Similar to the straight-through estimator, the discrete on-off function is replaced by a smooth version, e.g. a sigmoid function. The gate ν is then determined using the sigmoid of an associated gate parameter γ: σ(γ).

Different from straight-through estimation, is that we use uniform sampling between 0 and 1 at each forward pass to determine whether a gate should be opened or closed. For example, suppose that we sample u=0.34 while σ(γ)=0.52. Since u<σ(γ), we open the gates, i.e. ν is set to 1. If we would have sampled u=0.67, then u>σ(γ) and the gate is closed. In short, we treat σ(γ) as the probability of an open gate. The lower σ(γ), the less likely it is the gate will be opened. To estimate the gradient of the loss w.r.t. γ, there exists a remarkably simple closed-form expression:

In this formula, the calligraphic function ℱ(⋅) stands for a forward pass through the autoencoder, and is a short-hand notation for the reconstruction loss given a specific gate vector. So, what do we do in practice, during training? We feed some data x through the encoder and calculate the latent vector z. We then sample a uniform vector u and perform two forward passes through the decoder, each with a different gate vector as indicated in the formula above.

At this point in the story we are able to optimize the gate parameters γ, but we don’t tell the optimizer yet that the gates need to be closed as much as possible. This is fairly easy, actually. We just include a regularizer to lower the gate parameters. After all, if γ is very negative then σ(γ) is close to zero, and the gate will almost always be closed. If we introduce a regularization parameter β, the gradient w.r.t. γ therefore becomes:

Alright, that was the first ingredient. Now, of course, we cannot let all the gates close themselves, since then we would be violating our threshold τ that we set as a maximum on the reconstruction error! How do we do that? This is were GECO comes into play, short for Generalized ELBO with Constrained Optimization. GECO was introduced by Viola and Rezende in the Bayesian Deep Learning workshop at NIPS 2018, and it is absolutely fantastic. It uses classical Lagrange multipliers to constrain the optimization using inequalities, in our case MSE < τ. Our general loss function becomes something of the form: Loss = Regularization + λ ⋅ (MSE−τ). In our case, the regularization part consists of the standard KL divergence from the ELBO and the extra term we introduced above to lower the gate parameters. Let’s take a look at a single training run:

At the start of training, the MSE is above τ, but once it dives under this threshold, we can start closing the gates one by one. This is shown in the right graph. Each plateau in this staircase graph corresponds to an integer amount of open gates. At 21K batches, there are 6 gates open, at around 27K batches, this is lowered to 5. Notice that the MSE starts rising as soon as we start closing gates, and that it becomes harder to close more gates once the MSE is flirting with the threshold. After a while, we even start oscillating, when the optimizer tries to bring the number of gates down to 4, but is constantly pulled back by GECO, because it would heavily violate the constraint.

We train a VAE on the dSprites dataset and we set the threshold τ=20 and τ=35. In the first case, we can distinguish the shape forms along with their position and rotation. In the latter case, the reconstructions are just blurred circles, and therefore don’t contain any shape or rotational information. We indeed see that for τ=20 we need 5 dimensions to encode the information in the images, while for τ=35 we only need 3, i.e. x-position, y-position and size.

Top row: original images; middle row: reconstructions with τ=20; bottom row: reconstructions with τ=35.

Finally, we show that our approach is able to come close to the “optimal” number of dimensions on a selection of five different datasets. We first train a regular VAE with n=5, 10 and 30 dimensions, and we set τ=0, so that we focus maximally on the reconstruction quality. We record the lowest MSEs that we can obtain, and use these MSEs as thresholds to train new VAEs with 10, 20 and 60 dimensions. If we are able to lower the dimensions back to 5, 10 and 30 during training, then our method is highly optimal.

The original VAEs are trained for 200K batches, and we train the new VAEs for 200K, 300K and 400K batches. In the table, ║ν║₀ stands for the 0-norm of the gate vector, which is essentially the number of open gates. For most datasets, we come indeed close to the original 5, 10 and 30 dimensions. In some cases, especially for n=30, we can lower the number of dimensions far below 30 dimensions, which shows that the original VAE had a bottleneck which was overprovisioned.

Thanks for reading. This was a short write-up of our latest paper. Check our article for more in-depth information at