Generative Adversarial Networks in Tensorflow 2.0

At SpringML we are always keeping up with the latest and greatest technologies in Machine Learning and Artificial Intelligence.

In this post, I am going to discuss:

  1. Generative Adversarial Networks (GANs)
  2. A working version of the code in Tensorflow 2.0

With the release of Tensorflow 2.0 from the Tensorflow Dev Summit, there were lots of updates and takeaways from it.  With so many new additions and functionalities, it was hard to narrow down something to try.  I decided to try Pix2Pix for image-to-image translation with Generative Adversarial Networks (GANs).  This was appealing because this take can take an image from one representation into another, like translating a letter from English to Spanish.  This is typically a very manual and tedious process and needs expert knowledge. Also, whatever steps taken would be very specific only to that use case. But, by using the power of Artificial Intelligence, you can train a model to infer on what the picture could look like in another medium with a general purpose solution.  Pix2Pix has been done on several personal Github repos such as here, but this is the official implementation on Tensorflow 2.0 code.

How GANs Works

How the GANs algorithm works is that there is a generator that is constantly creating new images based on the training set and the discriminator is always trying to distinguish if the image is the real thing or was it generated by the model.  These two models are run simultaneously, playing a constant cat and mouse game where the generator is always trying to produce better copies of the image and the discriminator is always trying to better distinguish them.  Then, it will get to a point where the generator will do so well that the discriminator will not tell the difference anymore. This constant discrimination is what improves the model over time.

This particular flavor of GANs uses a conditional statement, where there is an additional noise vector layer (condition) added into the generator of the model.  There is also the ability to focus on generating outcomes on specific classes in a multimodal situation rather than randomly generating classes in an unconditional way.

Use Cases for GANs

Use Case 1
Colorizing historical photos, modernizing them instead of hand coloring them in photoshop.

 

Use Case 2
Translating satellite maps into google maps images.

Use Case 3
Standardizing images from different conditions such as day to night.

Use Case 4
There are fun things as well, such as translating a horse into a zebra.

Here are some of the key things that I noticed in the Pix2Pix example that highlights Tensorflow 2.0:

  1. tf.Function
    With this new implementation, several functions will be deprecated in favor of this new function part of the new eager execution concept. With eager execution, you would be to execute Tensorflow code quickly instead of being bogged down with setting up your layers and graphs in Tensorflow 1.X. Based on the documentation examples, things such as tf.sessions, tf.placeholder, and tf.global_variable_initializer are going way. You can learn more about it here. The @tf.function acts as a declarative statement outside of the defined function in Tensorflow as used in the example for Pix2pix:
    tf.Function
    Which is all you need!
  2. tf.keras
    Tensorflow is further adopting Keras as their high level API, so it’s easier to start working with modeling.  For example, tf.layers is being incorporated into tf.keras as tf.keras.layers. There is tf.keras.sequential where you can add your sequence of events in your next model as seen here in the Pix2Pix example:
    Tensorflow is further adopting Keras as their high level API
  3. tf.GradientTape
    A “tape” records all the gradient steps and executed again when necessary, by “rewinding the tape”.  This saves time of reimplementing steps since they are saved.
    tf.GradientTape
  4. tf.data
    This new function handles a lot of the ETL processes of Tensorflow with their pre-built calls.  Normally, a great deal of code would be used to pipe in data, convert the data into a flat matrix, one-hot encode the labels,reshape the data, transform it into a Tensor, and batch it- just so you can execute it in a tensor graph.  Now it seems more straightforward as shown here:
    tf.Data handles a lot of the ETL processes of Tensorflow with their pre-built calls.
  5. tf.checkpoints
    The new release makes it easy to save and restore checkpoints of your Tensorflow model without writing verbose lines of code.  Here is how it is used in the Pix2Pix example:
    tf.Checkpoints makes it easy to save and restore checkpoints of your Tensorflow model without writing verbose lines of code.

Trying On Another Dataset

To not interfere with my current code base on my local machine, I turned to Colab where all the code is in a virtual jupyter notebook environment. There is even the capability to run a GPU as well.

I decided to use the cityscape dataset, where I translate images that have been semantically segmented back into a real image.

cityscape dataset
Curated for the Pix2Pix

Click here for the Curated Pix2Pix dataset.  

The original code from the Pix2Pix tutorial used to the ‘facades’, dataset, which was small.  After going over the code base, I learned where to make the changes to bring in another dataset.  At first, the colab notebook crashed several times because the dataset was too big for the free resources available on Colab.  So, I decided to cut the dataset down to 100 images for train, 20 images for test, and 20 images for a validation set just for experimental purposes.  The original code base just used a train and test set, but there were more than enough images to create a validation set.

Results 1
The initial results were a bit blurry.

Results 2
The results did get better over time.

 

Results 3
After 200 iterations, I tried the saved checkpoint on some validation images, separate from train/test altogether.

 

Results 4

 

Results 5

The results were pretty good considering the low amount of samples. The predicted images appear to be a bit grainy, but I am sure with more images and iterations, the results would improve. According to the paper, the authors used 2975 training images from the Cityscapes training set for 200 epochs, with a batch size of 10, where the code base normally used 1. They said it was better for the encoder-decoder part of the model. The paper also states that the images become sharper and more colorful, but that is probably not possible with the limitations of Colab. The good news is that blurry images will automatically get screened out since they look fake according to the paper.

In order to apply this for a business use case:

  1. A custom dataset would have to be created for the client. The same scene would have to be captured in the exact same position but in two different mediums. For example, if there is a hand-drawn map, there needs to be an exact equivalent of a satellite image for the model to infer upon.
  2. A Compute Engine backed by a GPU in Google Cloud would definitely solve the memory and resource allocation issues.
  3. The color of the translated images would not be right at times, but it will bring the images into the same medium to apply one general model. There may be hallucinating effects according to the paper, but that seems to go away with more data and more training.

Conclusion

Generative Adversarial Network itself is still relatively new, but image-to-image translation already has been used in a variety of different ways. With the use of Colab, I was able to explore the code base without interfering with my local setup.

As a Google Premier Partner, we are excited about the release of Tensorflow 2.0 and we are looking forward to applying this to your next data science project. Feel free to contact us at info@springml.com if you have any questions or need help with any similar use cases.

References: