Original article was published on Artificial Intelligence on Medium
GAN (Generative Adversarial Network)
The idea behind GANs is that you have two networks, a generator G and a discriminator D, competing against each other. The generator makes “fake” data to pass to the discriminator. The discriminator also sees real training data and predicts if the data it’s received is real or fake.
The generator is trained to fool the discriminator, it wants to output data that looks as close as possible to real, training data.
The discriminator is a classifier that is trained to figure out which data is real and which is fake.
What ends up happening is that the generator learns to make data that is indistinguishable from real data to the discriminator.
The general structure of a GAN is shown in the diagram above, using MNIST images as data. The latent sample is a random vector that the generator uses to construct its fake images. This is often called a latent vector and that vector space is called latent space. As the generator trains, it figures out how to map latent vectors to recognizable images that can fool the discriminator.
If you’re interested in generating only new images, you can throw out the discriminator after training. In this notebook, I’ll show you how to define and train these adversarial networks in PyTorch and generate new images!
visualize the data:
Define the Model
A GAN is comprised of two adversarial networks, a discriminator and a generator.
The discriminator network is going to be a pretty typical linear classifier. To make this network a universal function approximator, we’ll need at least one hidden layer, and these hidden layers should have one key attribute:
All hidden layers will have a Leaky ReLu activation function applied to their outputs.
We should use a leaky ReLU to allow gradients to flow backward through the layer unimpeded. A leaky ReLU is like a normal ReLU, except that there is a small non-zero output for negative input values.
We’ll also take the approach of using a more numerically stable loss function on the outputs. Recall that we want the discriminator to output a value 0–1 indicating whether an image is real or fake.
We will ultimately use BCEWithLogitsLoss, which combines a
sigmoidactivation function and and binary cross entropy loss in one function.
So, our final output layer should not have any activation function applied to it.
The generator network will be almost exactly the same as the discriminator network, except that we’re applying a tanh activation function to our output layer.
The generator has been found to perform the best with tanh for the generator output, which scales the output to be between -1 and 1, instead of 0 and 1.
Recall that we also want these outputs to be comparable to the real input pixel values, which are read in as normalized values between 0 and 1.
So, we’ll also have to scale our real input images to have pixel values between -1 and 1 when we train the discriminator.
I’ll do this in the training loop, later on.
Discriminator and Generator Losses
Now we need to calculate the losses.
For the discriminator, the total loss is the sum of the losses for real and fake images,
d_loss = d_real_loss + d_fake_loss.
Remember that we want the discriminator to output 1 for real images and 0 for fake images, so we need to set up the losses to reflect that.
The losses will by binary cross entropy loss with logits, which we can get with BCEWithLogitsLoss. This combines a
sigmoid activation function and and binary cross entropy loss in one function.
For the real images, we want
D(real_images) = 1. That is, we want the discriminator to classify the the real images with a label = 1, indicating that these are real. To help the discriminator generalize better, the labels are reduced a bit from 1.0 to 0.9. For this, we’ll use the parameter
smooth; if True, then we should smooth our labels. In PyTorch, this looks like
labels = torch.ones(size) * 0.9
The discriminator loss for the fake data is similar. We want
D(fake_images) = 0, where the fake images are the generator output,
fake_images = G(z).
The generator loss will look similar only with flipped labels. The generator’s goal is to get
D(fake_images) = 1. In this case, the labels are flipped to represent that the generator is trying to fool the discriminator into thinking that the images it generates (fakes) are real!
Training will involve alternating between training the discriminator and the generator. We’ll use our functions
fake_loss to help us calculate the discriminator losses in all of the following cases.
- Compute the discriminator loss on real, training images
- Generate fake images
- Compute the discriminator loss on fake, generated images
- Add up the real and fake loss
- Perform backpropagation + an optimization step to update the discriminator’s weights
- Generate fake images
- Compute the discriminator loss on fake images, using flipped labels!
- Perform backpropagation + an optimization step to update the generator’s weights
As we train, we’ll also print out some loss statistics and save some generated “fake” samples.
Generator samples from training
Here we can view samples of images from the generator. First we’ll look at the images we saved during training.
These are samples from the final training epoch. You can see the generator is able to reproduce numbers like 1, 7, 3, 2. Since this is just a sample, it isn’t representative of the full range of images this generator can make.
hope you find this article productive to get your hands dirty with gan.