How to Train GAN Models in Practice

Original article was published by Abrar Ahmed on Artificial Intelligence on Medium


How to Train GAN Models in Practice

Photo by Clarisse Croset on Unsplash

Conditional GAN:

Extension of GAN to conditional model with information of car labels provided as input

Instead of learning distinct features from each image label data, the model learned complex new features which were a combination of distinct features from multiple image classes

This resulted in output cars having color compositions which were a mix of multiple car labels

Source: (wwt.com)

Using labels (CGAN)

Many datasets come with labels for the object type of their samples. Training GAN is already hard. So any extra help in guiding the GAN training can improve the performance a lot. Adding the label as part of the latent space z helps the GAN training. Below is the data flow used in CGAN to take advantage of the labels in the samples.

Source: (towardsdatascience.com)

System requirements

Python 3.5+

TensorFlow

NumPy

SciPy 0.19.1

Pillow

Pandas

The practical implementation of the GAN loss function and model updates is straightforward.

We will look at examples using the Keras library.

We can implement the discriminator directly by configuring the discriminator model to predict a probability of 1 for real images and 0 for fake images and minimizing the cross-entropy loss, specifically the binary cross-entropy loss.

For example, a snippet of our model definition with Keras for the discriminator might look as follows for the output layer and the compilation of the model with the appropriate loss function.

Source: (machinelearningmastery.com)

How to Implement the GAN Training Algorithm

The GAN training algorithm involves training both the discriminator and the generator model in parallel.

The algorithm is summarized below, taken from the original 2014 paper by Goodfellow, et al. titled “ Generative Adversarial Networks Summary of the Generative Adversarial Network Training Algorithm.Taken from: Generative Adversarial Networks.

Let’s take some time to unpack and get comfortable with this algorithm.

The outer loop of the algorithm involves iterating over steps to train the models in the architecture. One cycle through this loop is not an epoch: it is a single update comprised of specific batch updates to the discriminator and generator models.

An epoch is defined as one cycle through a training dataset, where the samples in a training dataset are used to update the model weights in mini-batches. For example, a training dataset of 100 samples used to train a model with a mini-batch size of 10 samples would involve 10 mini batch updates per epoch. The model would be fit for a given number of epochs, such as 500.

This is often hidden from you via the automated training of a model via a call to the fit() function and specifying the number of epochs and the size of each mini-batch.

In the case of the GAN, the number of training iterations must be defined based on the size of your training dataset and batch size. In the case of a dataset with 100 samples, a batch size of 10, and 500 training epochs, we would first calculate the number of batches per epoch and use this to calculate the total number of training iterations using the number of epochs.

For example:

… batches_per_epoch = floor(dataset_size / batch_size) total_iterations = batches_per_epoch * total_epochs In the case of a dataset of 100 samples, a batch size of 10, and 500 epochs, the GAN would be trained for floor(100 / 10) * 500 or 5,000 total iterations.

Next, we can see that one iteration of training results in possibly multiple updates to the discriminator and one update to the generator, where the number of updates to the discriminator is a hyperparameter that is set to 1.

The training process consists of simultaneous SGD. On each step, two minibatches are sampled: a minibatch of x values from the dataset and a minibatch of z values drawn from the model’s prior over latent variables. Then two gradient steps are made simultaneously … NIPS 2016 Tutorial: Generative Adversarial Networks We can therefore summarize the training algorithm with Python pseudocode as follows:

# gan training algorithm def train_gan(dataset, n_epochs, n_batch): # calculate the number of batches per epoch batches_per_epoch = int(len(dataset) / n_batch) # calculate the number of training iterations n_steps = batches_per_epoch * n_epochs # gan training algorithm for i in range(n_steps): # update the discriminator model # … # update the generator model # …

An alternative approach may involve enumerating the number of training epochs and splitting the training dataset into batches for each epoch.

Updating the discriminator model involves a few steps.

First, a batch of random points from the latent space must be selected for use as input to the generator model to provide the basis for the generated or ‘fake‘ samples. Then a batch of samples from the training dataset must be selected for input to the discriminator as the ‘real‘ samples.

Next, the discriminator model must make predictions for the real and fake samples and the weights of the discriminator must be updated proportional to how correct or incorrect those predictions were. The predictions are probabilities and we will get into the nature of the predictions and the loss function that is minimized in the next section. For now, we can outline what these steps actually look like in practice.

We need a generator and a discriminator model, e.g. such as a Keras model. These can be provided as arguments to the training function.

