How to Quickly Implement Generative Adversarial Networks Using TFGAN

How to Quickly Implement Generative Adversarial Networks Using TFGAN

Editor|Debra
AI Frontline Introduction: Generative Adversarial Networks (GANs) are currently widely used in various scenarios such as image generation, super-resolution image generation, image compression, image style transfer, data augmentation, and text generation. More and more researchers are engaged in the study of GAN networks, proposing various variants of GAN models, including CGAN, InfoGAN, WGAN, CycleGAN, etc. To facilitate the application and practice of GAN models, Google has open-sourced a TensorFlow library called TFGAN, which allows for quick implementation of various GAN models. For more valuable content, please follow the WeChat public account “AI Frontline” (ID: ai-front)

This article mainly explains how TFGAN can be applied to native GAN, CGAN, InfoGAN, WGAN, and other scenarios, as shown below:

How to Quickly Implement Generative Adversarial Networks Using TFGAN

Among them, the Mnist images generated by the native GAN are uncontrollable: CGAN can generate digit images corresponding to digit labels; InfoGAN can be considered as an unsupervised CGAN, where the first two rows represent controlling the generation category of digits with categorical latent variables, the middle two rows represent controlling the thickness of digits with continuous latent variables, and the last two rows represent controlling the tilt direction of digits with continuous latent variables; ImageToImage is a type of CGAN that implements style transfer of images.

Generative Adversarial Networks and TFGAN

GAN was first proposed by Goodfellow, consisting mainly of two parts: the Generator (G) and the Discriminator (D). The generator mainly generates a sample similar to real data from noise z, and the more realistic the sample, the better; the discriminator is used to estimate whether a sample comes from real data or generated data, and the more accurate the judgment, the better. As shown in the figure below:

How to Quickly Implement Generative Adversarial Networks Using TFGAN

In the figure above, for real sampling data, after passing through the discriminator network, it generates D(x). The output of D(x) is a real number in the range of 0-1, used to judge the probability that this image is a real image. Thus, for real data, the closer D(x) is to 1, the better. For random noise z, after passing through the generator network G, G transforms this random noise into generated data x. If it is an image generation problem, the output of G network is a generated fake image, represented as G(z). The discriminator model D aims to make D(G(z)) close to 0, meaning it can judge that the generated image is fake; the generator model G aims to make D(G(z)) close to 1, meaning it can deceive the discriminator model, making D believe that the fake data generated by G(z) is real. Thus, through the game between the discriminator model D and the generator model G, D cannot determine whether an image is generated or real, leading to the end.

Assuming P_r and P_g represent the distribution of real data and generated data respectively, the objective function of the discriminator model can be expressed as:

How to Quickly Implement Generative Adversarial Networks Using TFGAN

The objective function of the generator model is to make the discriminator model D unable to distinguish between real data and generated data, thus optimizing the objective function as:

How to Quickly Implement Generative Adversarial Networks Using TFGAN

The TFGAN library can be found at https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/gan, which mainly includes the following components:

  1. Core architecture, mainly including creating TFGAN models, adding Loss values, creating training operations, and running training operations.

  2. Common operations, mainly providing gradient clipping operations, normalization operations, and conditioning operations, etc.

  3. Loss functions, mainly providing commonly used loss and penalty functions in GAN, such as Wasserstein loss, gradient penalty, mutual information penalty, etc.

  4. Model evaluation, providing Inception Score and Frechet Distance metrics for evaluating unconditional generative models.

  5. Examples, Google also open-sourced commonly used GAN network example code, including unconditional GAN, conditional GAN, InfoGAN, WGAN, etc. Related use cases can be downloaded from https://github.com/tensorflow/models/tree/master/research/gan/.

Using the TFGAN library to train GAN networks mainly includes the following steps:

  1. Determine the input for the GAN network, as shown below:

    How to Quickly Implement Generative Adversarial Networks Using TFGAN

  2. Set the generator model and discriminator model in GANModel, as shown below:

    How to Quickly Implement Generative Adversarial Networks Using TFGAN

  3. Set the loss equation in GANLoss, as shown below:

    How to Quickly Implement Generative Adversarial Networks Using TFGAN

  4. Set the training operation in GANTrainOps, as shown below:

    How to Quickly Implement Generative Adversarial Networks Using TFGAN

  5. Run the model training, as shown below:

    How to Quickly Implement Generative Adversarial Networks Using TFGAN

CGAN

CGAN (Conditional Generative Adversarial Nets), addressing the uncontrollable shortcomings of GAN, adds supervised information, transforming the training from unsupervised to supervised, guiding the GAN network to generate. For example, by inputting classification labels, it can generate images corresponding to those labels. Thus, the objective equation of CGAN can be transformed to:

How to Quickly Implement Generative Adversarial Networks Using TFGAN

Here, y is the added supervised information, D(x|y) indicates judging real data x under the condition of y, and D(G(z|y)) indicates judging generated data G(z|y) under the condition of y. For example, the MNIST dataset can generate corresponding images based on digit label information; facial generation datasets can generate corresponding facial images based on gender, smile, age, etc. The architecture of CGAN is shown in the figure below:

How to Quickly Implement Generative Adversarial Networks Using TFGAN

In TFGAN, an API is provided for generating condition tensor from one_hot_labels variable and input tensor, as shown below:

tfgan.features.condition_tensor_from_onehot 
(tensor, one_hot_labels, embedding_size)

Where tensor is the input data, one_hot_labels are the one-hot labels, with shape [batch_size, num_classes], and embedding_size is the embedding size corresponding to each label, returning a condition tensor.

ImageToImage

Phillip Isola et al. proposed a conditional adversarial network for image generation based on CGAN called “Image-to-Image Translation with Conditional Adversarial Networks.” The basic idea of the network design is shown below:

