Black and White Image Colorization with Deep Learning
This blog post summarizes the results of my first project using deep learning. I chose to work on colorizing black and white pictures. I started by reproducing two models from the medium article of Emil Wallner. Then, I implemented the Pix2Pix model. My goal was to learn how to handle large image datasets, to implement deep learning models, to train and test them using both Keras and PyTorch. Along the way, I learned how to debug the models, tune the hyper-parameters. You will find an example of how the choice of datasets impact the colorizing task.
All the codes are available here.
As a researcher, my main focus is solar energy conversion technologies using molecular semiconductors. You have already seen some molecular semiconductors without realizing it! Check out your mobile phone! The screen is probably made with an OLED (organic light-emitting diode) display; this is it!
You may wonder what’s the link between my research and deep learning? Molecular semiconductors hold the promise of tailoring the properties to applications by designing their chemical structure. What better than using simulations to do this! Unfortunately, the simulation of large assemblies of molecules at the atomistic level is too greedy for even state-of-the-art high-performance computers! Also, we’d like to guess what would be a good molecule for a specific applications! Could deep learning solve those problems? I believe so.
Last August, I took the plunge and enrolled in a bootcamp about machine learning. During the last 6 months, I worked on a capstone project to colorize black and white pictures using deep learning. This blog post summarizes the results of my project. Below, you can see on the left the original image and on the right the colorized image by one of the model I will present below.
Colorizing Black & White Pictures
If turning a color image in black and white is almost trivial, the inverse problem is not. Let’s take a simple drawing; if the grey level can undoubtedly give the feeling of how light a color is, how to pick the color? We all know that grass is green, the sky is blue, but we would still have to pick the color we feel would be most appropriate. What about the color of a jumper? It only needs to be a realistic color, does it not? One can go further and get contextual information; for example, historical periods can give information about the most likely color for a garment. Some colors were more expensive, or some others did not exist yet. If you want to know more, I recommend the books of Victoria Finlay. Now, if we take a picture, the exposure to light will impact the picture by changing the saturation of the color. A grey shade also gives the feeling of how bright part of the picture is. Saturation changes on a pixel basis. How to learn the feeling of what color is right and how bright is a picture?
This blog post is organised as follow:
- How to handle colors?
- Neural Networks
- How to improve the training?
- Final thoughts.
All the codes are available here. I used the library kora to organize the Jupyter notebooks as modules.
How to handle colors?
Color images are often represented with 3 channels (RGB) due to the simplicity of the system. The color is reconstituted by simply adding the 3 channels that are all normalized between 0 and 255. However, for compressing images, another system is preferred where luminance and chroma are separated. The human is more sensitive to variations in brightness than luminance. Thus, it enables to sample the color images at a lower rate without perceived loss. The luminance (L)is nothing else than the black and white image. Why not using this trick for colorizing? It reduces the problem from Net(L) = [R,G,B] to Net(L) = [a,b]. Instead of having to learn to reconstruct the 3 channels R,G,B, the neural networks need only to reconstruct the a and b channels. a and b are for the four unique colors of human vision: red, green, blue, and yellow. The L channel is normalized between 0 and 100 and the a and b channels between -128 and 128.
Neural Networks
I started by trying to reproduce the models posted here using both Keras and PyTorch. The beta model in this post is a simple auto-encoder. In the gamma version, a classifier is used in parallel to the encoder. The classification layer is fused with the output of the encoder and passed to the decoder. In the original blog post, Inception ResNet V2 was used. For the PyTorch implementation, I have used the MobileNetV2 as it is lighter.
Build on TensorFlow, Keras provides a high level of abstraction and seems really easy to start with. PyTorch, in comparison, is a lower-level environment but still very user-friendly in my opinion. It is well-documented and you can find lots of resources online. The PyTorch tutorials are really a great place to start. From my point of view, learning with PyTorch forces you in understanding a bit deeper the concepts of deep learning. Finally, in my research field, graph neural networks (GNNs) are used and there is a great library built on PyTorch PyTorch geometric to handle GNNs more easily.
Then, I tried to implement from scratch the Pix2Pix architecture using PyTorch. The original Pix2Pix paper generates, from the black and white picture, an RGB picture. I implemented both RGB and Lab. I checked my implementation against the tutorial that can be found here. Pix2Pix is a type of conditional generative adversarial network (cGAN) that uses an U-net as a generative network and a patch discriminator. U-nets are auto-encoders with skip connection. The generator is trained via both adversarial loss and L1 loss measured between the generate image and the output image in a similar way as for an auto-encoder such as the previous models. The adversarial loss encourages the generation of plausible images and the L1 loss encourages the generation of plausible translations of the image.
With the most simple model, the beta model, the images are dominated by brown. From painting, we know that the mixtures of the three primary colors (red, yellow and blue) make brown. Therefore, this should not come as a surprise. Brown is going to produce the smallest error. The gamma model should improve the results by matching an object class with a coloring scheme but after 50 epochs, this is difficult to see. The Pix2Pix(Lab) model clearly gives the best results. The Pix2Pix(RGB) model does not only have to learn to colorize the image but it also needs to learn to reconstruct the drawing, which we are more sensitive to. Some pictures are clearly blurred, a problem that is discussed in the original Pix2Pix paper. However, this model has some merits when we want to colorize old black and white images that have been damaged with time for example.
How to improve the training?
Beyond modifying the model itself, it is possible to play with the other hyper-parameters of the training. The GAN architecture is not the best one to try to play with the optimization algorithm or the batch size as there are no good metrics for evaluating GANs during training; the loss function in itself is not sufficient. Thus, I tried to modify the optimizer and the batch size with the beta models and recorded the loss during training. Based on this experiment, I chose to use the Adam optimizer and a batch size of 64 images to also optimize the time per epoch.
Training GANs can be challenging as the training process is inherently unstable due to the competing nature of the generator and discriminator training. Every time the parameters of the generator or the discriminator is changed, the optimization problem of the other one changes. Thus, improving one model comes to the detriment of the other model. For the Pix2Pix model, I have initialized the weights of the neural network as it can help preventing the problem of exploding or vanishing gradients. As mentioned above, the loss function for GANs is not sufficient and I have recorded the changes on a selected set of images seen during the training and a selected set of images not seen during the training to get a better feeling of what the neural network is learning. Until about 20 epochs, the results are dominated by brown and blue colors. For longer training, the neural networks starts learning more diverse colors.
Large datasets of pictures are available (MS-COCO, Imagenet, Open Images dataset, CIFAR-10, CelebA). For the black and white problem, tags are not needed as the images can be preprocessed to generate pairs of the color image and the corresponding black and white picture. In a first instance, I used the dataset provided along the blog post of the beta/gamma models. However, how is the chosen dataset going to impact the results? How diverse and how big the dataset has to be to generalise the colorization? I retrained the Pix2Pix on part of the CelebA dataset as to keep the size of the training dataset similar to the one I have used so far.
Looking at what the neural networks learn during training, it is clear that the dataset impacts the training. The neural networks trained on CelebA learns first colors that reminds the color of the skin and then, learns other colors like green, blue and yellow. Interestingly, it seems to learn more vivid colors.
As expected, the neural networks trained on the CelebA dataset performs better on portraits and performs less well on the landscape. However, it seems that with longer training it could give some good results as it picks colors from the background of the images.
Final thoughts
It is definitely possible to reach better results by training the models longer and by probably using larger dataset (here, I limited the dataset to about 10 000 images). As shown the dataset used is important and can significantly improve the results on subsets of the testing set.
Different strategies could be used to get a good colorization utility:
- Various Pix2Pix networks could be trained on “specialized” datasets as this can be done in parallel. A new B&W image would go to a classifier and will enter one of the networks according to the most likely label. For images where the classifier does not perform well, we could also use the discriminator prediction to decide which colorized image to return.
- A multi-scale discriminator could be introduced. Patch discriminator, such as the Pix2Pix discriminator, works on the assumption that pixels separated by more than one patch size are independent and thus, fails to capture the global structure of the images. Using different subnets with different depths could improve the learning.
- Alternatively, self-attention mechanisms could be introduce either in the discriminator or in both discriminator and generator as in the self-attention generative adversarial network.
Big thanks to Springboard and Dat Tran.