Source: Deep Learning on Medium
Coloring Photos with a Generative Adversarial Network
Ever since I started learning about data science and machine learning, there has always been one algorithm that continually grabbed my attention: Generative Adversarial Networks (GANs). So much so that the second blog I ever wrote covered, in detail, how these models work and what they can create. When I first learned about them, most of the articles (including my own) only covered how they are able to take in a vector of random noise and produce life like photos.
Shortly after I wrote that blog article I came across Cycle-GAN’s (Zhu et al.) which allows image to image translation. This means we are able to pass in an image as the input and receive and altered copy of it as output. This includes mapping an artist’s styles such as Van Gogh, changing a photo from looking like summer to winter, or transforming a zebra to look like a horse.
After I wrote the aforementioned blog post, I felt as if I was comfortable enough with GANs, how they work and their variations that I decided to build my own. However, with so many applications to choose from I struggled to decide on what I wanted to create. One night, while I was still considering different options, I watched Peter Jackson’s They Shall Not Grow Old, a WWI documentary that was created through image reconstruction and recolorization. This, combined with my knowledge of transfer learning, inspired me to create a model that can convert black and white photos to color.
(For those not familiar with Generative Adversarial Networks I suggest reading my first blog before continuing)
Steps to Create a GAN:
- Get Data
- Preprocess Images
- Create Architecture
- Train, Monitor, and Tune Parameters
Before I could start coding or planning my model’s architecture I needed to find a dataset to work with. I learned that when it comes to image translation, GANs are much more efficient in understanding texture and symmetry in photos than they are at identifying complex geometry. This means, for example, they are able to create photos of landscapes much easier than they can a person or dog. This prompted me to collect 2 different datasets:
- MIT Computational Visual Library Dataset : Only used coastal photos since they contained simple landscapes.
- MPII Human Pose Dataset: Images of people performing various activities.
I started with the MIT dataset for my first attempt since these images are easier to model, allowing me to focus on creating my first architecture and all the necessary preprocessing before working with the much larger MPII dataset. Once I acquired all of my data I could finally begin coding.
Here is the Github for reference.
Since GANs are extremely computationally intensive, it’s important to limit the size of the images; I chose the size of 256×256 images. The MIT coastal photos were already this size so nothing needed to be done with the pictures from that dataset. On the other hand the MPII dataset images were of varying sizes larger than 256×256 and thus needed to be resized, which is easily done using Numpy.
My first idea was to take every image and simply convert it to grayscale and pass that into the generator to create a fully colored image using the standard RGB (red green blue) color channels with their normal values of 0–255. I quickly realized this method was very inefficient and that in order to transform my images I needed to use a few different methods instead of one simple solution.
During my research, I carefully examined researchers’ examples who attempted to color photos from black and white, and noticed a pattern among them; many transformed the images from RGB color values to LAB color values. Unlike RGB, which merges red, blue and green color values to create the color image, LAB consists of a light sensitivity channel and two color channels. The L channel contains information for the light sensitivity of a photo and is equivalent to a black and white version. A and B are the color channels where A controls the green-red tradeoff and B controls the blue-yellow tradeoff. Python’s Scikit-Image library comes with a great method that allows us to easily convert out RGB photos to LAB.
Instead of trying to create all three color channels, my initial method, I now decided to pass the L channel through the Generator as the input, and output the new A and B color channels. Lastly, once I converted the images to LAB I need to normalize, since their pixel values are a very inefficient range for my model to work with. I decided to scale the pixel values from -1 to 1 so that I could easily replicate it with a tanh activation.
To create the models I used Keras, a deep learning library for Python. First, I created the generator and discriminator separately; then I connect the two so the generator can learn based on how well it fools the other model.
The generator can be broken down into two pieces: the encoder and decoder. The original image, in my case the L channel, is downsized through convolutional layers until the desired feature map size is achieved. This group of feature maps before I upsize the image again is known as the latent space representation of the image. From here, I perform what are known as transpose convolutions or deconvolutions which allow me to upsample my image size; the animation link here displays an excellent example that shows how this works. I repeat this process until I am back at the original size, and then output the new image, the A and B channels.
My generator’s encoder consists of four convolutional layers with a stride of two so that I can downsize the image. I start with 64 feature maps and double the number of maps at each layer so that by the time I reach the latent space representation, i.e. the middle of the network, I have a maximum of 512 maps, each with a size of 16×16. The outputs of each layer then goes through batch normalization and then finally a leaky ReLU activation. The decoder is almost a reverse copy, as it has four deconvolutional layers whose output goes through batch normalization and ReLU activations. The very last layer is a one-stride convolution with two output channels, A and B color channels, with a tanh activation to match the -1 to 1 scale we set earlier when preprocessing our images.
The discriminator is a much simpler model than its counterpart, the generator, because it’s a standard Convolutional Neural Network (CNN) that is used to predict whether the AB channels are real or fake. It has four 2-stride convolutional layers each which consists of dropout, leaky relu activation, and, except for the first layer, batch normalization. Like every other CNN, I then flatten the last layer and put it through a sigmoid to predict whether the image is real or fake.
Both models are going to be using binary cross entropy for its loss function and the Adam optimizer. Contrary to everything I read, I needed to use a higher learning rate for the generator to keep it from being overpowered. First, the discriminator is initialized and then I create and link the generator’s loss to the discriminator’s output.
Train, Monitor, and Tune Parameters:
After completing the processes explained above, I was ready to train the model, or so I believed. As I ran the model on my laptop, I received the message “Kernel Died”. The memory and computational requirements were way too intensive for my laptop. It can also take multiple days to train a GAN on a CPU, and that showed me two things: I needed a strong computer/instance and a GPU. For this I went to AWS SageMaker and rented a GPU instance which provided me with more than enough power.
The last tweak I made was limiting my epoch size to be only 320 and instead of training for a few hundred epochs I trained for a few thousand. I trained the discriminator on a half epoch of real images followed by a half epoch of fake. Finally, the generator produces a full epochs worth of images which are passed to the discriminator, and the generator is penalized if the images are marked as fake. GANs are hard to train because, unlike most other deep learning algorithms, it doesn’t look to minimize or maximize any loss function. Instead, I look for an equilibrium, or saddle point, where the two networks stay competitive and continue to build off each other. As you can see from the graph of the losses below, it is a very sporadic training cycle.
It can take many attempts to get the model stable but once you do it’s important to be be patient, as it can take anywhere from 6–12 hours in a GPU instance to get the model trained. The two things that helped keep me in control of the long training times and vague metrics were printing images frequently and saving the model about half as often. In order to view the photo simply scale back to original pixel values and convert to RGB.
As you can see from the images above, the model will sometimes have a bias for certain color hues, and the last image in particular tends to lean towards greens and reds as well as some color bleeding.
This project is most definitely still a work in progress and my next is step is to build out a self-attention layer to help the model better understand the geometry in photos. Once I get the new model built, my goal is to train it on the Imagenet dataset so it can generalize to more photos and create a web application to allow anyone instantly color their pictures.