How to Quickly Implement Generative Adversarial Networks Using TFGAN

Where x is the input line drawing, G(x) is the generated image, y is the true image corresponding to the rendered line drawing x, the generator model G is used to generate the image, and the discriminator model D is used to judge the authenticity of the generated image. The discriminator network maximizes the judgment that the data (x,y) is real while judging that the data (x,G(x)) is fake. The generator network ensures that the discriminator network judges the data (x,G(x)) as real, thus engaging in a game between the generator model and the discriminator model. To ensure that the generator model not only deceives the discriminator model but also that the generated image resembles a real image, the L1 distance between the real image and the generated image is added to the objective function, as shown below:

The TFGAN library provides relevant loss function API examples for ImageToImage generative adversarial networks, as shown below:

How to Quickly Implement Generative Adversarial Networks Using TFGAN

# Define L1 loss between real data and generated data

How to Quickly Implement Generative Adversarial Networks Using TFGAN

# gan_loss is the objective function loss

gan_loss = tfgan.losses.combine_adversarial_loss
(gan_loss, gan_model, l1_pixel_loss, weight_factor=FLAGS.
weight_factor)

InfoGAN

In GAN, when the generator generates data using noise z, no conditional constraints are added, making it difficult to represent relevant semantic features with any dimension of z. Therefore, during the data generation process, it is challenging to control what kind of noise z can generate what kind of data, which greatly limits the use of GAN. InfoGAN can be considered as an unsupervised CGAN, adding latent variables c on top of noise z, allowing the generated data to have high mutual information with the shallow variable c, where Info represents the meaning of mutual information. Mutual information is defined as the difference between two entropies, H(x) is the entropy of the prior distribution, and H(x|y) represents the entropy of the posterior distribution. If x and y are independent variables, then the value of mutual information is 0, indicating that x and y are unrelated; if x and y are correlated, then the mutual information is greater than 0. Thus, given y, we can infer which values of x are likely to occur. The objective equation for InfoGAN is as follows:

The network structure of InfoGAN is shown below:

How to Quickly Implement Generative Adversarial Networks Using TFGAN

In the figure above, the difference between InfoGAN and GAN lies in the output D(x) of the corresponding discriminator network, generating a variational distribution Q(c|x), thus allowing Q(c|x) to approximate P(c|x), thereby increasing the mutual information between the generated data and latent variable c. TFGAN provides relevant APIs for InfoGAN, as shown below:

# Define infogan model through tfgan.infogan_model

How to Quickly Implement Generative Adversarial Networks Using TFGAN

# Generate loss value for infogan model through tfgan.gan_loss:

How to Quickly Implement Generative Adversarial Networks Using TFGAN

# The loss value for InfoGAN is the GAN loss value plus the mutual information I(c;G(z,c)). TFGAN provides APIs for calculating mutual information, as shown below. Here, structured_generator_inputs represent the noise information of the latent variable, and predicted_distributions represent the variational distribution Q(c|x).

How to Quickly Implement Generative Adversarial Networks Using TFGAN

WGAN

Martin Arjovsky et al. proposed WGAN (Wasserstein GAN), addressing the difficulties in training traditional GANs, the difficulty in indicating the training process of the generator and discriminator’s loss, and the lack of diversity in generated samples. It has the following advantages:

  1. It can balance the training levels of the generator and discriminator, making GAN model training stable.

  2. It ensures the diversity of generated samples.

  3. It proposes using Wasserstein distance to measure the degree of model training, where a smaller number indicates better training and higher quality images generated by the generator.

The algorithm of WGAN differs from the original GAN algorithm mainly in:

  1. Removing the sigmoid operation in the last layer of the discriminator model.

  2. Not taking the log operation for the loss values of the generator model and discriminator model.

  3. After updating the parameters of the discriminator model, truncating the absolute values of the model parameters to not exceed a fixed constant c.

  4. Using the RMSProp algorithm instead of momentum-based optimization algorithms like momentum and Adam.

The algorithm structure of WGAN is shown below:

How to Quickly Implement Generative Adversarial Networks Using TFGAN

TFGAN provides relevant APIs for WGAN, as shown below:

# Generator network loss equation

generator_loss_fn=tfgan_losses.wasserstein_generator_loss

# Discriminator network loss equation

discriminator_loss_fn=tfgan_losses.wasserstein_discriminator_loss

Summary

This article first introduces generative adversarial networks and TFGAN, where the generative adversarial network model is used for image generation, super-resolution image generation, image compression, image style transfer, data augmentation, text generation, and other scenarios; TFGAN is a TensorFlow library for quickly implementing various GAN models. It then explains the main ideas of the CGAN, ImageToImage, InfoGAN, and WGAN models and analyzes key technologies, including objective functions, network architectures, loss equations, and corresponding TFGAN APIs. Users can quickly practice generative adversarial network models based on TFGAN and apply them to relevant scenarios in the industrial field.

References

[1] Generative Adversarial Networks.

[2] Conditional Generative Adversarial Nets.

[3] InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets.

[4] Wasserstein GAN.

[5] Image-to-Image Translation with Conditional Adversarial Networks.

[6] https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/gan.

[7] https://github.com/tensorflow/models/tree/master/research/gan.

Author Introduction

Wu Wei (WeChat: allawnweiwu): PhD, currently an architect at IBM. Mainly engaged in research on deep learning platforms and applications, and research and development in the field of big data.

How to Quickly Implement Generative Adversarial Networks Using TFGAN

AI FrontlineKeeping up with cutting-edge AI technology communityHow to Quickly Implement Generative Adversarial Networks Using TFGAN

If you would like to see more similar quality content, remember to give a thumbs up before leaving!

┏(^0^)┛See you tomorrow!

Leave a Comment