Graduating in GANs: Going from understanding generative adversarial networks to running your own

Source: Deep Learning on Medium


Read how generative adversarial networks (GANs) research and evaluation has developed then implement your own GAN to generate handwritten digits.

Go to the profile of Cecelia Shao

Generative Adversarial Networks (GANs) have taken over the public imagination —permeating pop culture with AI- generated celebrities and creating art that is selling for thousands of dollars at high-brow art auctions.

In this post, we’ll explore:

  • Brief primer on GANs
  • Understanding and Evaluating GANs
  • Running your own GAN

There is a wealth of resources for catching up on GANs, so our focus for this article is to understand how GANs can be evaluated. We’ll also walk you through running your own GAN to generate handwritten digits like MNIST.

Here’s one run of the GAN we’ll show you how to implement later on — see how the handwritten digits it generates become increasingly realistic as training progresses!

Brief primer on GANs

Since its inception in 2014 with Ian Goodfellow’s ‘Generative Adversarial Networks’ paper, progress with GANs has exploded and led to increasingly realistic outputs.

Just three years ago, you could find Ian Goodfellow’s reply on this Reddit thread to a user asking about whether you can use GANs for text:

“GANs have not been applied to NLP because GANs are only defined for real-valued data.

GANs work by training a generator network that outputs synthetic data, then running a discriminator network on the synthetic data. The gradient of the output of the discriminator network with respect to the synthetic data tells you how to slightly change the synthetic data to make it more realistic.

You can make slight changes to the synthetic data only if it is based on continuous numbers. If it is based on discrete numbers, there is no way to make a slight change.

For example, if you output an image with a pixel value of 1.0, you can change that pixel value to 1.0001 on the next step.

If you output the word “penguin”, you can’t change that to “penguin + .001” on the next step, because there is no such word as “penguin + .001”. You have to go all the way from “penguin” to “ostrich”.

Since all NLP is based on discrete values like words, characters, or bytes, no one really knows how to apply GANs to NLP yet.”

Now GANs are being used to create all kinds of content including images, video, audio, and (yup) text. These outputs can be used as synthetic data for training other models or just for spawning interesting side projects like thispersondoesnotexist.com, thisairbnbdoesnotexist.com/, and This Machine Learning Medium post does not exist. 😎

Behind the GAN

A GAN is comprised of two neural networks — a generator that synthesizes new samples from scratch, and a discriminator that compares training samples with these generated samples from the generator. The discriminator’s goal is to distinguish between ‘real’ and ‘fake’ inputs (ie. classify if the samples came from the model distribution or the real distribution). As we described, these samples can be images, videos, audio snippets, and text.

Simple GAN overview from Kiran Sudhir

To synthesize these new samples, the generator is given random noise and attempts to generate realistic images from the learnt distribution of the training data.

The gradient of the output of the discriminator network (a convolutional neural network) with respect to the synthetic data informs how to slightly change the synthetic data to make it more realistic. Eventually the generator converges on parameters that reproduce the real data distribution, and the discriminator is unable to detect the difference.

You see and play with these converging data distributions with GAN Lab:

Here’s a selection of the best guides on GANs :

Understanding and Evaluating GANs

Quantifying the progress of a GAN can feel very subjective — “Does this generated face look realistic enough?” , “Are these generated images diverse enough?” — and GANs can feel like black boxes where it’s not clear which components of the model impact learning or result quality.

To this end, a group from the MIT Computer Science and Artificial Intelligence (CSAIL) Lab, recently released a paper, ‘GAN Dissection: Visualizing and Understanding Generative Adversarial Networks’, that introduced a method for visualizing GANs and how GAN units relate to objects in an image as well as the relationship between objects.

Figure 1 from Bau et. al 2019 showing image modification through intervention with certain GAN units.

Using a segmentation-based network dissection method, the paper’s framework allow us to dissect and visualize the inner workings of a generator neural network. This occurs by looking for agreements between a set of GAN units, referred to as neurons, and concepts in the output image such as tree, sky, clouds, and more. As a result, we’re able to identify neurons that are responsible for certain objects such as buildings or clouds.

