Source: Deep Learning on Medium
I’ve used PyTorch for implementing them. Yea, I could already hear “boo” from Keras/Tensorflow people. Keep calm and take the opportunity to get comfortable with it, just as I did.
Heard of MNIST dataset?. Yea, we will be using that. Images are gray-scaled with a shape of 28×28 pixels. Let’s load the dataset from
torch.datasetsmodule. The images are flattened and the values are normalized to 0–1 using
DataLoaderfunction helps to slice the training data in batches.
from torch.datasets import MNIST
from torchvision import transformstrans = transforms.Compose([transforms.ToTensor(), torch.flatten])
traindata = MNIST(root='./data', transform=trans, train=True, download=True)
trainloader = torch.utils.data.DataLoader(traindata, batch_size=6000, shuffle=True)
We will be using a fully connected network with the following configuration. Don’t get perplexed with the layers and neurons. You can remove/add layers as you wish. This configuration has some best practices, which I will also explain later. The input to Generator is noise of any arbitrary value. Here, I have chosen 128 neurons. The output of Generator has to match the shape of the training data value, which is 784. The Discriminator gets the input of 784 neurons and outputs a single value, whether it is real(1) or fake(0).
Translating the whole Architecture into network is fairly straightforward. Implement the
nn.Modulesclass and define the
forwardfunction. The best thing about pytorch is the functionality of autograd. What it means is, we don’t have to do the math behind the backprop. The gradients for all the neurons are automatically calculated.
Cool. Once both the Classes are defined, we could just instantiate them. And don’t you forget to define the cost function and the optimization algorithm to be used.
discriminator = Discriminator()
generator = Generator()criterion = nn.BCELoss()
discrim_optim = optim.Adam(discriminator.parameters(), lr= 0.0002)
generat_optim = optim.Adam(generator.parameters(), lr=0.0002)
Why this BCELoss and Adam optimization?
Vanilla GAN uses Minimax algorithm. Minimax loss is computed by the log loss of both generator and discriminator predicted probabilities. BCELoss is Binary cross entropy loss, which does the log loss of probabilities. One can use two different loss functions for training Generator and Discriminator, but here we use a single loss function for both. And for weight updates we use Adam optimizer, because everybody uses it. haha, jk. You could try with other optimizers like SGD, Adagrad.
So, The model design is complete. Yay. Now, let’s train the model.
Whatever the confusion that may occur, it occurs here. Let’s go step by step. So we have to train both the discriminator as well as the generator simultaneously. That fundamentally means the following steps.
- Forward pass the discriminator
- Backprop the discriminator error
- Update discriminator weights
- Forward pass the generator
- Backprop the generator error
- Update the generator weights
# Noise input for generator
return torch.randn(x,y)for epoch in range(2000):
for pos_samples in trainloader:
# Training Discriminator network
pos_predicted = discriminator(pos_samples)
pos_error = criterion(pos_predicted, torch.ones(batches,1)) neg_samples = generator(noise(batches, 128))
neg_predicted = discriminator(neg_samples)
neg_error = criterion(neg_predicted, torch.zeros(batches,1)) discriminator_error = pos_error + neg_error
# Training generator network
gen_samples = generator(noise(batches, 128))
gen_predicted = discriminator(gen_samples)
generator_error = criterion(gen_predicted, torch.ones(batches, 1))
That’s it. Done!
Woah Woah. Slow down. I have lots of questions!
Why do you index only the data,
pos_samples. What happened to the labels of the training data?
– Nicely spotted. In Vanilla GAN, we don’t care about the labels. We basically give all the training data to train the networks irrespective of the class it is from. So, the Generator network has to fit the weights to reproduce all the variations for different noise inputs. That being said, there are several variations in GANs, which takes into consideration of labels like the Auxiliary GANs.
What is this zero_grad() for the optimizer?
– For each epoch, we want the gradients to be zero so that the gradients computed during each backpropagation may occur without residual gradients in the neurons. Without zero_grad(), the gradients will be accumulated for each epoch, which is useful in networks like RNNs.
How can you just add both errors and execute the backward()?
– It’s so pythonic, isn’t it? It’s the autograd module in pytorch. It takes care of back propagating both the errors.
All right, so now how do you know when the network is trained?. Which cost function should I have to observe?
– As previously stated, convergence is an interesting problem in GAN. Strictly speaking, when both the Discriminator and the Generator reach the Nash equilibrium, the GAN is said to be trained. Since GAN is a minimax problem when one network maximizes its cost function the other one tries to minimize it. And we are training both to improve. Nash equilibrium states that the agent doesn’t change its course of action irrespective of other agent’s decision. During training, one of the network trains from the other one right, but when it reaches a point where the Discriminator or Generator doesn’t get better irrespective of the other network’s decision, it has reached the Nash equilibrium. In practical terms, given an equal set of real and fake images, the Discriminator will detect every real and fake images as real, thus the prediction accuracy will be 50%.
There are few best practices suggested for a better model and faster convergence. I deliberately saved this at the end as some of this discussion might change the code and will lead to confusion if explained in the main content.
- In the final layer, use sigmoid activation functions for Discriminator and tanh function for Generator. In this case, the output of Generator will be in the range (-1,1). Thus, we also have to normalise the training data to this range (-1,1)
- Instead of training the real images for 1 and fake images for 0, train them with 0.98 and 0.02 or similar values
- For a quick and dirty check to see if your GAN setup is working, limit the training data to a single class and check out how they perform. On MNIST dataset of 10 classes, it might take several hours for some undistorted images, so better check if the configuration is working good for a confined set.
- Since Generator needs more time to train than Discriminator, using dropout layers in Discriminator, stops the overfitting.