What did I learn by implementing pix2pix

Source: Deep Learning on Medium

What did I learn by implementing pix2pix

Paper: https://arxiv.org/pdf/1611.07004.pdf

This thing got my attention quickly after seeing it in fastai course (100% recommended!!). Pix2pix is a GAN able to translate one image into another image. Some examples from the paper:

This paper is 100% recommend as is easy to read and understand, sadly this doesn’t happen often on papers about deep learning.

Architecture

Generator: UNet modified

The architecture of the generator is a unet network, winner of ISBI cell tracking challenge 2015, with a slight change on how skip connections are done. This is the original architecture of UNet:

We can see there are skip connections between the encoder (left branch) and the decoder (right branch). The outputs from the left branch are concatenated and then we have 2 conv blocks before upsampling again. In contrast pix2pix after concatenating and applying a conv layer it upsamples again, it’s like having 1 conv layer per level, instead of 2 as in the original version:

Source: https://www.researchgate.net/publication/328150573_Dialectical_GAN_for_SAR_image_translation_From_sentinel-1_to_TerraSAR-X

My implementation of the generator:

Discriminator

The discriminator receives 2 images, yes 2!! And that’s why it’s called conditional GAN (CGAN) in the paper. The architecture is like this:

Source: https://www.researchgate.net/publication/328150573_Dialectical_GAN_for_SAR_image_translation_From_sentinel-1_to_TerraSAR-X

Why two images? One of the images is the X data, the image already used as input on the generator and the other image depends. It depends on what action is doing, learning to recognize fake images or learning to recognize real images. So the second image is either the output from the generator (learning to recognize fake images) or the target image (learning to recognize real images).

And the output of the discriminator is quite interesting also. Usually a discriminator outputs a single value, is it a real image or a fake one? In this case the output is a matrix representing “patches” of the original image. Those patches are interesting in the way that some of them have overlapping information and some not, this is what is called the receptive field in the paper.

Here is a schema of the “receptive fields” propagation on the discriminator:

Source: https://medium.com/@EricKuy/image-to-image-pix2pix-e690098231fd

My implementation of the discriminator:

You can notice there is no sigmoid at the end, why if we want values close to 0 or 1? Well, that’s what the original implementation does, sends the output of the last layer to the loss directly and the answer is because this way is numerically more stable, using BCEWithLogitsLoss https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/41931e25c7d12e0ff2fcea4ee7ba2e597769e6f2/models/networks.py#L224-L225

This is the output of the discriminator when seeing a fake image as input and passing it through a sigmoid function, all values are almost 0, there is no part of the image recognized as something close to real. The discriminator is doing very good job! (not so good the generator, is unable to fool the discriminator):

Left is what the generator sees as input, right the generated image. The discriminator sees both at the same time and produce the output below. This sample is from the test set.

And this is the output when seeing a real image, everything looks pretty real right?:

Left input of the generator, right the expected output. The discriminator see both as the same time.

Losses

I’m not going to say much about the theory behind the loss function as is well explained in the paper and there are many other articles speaking about it like this, this and this.

For the completeness of the article here are the formulas of the loss of the generator from the paper:

In contrast what many people don’t show is how to implement the loss, and here is how I did it (I have to say that this part took me quite a bit of work to get it right, and with lot of help from the github of the paper, meh… this is my first time playing with GANs):

Examples

For fun and profit here we have some examples of the generated images (the order left to right is input of the generator, output from the generator and target image).

Training set

As we can see the generator have almost memorized the images of the training set, the generated images are quite good.

Test set

But for the test set the results are much worse…

This isn’t too bad!

Full implementation

Here you have the full notebook on google colab if you want to play with it:

Conclusion

Implementing papers is always a good idea because it is a real challenge and there is always something new to learn. Understanding the theory and being able to implementing are totally different things. I found quite interesting the idea of the CGAN and the PatchGAN, although they might not be novel to this paper.