Original article was published by Kailash S on Artificial Intelligence on Medium
What is all about Generative Adversarial Networks (GANs)?
A brief overview of GAN architecture and working mechanism
This article will give you a fair idea of Generative Adversarial Networks(GANs), its architecture, and the working mechanism. Supporting code snippets are also used to give a brief overview of how GANs can be implemented programmatically.
Generative Adversarial Networks (GANs) are deep-learning-based generative models that use two neural networks, competing for one against the other (thus the “adversarial”) in order to generate new, synthetic instances of the real data. They are used widely in data augmentation where data can be images, voice or even videos.
GAN is made up of 2 neural networks which are Generator and Discriminator.
- The generator is the model that is used to generate realistic (and not real) examples from the problem domain.
- The discriminator is the model that is used to classify examples as real (from the domain) or fake (generated)
The generated realistic images will be mixed alongside a stream of images taken from the actual, ground-truth dataset and sent to the discriminator to evaluate each image to be real or fake. The more the discriminator predicts the fake images as real, the better the generator model is.
Steps involved in one epoch of training the Generator model
- Generate a random sample noise
- From the sample random noise, generate a realistic problem domain image
- Use Discriminator to evaluate the generated image to be ‘real’ or ‘fake’
- Obtain gradients and loss of both generator and discriminator
- Use the discriminator’s gradients to fine-tune the generator model in the following epoch
Both the models are trained in an alternative fashion until the discriminator model is fooled about half the time, meaning the generator model is generating more realistic examples.
The generator is basically a neural network model. Neural networks need input in some format to start with. So, for the generator, we begin with a random noise (usually normally distributed). This noise will be turned into a meaningful output by the generator model. By using a random noise every time during training, we make sure that a wide variety of images are being generated to fool the discriminator. Below is the code snippet to create a simple generator model
num_features = 100
generator = keras.models.Sequential(
keras.layers.Dense(7*7*128, input_shape = [num_features]),
keras.layers.Conv2DTranspose(64, (5,5), (2,2), padding = 'same', activation = 'selu'),
keras.layers.Conv2DTranspose(1, (5,5), (2,2), padding = 'same', activation = 'tanh' )
Transpose convolution is the core logic used in the Generator model to convert smaller input(noise) to larger input(realistic image). Scaled Exponential Linear Unit (selu) is the state of the art activation function used for GANs.
‘same’ padding will not drop any elements and adds zeros (zero paddings) evenly in left and right ends (except if the amount of columns to be added is odd, it will add the extra column to the right)
A discriminator model can be considered as a binary classification neural network model where the classes are real and fake.
At the end of each epoch, the discriminator tells the generator how to reformat the current output i.e.., generated image to be more realistic. This technically is being done using the gradients returned by the discriminator. These gradients will be added to the weights of each layer in the generator model so that the generator can learn how to generate a more realistic image for the discriminator’s eyes. Below is the code snippet to create a simple discriminator model
discriminator = keras.models.Sequential(
keras.layers.Conv2D(64, (5,5), (2,2), padding = 'same', input_shape = [28,28,1]),
keras.layers.Conv2D(128, (5,5), (2,2), padding = 'same'),
keras.layers.Dense(1, activation = 'sigmoid')
Since the discriminator performs binary classification, sigmoid activation is used in the final layer.
Why training the Generator and discriminator alternatively?
While training the discriminator, the generator will be kept idle, since the discriminator will be figuring out how to distinguish between the real examples and the synthetic data which were generated by the generator. If the generator is allowed to train simultaneously, the discriminator will not be able to decide the distinguishing features between real and fake images as the generator’s output features keep on changing simultaneously. This will in return lead to a poor generator model.
Similarly, while training the generator, the discriminator will be kept idle. Since the generator is fine-tuned by back-propagating the gradients of the discriminator’s output, if the discriminator is allowed to train simultaneously, the gradients will also keep on changing making the generator model impossible to converge. Below is the code snippet to illustrate how training of GAN happens.
for epoch in tqdm(range(total_epochs)): total_epochs))
for x_batch in dataset:
noise = tf.random.normal(shape = [batch_size, num_features])
gen_images = generator(noise)
disc_input = tf.concat([gen_images, x_batch], axis = 0)
disc_labels = tf.constant[[0.]]*batch_size +
discriminator.trainable = True
gen_labels = tf.constant( [[1.]]*batch_size)
discriminator.trainable = False
For each epoch, the dataset will be divided into batches as specified. For each batch, some random noise of specified dimensions will be generated. Noise will be converted into meaningful output by passing the random noise into the generator model. Input for the discriminator will be created by shuffling the real and fake images. Labels for the same will be defined. Firstly, the discriminator will be trained on the input followed by the generator based on the discriminator gradients/feedback.
(NOTE: While training one model, another one will be kept idle by assigning the trainable attribute of the corresponding model to False)
When to stop the training process of GAN?
As the generator improves with training, the discriminator performance gets worse, meaning the discriminator will often be fooled by the generator. If the generator succeeds perfectly, then the discriminator has a 50% accuracy.
If the GAN continues training past the point, the discriminator starts giving irrelevant feedback to the generator which will result in the downfall of the generator’s performance. Hence, it is impossible to fool the discriminator for all fake images. For a GAN, convergence is often a fleeting, rather than stable, state.
WRAPPING IT UP…
Hope this blog post would have enlightened everyone with the underlying concepts of GANs. To help you guys in taking this forward to implementation, I have added the GitHub repository, where GAN has been implemented over the “Fashion MNIST image dataset” to generate synthetic images of clothing items. Try implementing this for getting further in-depth understanding.
Thanks for reading this post until the very end. If you found this blog to be useful, cheer me up with your claps and leave a comment below if you have any questions.