Using PyTorch to Generate Images of Malaria-Infected Cells

Source: Deep Learning on Medium

Variational Autoencoders & Different Initialization Schemes

Go to the profile of Daniel Bojar

On my quest to master machine learning & deep learning (in order to make exciting biological discoveries) I tackled variational autoencoders (VAEs) as the next target on my list. Seemingly simple, a VAE consists of an encoder, a decoder and an appropriate loss function of course. Yet examples of creative applications of VAEs abound, including the generation of images, music and parts of proteins. And we’ll see why they are so popular for these matters in a second! People seemingly love to generate images of numbers and faces with VAEs for some reason, so I thought I try my luck with something different. After looking for a while I came across a dataset of human blood cells (healthy as well as infected by the causative agent of malaria, Plasmodium falciparum) on Kaggle (great for datasets by the way). As both categories contained around 14000 images which is not too shabby, I decided to go along with that and build VAEs for both categories to generate images of healthy & infected cells. Is that useful? I don’t know (probably not too much) but it’s a great learning experience and certainly pleasing to look at!

Examples of input images (not particularly pretty but hey you take what you can get, right?). Upper row are three healthy blood cells, lower row are three blood cells infected by the malaria pathogen P. falciparum (easily seen through the purple spots).

But before we throw ourselves into the fun of deep learning, some background. The nominal task of VAEs to large parts consists in predicting their input, referred to as reconstruction. While this might sound trivial, as all autoencoders, VAEs have a bottleneck structure. The layers at the edges of the architecture have considerably more nodes than the middle set of layers. So a wide encoder leads up to the most narrow bottleneck layer and a corresponding decoder picks up from there and brings the input back to the original dimensions. This prevents the VAE from simply applying an identity transformation to its inputs and be done with it.

Structure of a standard autoencoder. Variational autoencoders replace the ‘z’ layer with vectors of means and variances to create sampling distributions. Source

Therefore, the network has to learn the essential characteristics of, say, an image in order to recreate it. This constrained representation necessitated by the bottleneck layers is the main reason why autoencoders in general are used for denoising. Since only true characteristics of images will be present in the latent features, noise is removed upon reconstruction. By comparing the original image with the reconstructed image and minimizing their difference (read: minimizing the reconstruction error which can be for instance mean-squared error) we already have our loss function. Or we would have, for normal autoencoders.

Variational autoencoders however also have a generative component. Effectively, they allow you to get tweaked variants of your input images instead of trite replication, which is why they are so popular in creative fields such as music or visual arts. The difference in architecture is that instead of a never-ending array of dense layers, the VAE encoder will output a vector of means and a vector of standard deviations. These can now be used to form distributions to sample from. From every distribution we then sample a value and these values populate the bottleneck layer. And this is where the stochasticity resulting in the variations comes into play. Since we sample from distributions, the output will be different every time we run the VAE even if all the weights stay the same.

So far, so good. If we change the architecture as indicated, are we done then? Not quite. We still need to adapt our loss function. For the reconstruction error, we will use binary cross-entropy. Yet we also need another term in the loss function, namely Kullback–Leibler divergence (KL loss). Simplistically, this is the difference between probability distributions. In our case, it will be the sum of differences between VAE-generated probability distributions and the normal distribution. By minimizing this, we force the distributions in latent space to cluster around the center (the normal distribution). This results in an overlap of the generated distributions in the latent space and in improved image generation. Because otherwise, the VAE might ‘memorize’ inputs by creating clearly separated probability distributions for each input type. By summing up reconstruction error & KL loss we get our final loss function which we try to minimize during training.

Now we can build! I have the whole thing in a Jupyter notebook here if you want to try it or just steal the code for your own projects. The VAE is implemented in PyTorch, the deep learning framework which even life science people such as myself find comfortable enough to work with. I started out by shamelessly stealing the VAE code from the PyTorch repo and then tweaking & modifying it until it fit my needs. Typically I’m running these things on Colaboratory from Google because it gives you a Tesla K80 GPU for free, which is absolutely great! You can simply use it with Jupyter notebooks and continuously run it for up to 12 hours.

