The math behind GANs (Generative Adversarial Networks)

Source: Deep Learning on Medium

The math behind GANs (Generative Adversarial Networks)

A detailed understanding of the math behind original GANs including their limitations

1. Introduction

The Generative Adversarial Network (GAN) comprises of two models: a generative model G and a discriminative model D. The generative model can be considered as a counterfeiter who is trying to generate fake currency and use it without being caught, whereas the discriminative model is similar to police, trying to catch the fake currency. This competition goes on till the counterfeiter becomes smart enough to successfully fool the police.

Figure 1: Representation of the generator and discriminator as a counterfeiter and police, respectively. Figure from [1].

In other words,

Discriminator: The role is to distinguish between actual and generated (fake) data.

Generator: The role is to create data in such a way that it can fool the discriminator.

2. Some parameters and variables

Before we go into the derivation, let’s describe some parameters and variables.

3. Derivation of the loss function

The loss function described in the original paper by Ian Goodfellow et al. can be derived from the formula of binary cross-entropy loss. The binary cross-entropy loss can be written as,

3.1 Discriminator loss

Now, the objective of the discriminator is to correctly classify the fake and real dataset. For this, equations (1) and (2) should be maximized and final loss function for the discriminator can be given as,

3.2 Generator loss

Here, the generator is competing against discriminator. So, it will try to minimize the equation (3) and loss function is given as,

3.3 Combined loss function

We can combine equations (3) and (4) and write as,

Remember that the above loss function is valid only for a single data point, to consider entire dataset we need to take the expectation of the above equation as

which is the same equation as described in the original paper by Goodfellow et al.

4. Algorithm

Figure 2: Algorithm described in the original paper by Goodfellow et al. Figure from [2].

It can be noticed from the above algorithm that the generator and discriminator are trained separately. In the first section, real data and fake data are inserted into the discriminator with correct labels and training takes place. Gradients are propagated keeping generator fixed. Also, we update the discriminator by ascending its stochastic gradient because for discriminator we want to maximize the loss function given in equation (6).

On the other hand, we update the generator by keeping discriminator fixed and passing fake data with fake labels in order to fool the discriminator. Here, we update the generator by descending its stochastic gradient because for the generator we want to minimize the loss function given in equation (6).

5. Global Optimality of Pg = Pdata

The optimal discriminator D for any given generator G can be found by taking derivative of the loss function (equation (6)),

The above equation is very important mathematically but in reality, you cannot calculate optimal D as Pdata(x) is not known. Now, the loss for G when we have optimal D can be obtained by substituting equation (7) into loss function as,

Now, Kullback-Leibler(KL) and Jensen-Shannon(JS) divergences are given by,

Hence,

The above equation reduces to -2log2 as the Pg approaches Pdata because divergence becomes zero.

6. Limitations

The loss function derived (equation (9)) has some limitations which are described in this section.

6.1 Vanishing Gradient

The aim of optimization for equation (9) is to move Pg towards Pdata or Pr for an optimal D. JS divergence remains constant if there is no overlap between Pr and Pg (figure 3). It can be observed that JS divergence is constant and its gradient is close to 0 when the distance is more than 5, which represents that the training process does not have any influence on G (figure 4). The gradient is non-zero only when Pg and Pr have a significant overlap that means when D is close to optimal, G will face vanishing gradient problem.

Figure 3: Illustration of training progress for a GAN. Two normal distributions are used here for visualization. Given an optimal D, the objective of GANs is to update G in order to move the generated distribution Pg (red) towards the real distribution Pr (blue) (G is updated from left to right in this figure. Left: initial state, middle: during training, right: training converging). However, JS divergence for the left two figures are both 0.693 and the figure on the right is 0.336, indicating that JS divergence does not provide sufficient gradient at the initial state. Figure from [3].
Figure 4: JS divergence and gradient change with the distance between Pr and Pg. The distance is the difference between the two distribution means. Figure from [3].

This problem can be countered by modifying the original loss function of G as,

6.2 Mode Collapse

During training, the generator may get stuck into a setting where it always produces the same output. This is called mode collapse. This happens because the main aim of G was to fool D not to generate diverse output. The math behind this is a bit involved and will be discussed in future articles.

Btw, this is my first story and I hope you enjoyed it.

7. References

1] Atienza, Rowel. Advanced Deep Learning with Keras: Apply deep learning techniques, autoencoders, GANs, variational autoencoders, deep reinforcement learning, policy gradients, and more. Packt Publishing Ltd, 2018.

2] Goodfellow, Ian, et al. “Generative adversarial nets.” Advances in neural information processing systems. 2014.

3] Wang, Zhengwei, Qi She, and Tomas E. Ward. “Generative Adversarial Networks: A Survey and Taxonomy.” arXiv preprint arXiv:1906.01529 (2019).