Source: Deep Learning on Medium
Sketch2Color anime translation using Generative Adversarial Networks(GANs)
A step-by-step guide for building a GAN that colors an anime sketch.
“Generative Adversarial Networks is the most interesting idea in the last 10 years in Machine Learning” — Yann LeCun.
Problem Statement :
The task is to generate a compatible colored anime image from a given
black-and-white anime sketch with the help of Generative Adversarial Networks(GANs).
Logistics for this article are as follows,
- Getting and preprocessing the data,
- Generator architecture,
- Discriminator architecture,
- Generator and discriminator loss,
- Training the generator and discriminator,
- Tensorboard logs,
- In-progress training results,
- The output of the sample sketches and
So, let’s keep going…
Generative Adversarial Networks(GANs) are an approach to generative modeling using deep learning methods, such as convolutional neural networks(CNNs).
Generative modeling is an unsupervised learning machine learning task that involves automatically discovering and learning the patterns in input data in such a way that the model can be used to generate new examples, that plausibly could have been drawn from the original dataset.
The GAN model architecture involves two sub-models:
1. A generator model for generating new examples,
2. A discriminator model for classifying whether generated examples are real, from the domain, or fake, generated by the generator model.
GAN is based on the zero-sum non-cooperative game (minimax) i.e if one wins, the other loses. In-game theory, the GAN model converges when the discriminator and the generator reach a Nash equilibrium. This is the optimal point for the minimax equation below.
GANs are an exciting and rapidly changing field, delivering on the promise of generative models in their ability to generate realistic examples across a range of problem domains.
Most notably in image-to-image translation tasks such as translating photos of summer to winter or day to night, and in generating photorealistic photos of objects, scenes, and people that even humans cannot tell are fake.
Today we’re going to use these GANs for image-to-image translation task i.e, automatically generate compatible colors for a given black-and-white anime sketch, not even grayscale.
In case if you would like to dig deeper into the math then you can check out the original paper by Ian J. Goodfellow, I really enjoyed reading the paper.
2. Getting and preprocessing the data:
The anime sketch colorization dataset that I’ve used for training the GAN can be downloaded from the kaggle website, here.
After downloading and unzipping the dataset, I had to preprocess it as both sketch and colored anime were in the same image.
Once we save the sketch and colored images to the separate folders for both training and validation/test data while loading the data we normalize the same such that all the values that are in the range [0, 255] come into the range of [-1, 1] as follows,
The reason for this is that according to the well studied GAN hacks, normalizing the input image values to be in the range of [-1, 1] and using “tanh” as generator’s output layer activation yields much better results.
3. Generator architecture:
Sketch2Color anime GAN is a supervised learning model i.e given a black-and-white sketch it can generate a colored image based on the sketch-color image pairs used in the training data.
The architecture of the generator that is used for the sketch to color anime translation is a kind of “U-Net”.
Instead of using fully connected layers in encoder-decoder units as of many previous solutions here we use convolution, transposed convolution to downsample the input and upsample the same to present the output size that avoids information loss when passed through fully connected layers.
Especially in this Sketch2Color anime problem, we need to keep the edges as the most important information from the input to ensure the quality of the output image.
Hence, a “U-net” kind of architecture is employed by concatenating layers in the encoder to the corresponding layers of the decoder.
Whereas yellow blocks represent layers in the encoder and blue blocks in the decoder. In each layer of decoding, the corresponding layers of the encoder are concatenated to the current layer to decode the next layer.
4. Discriminator architecture:
As compared to the generator the discriminator only has the encoder units and it aims to classify whether the input sketch-color image pair is “real” or “fake” i.e if the colored image is from actual data or by the generator.
The input of discriminator is either the pair of sketch (yellow) and real target image (red), or the pair of sketch (yellow) and generated image (blue).
The discriminator network is trained to maximize classification accuracy.
The discriminator output is a matrix of probabilities of shape 30x30x1, in which each element gives the probability of being real for a pair of corresponding patches from the input sketch and colored anime image.
We do also avoid using fully connected layers here, in the end, to avoid any information loss deep in the network and use global average pooling to get a single value.
The convolutional layers between the input and the output extract the high-level features of the input pairs to output the probability.
5. Generator and discriminator loss:
The task of generating colored anime from black-and-white sketch is much harder as it’s a simple line sketch compared to a grayscale image that contains more useful information.
Hence because of this issue, we might have to impose more constraints to yield better results.
Since our sketch2anime GAN is supervised learning we will be using conditional GANs for this purpose.
The loss function for general conditional GANs is as shown below.
The conditional GANs learn a mapping from random noise vector “z” to the output image “y” conditioned on observed “x”.
In our case, the GAN is conditioned on a black-and-white sketch “x” for generating a colored anime.
While the generator tries to minimize the loss, simultaneously the discriminator tries to maximize it and eventually they reach an equilibrium.
The generator that which tries to minimize the loss during the training such that it produces plausible color anime images is,
Training the discriminator simultaneously encourages more variations in the colored image generation. But In order to produce realistic colored images, we mix the GAN loss with some more traditional loss functions.
The first loss that we use is PixelLevel loss i.e L1 distance between each pixel of target color image and generated color image as,
The second loss that we use is FeatureLevel loss i.e L2 distance between the activation(φj) of the 4th layer of the 16-layer VGG network (VGG16) pre-trained on the ImageNet dataset to retain high-level features like specific colors to objects and shapes,
The final loss that we use is TotalVariation loss such that the GAN produces similar colors that were used for sketch-color image pairs in the training data.
This encourages smoothness(acts as a form of regularization) and prevents output denoising.
Finally, the GAN loss function is a weighted combination of all the above losses as the following,
Hence by minimizing this final loss function, the GAN learns better patterns between the sketch-color image pairs.
The weights Wp, Wf, Wg, and Wtv are adjusted accordingly to control the importance of each of the losses.
6. Training the generator and discriminator:
The GAN was trained for 43 epochs with batch size=8 based on the
in-progress generator results for seed sketches after every epoch.
The label smoothing was used during the training process i.e soft labels instead of hard labels(0.9 instead of 1).
The discriminator was trained less say at even-numbered batch because if it learns more about the real and generated color images then it starts dominating over the generator.
Then the generator becomes weak and never learns, captures the distribution of the real target images. Hence, it’s hard to find a schedule of # iterations for discriminator and generator.
Adam optimizer was used with a learning rate=0.0002 and beta_1=0.5.
The discriminator and generator were trained alternatively in a loop: first, train the discriminator, then the generator, then the discriminator again, etc.
7. Tensorboard logs:
At every batch after training, the loss was calculated for the discriminator on the pair of real target image and sketch or pair of generated color image and sketch, and for the generator on the black-and-white sketches.
At the end of each epoch mean of these losses was calculated for the corresponding model.
The discriminator and the generator mean losses were logged after every epoch to monitor their progress using the tensorboard callbacks.
As expected the discriminator loss fluctuates somewhere between 0.23 and 0.35 and the generator loss decreases steadily which indicates a good sign that our generator is capturing the distribution of real target images.
The ModelCheckpoint was made for generator after every epoch so that we can use the best output model so far as to predict the colors after completion of training.
8. In-progress training results:
After every epoch, the generator was made to predict the colors for some of the fixed/seed sketches to check if it’s learning the correlations between sketch and its associated colored image.
Some of these results at different epochs are as shown below,