Next, we must generate points from the latent space and then use the generator model in its current form to generate some fake images. For example:

… # generate points in the latent space z = randn(latent_dim * n_batch) # reshape into a batch of inputs for the network z = x_input.reshape(n_batch, latent_dim) # generate fake images fake = generator.predict(z) Note that the size of the latent dimension is also provided as a hyperparameter to the training algorithm.

We then must select a batch of real samples, and this too will be wrapped into a function.

… # select a batch of random real images ix = randint(0, len(dataset), n_batch) # retrieve real images real = dataset[ix] The discriminator model must then make a prediction for each of the generated and real images and the weights must be updated.

# gan training algorithm def train_gan(generator, discriminator, dataset, latent_dim, n_epochs, n_batch): # calculate the number of batches per epoch batches_per_epoch = int(len(dataset) / n_batch) # calculate the number of training iterations n_steps = batches_per_epoch * n_epochs # gan training algorithm for i in range(n_steps): # generate points in the latent space z = randn(latent_dim * n_batch) # reshape into a batch of inputs for the network z = z.reshape(n_batch, latent_dim) # generate fake images fake = generator.predict(z) # select a batch of random real images ix = randint(0, len(dataset), n_batch) # retrieve real images real = dataset[ix] # update weights of the discriminator model # … # update the generator model # …

Next, the generator model must be updated.

Again, a batch of random points from the latent space must be selected and passed to the generator to generate fake images, and then passed to the discriminator to classify.

… # generate points in the latent space z = randn(latent_dim * n_batch) # reshape into a batch of inputs for the network z = z.reshape(n_batch, latent_dim) # generate fake images fake = generator.predict(z) # classify as real or fake result = discriminator.predict(fake) The response can then be used to update the weights of the generator model.

# gan training algorithm def train_gan(generator, discriminator, dataset, latent_dim, n_epochs, n_batch): # calculate the number of batches per epoch batches_per_epoch = int(len(dataset) / n_batch) # calculate the number of training iterations n_steps = batches_per_epoch * n_epochs # gan training algorithm for i in range(n_steps): # generate points in the latent space z = randn(latent_dim * n_batch) # reshape into a batch of inputs for the network z = z.reshape(n_batch, latent_dim) # generate fake images fake = generator.predict(z) # select a batch of random real images ix = randint(0, len(dataset), n_batch) # retrieve real images real = dataset[ix] # update weights of the discriminator model # … # generate points in the latent space z = randn(latent_dim * n_batch) # reshape into a batch of inputs for the network z = z.reshape(n_batch, latent_dim) # generate fake images fake = generator.predict(z) # classify as real or fake result = discriminator.predict(fake) # update weights of the generator model # …

It is interesting that the discriminator is updated with two batches of samples each training iteration whereas the generator is only updated with a single batch of samples per training iteration.

Now that we have defined the training algorithm for the GAN, we need to understand how the model weights are updated. This requires understanding the loss function used to train the GAN.

Source: (machinelearningmastery.com)

How to Train a GAN? Tips and tricks to make GANs work

While research in Generative Adversarial Networks (GANs) continues to improve the fundamental stability of these models, we use a bunch of tricks to train them and make them stable day to day.

Here are a summary of some of the tricks.

If you find a trick that is particularly useful in practice, please open a Pull Request to add it to the document. If we find it to be reasonable and verified, we will merge it in.

Source: (github.com)

Best Practices for Deep Convolutional GANs

Downsample Using Strided Convolutions

Upsample Using Strided Convolutions

Use LeakyReLU

Use Batch Normalization

Use Gaussian Weight Initialization

Use Adam Stochastic Gradient Descent

Scale Images to the Range [-1,1]

Source: (machinelearningmastery.com)

Soumith Chintala’s GAN Hacks

Use a Gaussian Latent Space

Separate Batches of Real and Fake Images

Use Label Smoothing

Use Noisy Labels

Source: (machinelearningmastery.com)

Heuristics for Training Stable GANs

GANs are difficult to train.

At the time of writing, there is no good theoretical foundation as to how to design and train GAN models, but there is established literature of heuristics, or “hacks,” that have been empirically demonstrated to work well in practice.

As such, there are a range of best practices to consider and implement when developing a GAN model.

Perhaps the two most important sources of suggested configuration and training parameters are:

Alec Radford, et al’s 2015 paper that introduced the DCGAN architecture.

Soumith Chintala’s 2016 presentation and associated “GAN Hacks” list.

In this tutorial, we will explore how to implement the most important best practices from these two sources.

Source: (machinelearningmastery.com)