Source: Deep Learning on Medium
The math behind GANs (Generative Adversarial Networks)
A detailed understanding of the math behind original GANs including their limitations
The Generative Adversarial Network (GAN) comprises of two models: a generative model G and a discriminative model D. The generative model can be considered as a counterfeiter who is trying to generate fake currency and use it without being caught, whereas the discriminative model is similar to police, trying to catch the fake currency. This competition goes on till the counterfeiter becomes smart enough to successfully fool the police.
In other words,
Discriminator: The role is to distinguish between actual and generated (fake) data.
Generator: The role is to create data in such a way that it can fool the discriminator.
2. Some parameters and variables
Before we go into the derivation, let’s describe some parameters and variables.
3. Derivation of the loss function
The loss function described in the original paper by Ian Goodfellow et al. can be derived from the formula of binary cross-entropy loss. The binary cross-entropy loss can be written as,
3.1 Discriminator loss
Now, the objective of the discriminator is to correctly classify the fake and real dataset. For this, equations (1) and (2) should be maximized and final loss function for the discriminator can be given as,
3.2 Generator loss
Here, the generator is competing against discriminator. So, it will try to minimize the equation (3) and loss function is given as,
3.3 Combined loss function
We can combine equations (3) and (4) and write as,
Remember that the above loss function is valid only for a single data point, to consider entire dataset we need to take the expectation of the above equation as
which is the same equation as described in the original paper by Goodfellow et al.
It can be noticed from the above algorithm that the generator and discriminator are trained separately. In the first section, real data and fake data are inserted into the discriminator with correct labels and training takes place. Gradients are propagated keeping generator fixed. Also, we update the discriminator by ascending its stochastic gradient because for discriminator we want to maximize the loss function given in equation (6).
On the other hand, we update the generator by keeping discriminator fixed and passing fake data with fake labels in order to fool the discriminator. Here, we update the generator by descending its stochastic gradient because for the generator we want to minimize the loss function given in equation (6).
5. Global Optimality of Pg = Pdata
The optimal discriminator D for any given generator G can be found by taking derivative of the loss function (equation (6)),
The above equation is very important mathematically but in reality, you cannot calculate optimal D as Pdata(x) is not known. Now, the loss for G when we have optimal D can be obtained by substituting equation (7) into loss function as,
Now, Kullback-Leibler(KL) and Jensen-Shannon(JS) divergences are given by,
The above equation reduces to -2log2 as the Pg approaches Pdata because divergence becomes zero.
The loss function derived (equation (9)) has some limitations which are described in this section.
6.1 Vanishing Gradient
The aim of optimization for equation (9) is to move Pg towards Pdata or Pr for an optimal D. JS divergence remains constant if there is no overlap between Pr and Pg (figure 3). It can be observed that JS divergence is constant and its gradient is close to 0 when the distance is more than 5, which represents that the training process does not have any influence on G (figure 4). The gradient is non-zero only when Pg and Pr have a significant overlap that means when D is close to optimal, G will face vanishing gradient problem.
This problem can be countered by modifying the original loss function of G as,
6.2 Mode Collapse
During training, the generator may get stuck into a setting where it always produces the same output. This is called mode collapse. This happens because the main aim of G was to fool D not to generate diverse output. The math behind this is a bit involved and will be discussed in future articles.
Btw, this is my first story and I hope you enjoyed it.
1] Atienza, Rowel. Advanced Deep Learning with Keras: Apply deep learning techniques, autoencoders, GANs, variational autoencoders, deep reinforcement learning, policy gradients, and more. Packt Publishing Ltd, 2018.
2] Goodfellow, Ian, et al. “Generative adversarial nets.” Advances in neural information processing systems. 2014.
3] Wang, Zhengwei, Qi She, and Tomas E. Ward. “Generative Adversarial Networks: A Survey and Taxonomy.” arXiv preprint arXiv:1906.01529 (2019).