Source: Deep Learning on Medium
Hmmm how do I start writing this story ???
Love you leCun.
I have been writing stories about a lot of different algorithms so far and all are discriminative algorithms, but this story is all about Generative models so let me quickly detail you about what the differences are
Discriminative vs Generative
we all love (x, y) pairs, x being the inputs/features ( images,text,speech and etc..) y being the targets/labels.
(x, y) → (features, labels) / (inputs, targets)
let’s think about classification in supervised way
Given inputs we want to build a model that can classify the inputs to the corresponding targets as correct as possible.
Eg →Given these features this mail is SPAM or Not ??
It learns the conditional probability distribution.
p(y|x) –“the probability of
x should be maximum.
so the model learns to predict the labels from the data other words, it learns the decision boundary between classes.
it does not really care about “How the training data is generated/distributed.
Ex: logistic regression, SVM’s
Given inputs we want to build a model that can understand the inputs to generate similar inputs and it’s labels from the targets .
Eg → Assume this mail is SPAM what likely are these features ??
it learns the joint probability distribution.
P(x,y) = p(y|x).p(x)
the model has to learn
it cares about “How the training data is generated/distributed. it cares about How to get
Ex: Naive bayes
Okay I hope you get some idea so let’s move on
GAN’s are generative models that try to learn the model to generate the input distribution as realistic as possible.
Gan’s end goal is to predict features given a label, Instead of predicting a label given features.
Eg: if we take cat images being x , then the GAN’s goal is to learn a model that can produce the realistic or believable cat images from the training data x.
A general adversarial network(GAN) consists of 2 neural networks.
- A neural network called “Generator ” which generates new data points from some random uniform distribution. The goal is to produce the similar type of fake results from inputs.
- while another neural network called “Discriminator” which identifies the fake data produced by Generator from the real data.
The main idea for GAN’s is to train 2 different networks to compete with each other with 2 different objective functions.
→The generator G tries to fool the discriminator into believing that the input sent by generator is real
→While the discriminator D gives a slap to the generator by identifying that this is fake.
→Then after taking the slap from the discriminator D , the generator G learns to produce similar type of training data inputs.
→ And this process is repeated for a while or until Nash equilibrium found.
This process is called Adversarial Training.
GAN’s training process step by step.
- We take some noise from random distribution , then we feed it to the Generator G to produce the fake x (label y=0) → (x,y) input-label pair.
- We take this fake pair and the real pair x (label y =1) and feed it to the Discriminator D alternatively.
- The discriminator D is a binary classification neural network so it calculates the loss for both fake x and real x and combine them as the final loss as D loss.
- The generator G also calculates the loss from it’s noise as G loss since each network has a different objective function.
- The two losses go back to their respective networks to learn from the loss (adjusting the parameters w r t the loss)
- Apply any optimization algorithm (Grad descent, ADAM, RMS prop, etc..) Repeat this process for certain no of epochs or as long as you wish.
Each network has goals so these two networks pit against each other during the training.
The generator G gets stronger and stronger at generating the real type of results and the discriminator D also gets stronger and stronger at identifying which one is real , which one is fake.
well, from game theory by Ian good fellow and his folks.
Okay at least I hope you have some idea conceptually let’s get into the Math.
Gan’s objective function
As we know the discriminator is a binary classifier so when we feed the real data , the model should produce high probability for the real data and low probability for fake data( generator’s output)
so let me define the variables and functions
as you can see the D(x), D(G(z)) give a score between 0 and 1,
we want to build a model (discriminator) that maximizes the real data while minimizing the fake data.
and G(z) gives the same shape of the real input(ex: if image of 10*10 is the real input then G(z) produces the same shape but of course it’s noisy.
and also we want to build a model(generator) that maximizes the fake data.
here are the equations to paint a picture.
as you can see the discriminator network runs twice (one for real , one for fake) before it calculates the final loss while generator runs only once.
Once we got these two losses, we calculate the gradients w r t their parameters and back propagate through their networks independently .
From the paper by Ian here is the final equation in terms of expectation.
D and G play the following two-player mini-max game with value function V (G, D):
Note: I strongly recommend you to check the paper right after this ( I am adding snippet from the paper for now down).
well if any confusions still there let’s start coding to understand this better.
so let’s build a simple one hidden layer neural network to understad the whole idea.
I will explain the code step by step (first explaination followed by the code snippet)
- Define real data X and fake data Z placeholders (here I took the MNIST dataset so the input is 784 pixel values thus shape =[None,784]
2. Building the one hidden layer network for both generator G and discriminator D
3. Pass some random noise data to the generator to produce the fake data and pass that fake data to the discriminator D also pass the real data to the discriminator D seperately.
4. Calculate the loss for real D, fake D and fake G.
Note: observe the labels we are giving above.
→for Real data X we give 1’s as labels and for fake data Z we give 0’s as labels and we apply cross entropy loss to both the logits to calculate the final D loss.
→at generator G we take the same fake logits but here we give 1’s as labels which is complete opposite for D_loss_fake variable.
The below snippet also does the same as above snippet.
We can either use above cross entropy loss code cell or this one,it does not really make difference in the networks.
5. Optimizers and Respective parameters for networks to backpropagate.
6. Before I start training I just check how the generator is giving the outputs.
7. This is the training snippet with the batch size of 128
here is the training video
8. After 7 mins training on my laptop (16 GB, i5 CPU, Ubuntu, HP pavilion) , here are the results the generator has learned.
I only have main code snippets here Full notebook you can find on my github here.
of course the results are not accurate since we have smaller networks and less no of iterations but this gives me the full inuition about how to train GAN’s.
Note: I usually don’t focus much on coding which is why I just explained lightly but strongly recommend you to play with the code ( try different things to get full intuition).
Okay so far we understand the idea, now let’s talk about some of the problems you may or may not find.
Since Gan’s are invented ,There has been a lot of problems in training GAN’s successfully, and a lot of researchers explored a lot of ways to improve the Gans training.
infact “Improving GAN’s training is a very hot research topic”.
I am going to mention very often one’s only,
- Mode collapse
This is a state where the model Generator G gets stuck at a point with producing limited varities of samples or one sample repeatedly during or after training the GAN.
During the GAN training when the discriminator is not really forcing more diversity in the generatorG, the generator fails to learn the representation of the complex real-word data.
Remember the generator ‘s goal is to trick the discriminator D into thinking that the outputs by G are real,
if generator G produces one sample which is realstic to the original data, then the discriminator finds it hard to distinguish between them.
At this time the generatorG keeps on producing the same image over and over This leads to “Complete mode collapse”.
if the generator G learns very few properties and produces few varities of samples then that is “Partial mode collapse”.
Mode collapse happens quite often and there are some ways to prevent it from happening, #willdiscusshortly.
2. Vanishing gradient
This is a very often problem we see in deep neural networks in general, the same problem gets stronger here because the gradient at Discriminator not only goes back to Discriminator network but also it goes back to Generator network as feedback.
Because of it there is no stability in training GAN’s.
→if the discriminator D gets stronger quickly (say D(x)= 1 , D(G(z)) =0 ), at generator G → log(1 — D(G(z))) = log(1–0) = 0
then the gradient of the loss function is 0 , then the learning is stopped.
→ if the discriminator D gets too weekly , then the generator G does not have good feed back so the loss represent nothing much.
Moral: Don’t train D too good or too poor.
3. Hard to find “Nash equilibrium”
This is the optimal point in the game for both generator G and discriminator D
it’s really hard to find because this is a non coperative game where two players push each other as hard as possible.
Nash equilibrium happens when one player does not change his/her actions regardless of what the other is doing.
Here the kid is neither lossing nor winning.
at the end we want this.(the discriminsator D is the kid, the generator G is the sumo). # don’t know if the analogy makes sense to you. #
but this is really hard to achieve.
From this paper of (Improving gans) here are some points, tests and results
→Training GANs consists in finding a Nash equilibrium to a two-player non-cooperative game.
→Each player wishes to minimize its own cost function, J (D) (θ (D) , θ (G) ) for the discriminator and J (G) (θ (D) , θ (G) ) for the generator.
→ A Nash equilibirum is a point (θ (D) , θ (G) ) such that J (D) is at a minimum with respect to θ (D) and J (G) is at a minimum with respect to θ (G)
→a modification to θ (D) that reduces J (D) can increase J (G) , and a modification to θ (G) that reduces J (G) can increase J (D) .
Gradient descent thus fails to converge for many games.
For example, when one player minimizes xy with respect to x and another player minimizes −xy with respect to y,
f(x) = xy, f(y) = -xy
∂f/∂x=y and ∂f/∂y=−x
x → x−α⋅y and y → y+α⋅x ( α is learning rate)
gradient descent enters a stable orbit, rather than converging to x = y = 0, the desired equilibrium point.
4. No proper evaluation metric
As we have seen above in the code, we don’t really know when to stop the training as there is no proper evaluation metric in training GAN’s.
Visual inspection is required, a lot of people do that it when training GAN’s.
The losses don’t really tell much in GAN’s unlike other deep learning algorithms.
Due to this, often we end up not having a good GAN model.
Alright I am gonna stop right here, so far we discued a little only and there is a lot to cover in Gan’s research.
Here are things that I will cover in next stories deeply
- Different types of GAN’s and it’s concepts with Math deeply.
(Info GAN, Auxillary GAN, DCGAN, CycleGAN,CGAN, SRGAN,WassersteinGAN etc..)
2. One story about improving GAN’s (advanced)
3. Image to Image and Video to Video translations(supervised and unsupervised)
4. Research Idea’s and work about GAN’s
Remember GAN’s are a hot research topic , there is so much we can explore, as explorers , let’s explore the AI with GAN’s.
So see you in next story…! have a great day or night..!
More Resources and References
Ian Goodfellow’s talk at Stanford about adversarial examples
Oreilly’s tutorial : https://www.oreilly.com/learning/generative-adversarial-networks-for-beginners