Having this level of granularity into the neurons allows for edits to existing images (e.g. to add or remove trees as shown in the image) by forcefully activating and deactivating (ablating) the corresponding units for those objects.

However, it’s not clear if the network is able to reason about objects in a scene or if it’s simply memorizing these objects. One way to get closer to an answer for this question was to try to distort the image in unrealistic ways. Perhaps the most impressive part of MIT CSAIL’s interactive web demo of GAN Paint was how the model is seemingly able to limit these edits to ‘photorealistic’ changes. If you try to impose grass onto the sky, here’s what happens:

Even though we’re activating the corresponding neurons, it appears as though the GAN has suppressed the signal in later layers.

Figure 11 from Bau et. al. 2019 shows how the local context for an object impacts the likelihood of the object synthesis (in this case, the likelihood of a door being generated on a building versus on a tree or in the sky).

Another interesting way of visualizing GANs is to conduct latent space interpolation (remember, the GAN generate new instances by sampling from the learned latent space). This can be a useful way of seeing how smooth the transitions across generated samples are.

These visualizations can help us understand the internal representations of a GAN, but finding quantifiable ways to understand GAN progress and output quality is still an active area of research.

Two commonly used evaluation metrics for image quality and diversity are: the Inception Score and the Fréchet Inception Distance (FID). Most practitioners have shifted from the Inception Score to FID after Shane Barratt and Rishi Sharma released their paper ‘A Note on the Inception Score’ on key shortcomings of the former.

Inception Score

Invented in Salimans et al. 2016 in ‘Improved Techniques for Training GANs’, the Inception Score is based on a heuristic that realistic samples should be able to be classified when passed through a pre-trained network, such as Inception on ImageNet. Technically, this means that the sample should have a low entropy softmax prediction vector.

Besides high predictability (low entropy), the Inception Score also evaluates a GAN based on how diverse the generated samples are (e.g. high variance or entropy over the distribution of generated samples). This means that there should not be any dominating classes.

If both these traits are satisfied, there should be a large Inception Score. The way that you combine the two criteria is by evaluating the Kullback-Leibler (KL) divergence between the conditional label distribution of samples and the marginal distribution from all the samples.

Fréchet Inception Distance

Introduced by Heusel et al. 2017, FID estimates realism by measuring the distance between the generated distribution of images and the true distribution. FID embeds a set of generated samples into a feature space given by a specific layer of Inception Net. This embedding layer is viewed as as a continuous multivariate Gaussian, then the mean and covariance are estimated for both the generated data and the real data. The Fréchet distance between these two Gaussians (a.k.a Wasserstein-2 distance) is then used to quantify the quality of generated samples. A lower FID corresponds to more similar real and generated samples.

An important note is that FID needs a decent sample size to give good results (suggested size = 50k samples ). If you use too few samples, you will end up over-estimating your actual FID and the estimates will have a large variance.

For a comparison of how Inception Scores and FID scores have differed across papers, see Neal Jean’s post here.

Want to see more?

Aji Borji’s paper, ‘Pros and Cons of GAN Evaluation Measures’ includes an excellent table with more exhaustive coverage of GAN evaluation metrics:

Interestingly, other researchers are taking different approaches by using domain-specific evaluation metrics. For text GANs, Guy Tevet and his team proposed using traditional probability-based language model metrics to evaluate the distribution of text generated by a GAN in their paper ‘Evaluating Text GANs as Language Models’.

In ‘How good is my GAN?’, Konstantin Shmelkov and his team use two measures based on image classification, GAN-train and GAN-test, which approximate the recall (diversity) and precision (quality of the image) of GANs respectively. You can see these evaluation metrics in action in the Google Brain research paper, ‘Are GANS created equal’, where they used a dataset of triangles to measure the precision and the recall of different GAN models.

Running your own GAN

To illustrate GANs, we’ll be adapting this excellent tutorial from Wouter Bulten that uses Keras and the MNIST dataset to generate written digits.

See the full tutorial notebook here.

We’ll be tracking our GAN’s progress by visualizing our loss and accuracy curves but also by checking test outputs using Comet.ml

