Cycle GAN with PyTorch

Original article was published on Deep Learning on Medium


General idea of the cycleGAN. (Source: Hardik Bansal)

In this article I am going to share an interesting project which I was part of, the project’s goal was to build a cycle GAN which could take in images of class A and transform them to class B, in this case horses and zebras. I will go in order covering the following topics:

  • Cycle GAN description, main features.
  • Where and how to find image data.
  • Implementation of the cycle GAN in PyTorch.
  • Presentation of the results.

Cycle GAN description

The cycle GAN (to the best of my knowledge) was first introduced in the paper Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks. (I would recommend anyone interested in computer vision to read the paper as it is not very long and provides the intuition behind the cycle GAN as well as important mathematical formulation of the loss functions.) What makes the cycle GAN interesting is that it is a combination/ game between two neural networks which learn in parallel. These two networks are known as the generator which takes in real images and outputs fake images and the discriminator which classifies whether the image is real or fake. The goal is that over time the generator will become better at tricking the discriminator and the discriminator will become better at not being tricked. In order to optimize the generators and discriminators accordingly a total of eight loss functions are introduced. Various loss functions can be used but in this project we use the LSGAN loss function which is based on least squares and the L1 loss function which is the absolute mean.

Discriminators: LSGAN loss

The goal of the discriminator is to as mentioned earlier classify a real image as real and fake as fake, to optimize this the following least squares loss function is used:

The intuition here is, in the case of a real images the perfect discriminator would output all ones and get a zero loss from the first term. In the case of a fake image the perfect discriminator would output all zeros, and also get zero loss. The LSGAN is used once for each discriminator in the model.

Generators: trick discriminator

One of the goals of the generator is to trick the discriminator into believing the fake images are real. In order for the generator to do a perfect job in terms of tricking the discriminator, the discriminator should output all ones for the generated image. In this case the generator would get a loss of zero. To accomplish this a slightly varied definition of the LS loss function to the one above is defined:

Generators: cycle consistency

We impose constraints on the generators known as cycle consistency. We want the generators to be able to first map from one class domain to the other class domain, and then back to the first class domain without any changes to the image. This constraint is imposed by using the L1 loss function on pixel to pixel level between the real and restored image:

Generators: Identity loss

The identity loss is needed to constrain the generator from altering images of the same class is it is trying to generate to. For example, if a horse image is given as input to the horse generator then we would want the input to pass through the generator unaltered. This trait is imposed via the identity loss and is based on the same L1 loss function as shown above.

Generator architecture

Generator structure (dimensions will depend on image sizes)

As the figure above shows, the generator has three main components, the encoding phase, the transformation and the decoding.

Encoding:

The encoding phase converts the features in the image to a latent space representation via multiple convolution layers.

Transformation:

The transformation phase is comprised of six or nine resnet blocks and used to assemble the appropriate latent features captured in the encoding phase . The resnet blocks also have the benefit of skip connections which help avert vanishing gradients in deep networks.

Decoding:

The decoding phase does the opposite of the encoding phase and assembles the latent representations via transpose convolutions.

Discriminator architecture

Discriminator structure

The discriminator is simply a network consisting of convolutional layers, and has the function of performing binary classification, i.e. is the image real or fake.

Finding image data

For anyone interested in performing pattern recognition tasks on image data, finding large quantities of area specific images easily and quickly will be important. Most deep learning frameworks such as tensorflow, keras and PyTorch have functions which give the user easy access to popular data sets such as cifar 10, mnist fasion/ digits. However finding personal image data takes a bit more effort but is still not very difficult. I have found the following forum discussion to be informative in this topic: Link. I have personally found ImageNet to be easy and reliable for scraping image data. Based on my experience ImageNet can provide roughly 1000 images of a given object (such as strawberries), this may or may not be enough data given the deep learning application.

The data used for this project were images of horses and zebras with a size of 256 X 256. The horse/ zebra data set (which I did not collect myself) can be found at the following link: Link.

Cycle GAN implementation with PyTorch

The generator network is defined as seen below.

The discriminator network is defined as seen below.

The aforementioned loss functions used for the generator and discriminator networks respectively are implemented as seen below.

The full code used for training the networks can be found in the following notebook. The networks were trained using GPU via google corroboratory.

Results

The figure below shows some of the resulting images received from the implementation described above.

Cycle GAN results

When we look at the images generated from the model, the quality in results vary considerably. Some image results are generally good, where as others are of lower quality such as portions of the image which do not contain any animal being altered. Generally, when generating a fake zebra from a horse image, the best result is obtained if there is a brown horse where all or most of the horse’s body is present in the picture, as well as the horse covering a large portion of the image. If there is a large amount of noise in the image i.e. the horse body takes up a small portion of the image the generated is often times hardly altered by the generator. A hypothesis to why this could be the case is that the model is interpreting the image to be a zebra and in which case due to the constraint from the identity loss function the generator model does not alter the input image.

Generally, when the model is to transform a zebra into a horse, it performs worse than in the opposite case. The poorer performance can most likely be attributed to the appearance of horses varying more than zebras. Another aspect noticed is that even though white and black horses exist in the training data images the change seen on zebra conversion images is that only the color brown is added to a larger or smaller part of the zebra body. The reason for this is that the data set most likely contains more images of brown horses, and thus the model has over fit to the color brown.