After getting the images, you have to convert them into a PyTorch dataset. The implementation of this in my notebook is quite bare-bones but in general a PyTorch dataset object needs an indexer (__getitem__) and a length (__len__). We can use this opportunity to also apply some transforms to the images, such as making them all the same size and normalizing them. Then, we can feed the dataset to a data loader which will break it up into minibatches which we can use for training. Easy, right? For the VAE itself, we instantiate the layers in the __init__ section and define the layer interactions in the ‘forward’ section. Basically, it’s just a bunch of linear layers with the occasional ReLU activation function. As VAEs can be a bit finnicky in training (think vanishing and exploding gradients, yay!), I also added two batch normalization layers to the encoder and decoder. By reducing covariate shift (increasing the independence of two interacting layers) they allow for more stability during training and even have a slight regularization effect (that’s why they get turned off if you switch your model to eval mode). At the end we need a sigmoid activation function for the binary cross-entropy loss so that all values are between 0 and 1.

The last point which significantly increases training stability is layer initialization. Inputting a certain set of weights can make a huge difference in training stability, especially for deep neural networks. I started out with Xavier initialization for all my linear layers which finally allowed me to train the VAE without anything blowing up. This approach samples initial weights from a random uniform distribution influenced by the number of incoming & outgoing connections of a given layer. But then recently I stumbled upon an excellent blog post on initialization schemes including Kaiming initialization, so I decided to try that one as well and compare it to the Xavier-initialized VAE. Apparently, this one is best for ReLU-like activation functions and consists of a weight tensor drawn from a standard normal distribution multiplied by a factor inversely proportional to the number of incoming connections to the layer.

VAE-generated images of blood cells, either healthy (left panel) or infected with malaria (right panel). Here, the VAE was initialized with Xavier initialization.

Additionally, I added a decaying learning rate (halving after every epoch) to get an improved performance. After training for 10 epochs, I started to generate images with the trained VAEs. For this, it’s sufficient to sample random values from the standard normal distribution, input them into the bottleneck layer of the trained VAE and decode them into generated images. If we have a look at the images generated by the Xavier-initialized VAE, we can clearly see that the VAE did indeed learn something from the images. Immediately obvious is the color difference between uninfected (yellow) and infected cells (purplish). Then, if you look closer, uninfected cells seem to be rounder and more evenly shaped than infected cells. While you can see some granularity in the infected cells it’s not really the same clear-cut clusters as in the input images though. For images generated from Kaiming-initialized VAEs, we can also observe the clear color difference yet here the granularity seems to be even less pronounced. Additionally, the images seem quite hazy. In fact, VAE-generated images have been noted to be kind of noisy. Sometimes, this isn’t necessarily bad though. If you remember the ragged edges of the input images, a bit of blurring around the edges at least make the generated images a bit more aesthetically pleasing.

VAE-generated images of blood cells, either healthy (left panel) or infected with malaria (right panel). Here, the VAE was initialized with Kaiming initialization.

Where to go from here? Generative adversarial networks (GANs) have been reported to create images with increased resolution in comparison to VAEs, so if that is a factor GANs might be attractive. Also, especially for images, the usage of linear layers might be inferior to using convolutional layers. Building a CNN-VAE might yield considerably improved generated images, so go try that if you feel like it! In principle using an autoencoder for this kind of uninfected / infected cell setup could give you insights about the characteristics of the respective cell states (by investigating the parameters in the constructed latent space) and might help in the automated diagnosis of malaria. Anyways, I definitely enjoyed working on this little thing and learned a bit more about deep learning & PyTorch. Looking forward to the next project!

Bonus: If you set the parameters right, you can force your VAE to generate cell images that look like gems or colorful pebbles. Because we all deserve beautiful images!

Can you figure out how to get those beautiful cells from a different universe?