This GAN model takes in the MNIST training data and random noise as an input (specifically, random vectors of noise) to generate:

  • images (in this case, image of handwritten digits). Eventually, these generated images will resemble the data distribution of the MNIST dataset.
  • the discriminator’s prediction on the generated images

The Generator and Discriminator models together form the adversarial model — for this example, the generator will perform well if the adversarial model serves an output classifying the generated images as real for all inputs.

See the full code here and the full Comet Experiment with results here

Tracking your model’s progress

We’re able to track the training progress for both our Generator and Discriminator models using Comet.ml.

We’re plotting both the accuracy and loss for our discriminator and adversarial models — the most important metrics to track here are:

  • the discriminator’s loss (see blue line on the right chart)— dis_loss
  • the adversarial model’s accuracy (see blue line on the left chart) — acc_adv

See the training progression for this experiment here.

You also want to confirm that your training process is actually using GPUs, which you can check in the Comet System Metrics tab.

You notice that our training for-loop includes code to report images from the test vector:

if i % 500 == 0:
# Visualize the performance of the generator by producing images from the test vector
images = net_generator.predict(vis_noise)
# Map back to original range
#images = (images + 1 ) * 0.5
plt.figure(figsize=(10,10))

for im in range(images.shape[0]):
plt.subplot(4, 4, im+1)
image = images[im, :, :, :]
image = np.reshape(image, [28, 28])

plt.imshow(image, cmap='gray')
plt.axis('off')

plt.tight_layout()
# plt.savefig('/home/ubuntu/cecelia/deeplearning-resources/output/mnist-normal/{}.png'.format(i))
 plt.savefig(r'output/mnist-normal/{}.png'.format(i))
 experiment.log_image(r'output/mnist-normal/{}.png'.format(i))
 plt.close('all')

Part of the reason why we want to report generated output every few steps is so that we can visually analyze how our generator and discriminator models are performing in terms of generating realistic handwritten digits and correctly classifying the generated digits as ‘real’ or ‘fake, respectively.

Let’s take a look at these generated outputs!

See the generated outputs on your own in this Comet Experiment

You can see how the Generator models starts off with this fuzzy, grayish output (see 0.png below)that doesn’t really look like the handwritten digits we expect.

As training progresses and our models’ losses decline, the generated digits become clearer and clear. Check out the generated outputs at:

Step 500:

Step 1000:

Step 1500:

And finally at Step 10,000 — you can see some samples of the GAN-generated digits in the red outlined boxes below

Once our GAN model is done training, we can even review our reported outputs as a movie in Comet’s Graphics tab (just press the play button!).

To complete the experiment, make you sure run experiment.end() to see some summary statistics around the model and GPU usage.

Iterating with your model

We could train the model longer to see how that impacts performance, but let’s try iterating with a few different parameters.

Some of the parameters we play around with are:

  • the discriminator’s optimizer
  • the learning rate
  • dropout probability
  • batch size

From Wouter’s original blog post, he mentions his own efforts with testing parameters:

I have tested both SGD, RMSprop and Adam for the optimizer of the discriminator but RMSprop performed best. RMSprop is used a low learning rate and I clip the values between -1 and 1. A small decay in the learning rate can help with stabilizing

We’ll try increasing the discriminator’s dropout probability from 0.4 to 0.5 and increasing both the discriminator’s learning rate (from 0.008 to 0.0009) and the generator’s learning rate (from 0.0004 to 0.0006). Easy to see how these changes can get out of hand and difficult to track…🤯

To create a different experiment, simply run the experiment definition cell again and Comet will issue you a new url for your new experiment! It’s nice to keep track of your experiments, so you can compare the differences:

See the difference between the two experiments’ hyperparameters. Can you spot the differences in learning rate and dropout probability that we made?

Unfortunately, our adjustments did not improve the model’s performance! In fact, it generated some funky outputs:

That’s it for this tutorial! If you enjoyed this post, feel free to share with a friend who might find it useful 😎

👉🏼Have questions or feedback? Comment below!

👉🏼Want more awesome machine learning content? Follow us on Medium!