Original article was published on Artificial Intelligence on Medium
Getting Started with GANs Using PyTorch
We will see the ability of GAN to generate new images which makes GANs look a little bit “magic”, at first sight.
A generative adversarial network (GAN) is a class of machine learning frameworks conceived in 2014 by Ian Goodfellow and his colleagues. Two neural networks (Generator and Discriminator) compete with each other like in a game. This technique learns to generate new data using the same statistics as that of the training set, given a training set.
In this blog post, we’ll train a GAN to generate images of cats’ faces. We will use the Cats Faces Dataset which consists of more than 15,700 cats images. Since generative modeling is an unsupervised learning method, hence there are no labels on the images.
First, get all the required libraries,
Since we are running this notebook on Kaggle, we can use the “+ Add data” option from the sidebar and search for Cats Faces Dataset. This will download the data to ../input.
Save the path to the dataset in DATA_DIR.
The dataset has a single folder named cats which contains more than 15,700+ images in JPG format.
Now load this dataset using ImageFolder class from torchvision. We will also resize and crop the images to 64×64 px, and normalize the pixel values with a mean & standard deviation of 0.5. This will ensure pixel values are within the range (-1, 1), which is more convenient for discriminator training. Also, we will create a data loader for loading data into batches.
Generative Adversarial Network (GAN)
Now let’s build Discriminator and Generator neural networks, Generator will generate images and Discriminator will detect whether the given image is real or fake. Both of them will compete with each other and both will get better at their jobs.
The input to the generator is a vector or matrix of random numbers that are used as a seed for image generation. The generator can convert a shape tensor (128, 1 , 1) to a (3 x 28 x 28) shape tensor image.
The discriminator takes an image as its input, and attempts to classify it as “real” or “generated”. This will take input as (3 x 64 x 64) tensor and convert it into (1, 1 , 1) tensor.
Below is the function to save each image generated after every epoch.
The next step is to build a training function for discriminator and generator. This function will be used to train both neural networks.
This is the training function for the discriminator.
This is the training function for the generator.
Now we will build a training function which will train the discriminator and generator.
Now we are ready to train the model, we have taken lr = 0.0002 and number of epoch = 60.
After training the model for 60 epoch we can see the generated images.
As we can see the generated images from the model are similar to real ones.
You can find the full code on my Jovian Profile: