Understanding GAN Networks from a Beginner’s Perspective

From | CSDN Blog Author | JensLee

Edited | Deep Learning This Little Thing Public Account

This article is for academic exchange only. If there is any infringement, please contact the backend to delete it.

Understanding GAN networks (Generative Adversarial Networks) from a beginner’s perspective can be thought of as a forgery machine, creating things that look real. Let’s start discussing how to create forgeries: (This mainly explains the GAN code, which is quite simple)
We will first take the example of generating fake images of puppies.
First, we need a model to generate puppy images, which we call the generator, and a model to judge whether the puppy images are real or fake, called the discriminator.
Understanding GAN Networks from a Beginner's Perspective
First, input a 1000-dimensional noise vector into the generator. The specific structure of the generator is shown below (you can look at it later if you want):
Understanding GAN Networks from a Beginner's Perspective
It’s actually quite simple. The code is as follows:
def generator_model():    model = Sequential()    model.add(Dense(input_dim=1000, output_dim=1024))    model.add(Activation('tanh'))    model.add(Dense(128 * 8 * 8))    model.add(BatchNormalization())    model.add(Activation('tanh'))    model.add(Reshape((8, 8, 128), input_shape=(8 * 8 * 128,)))    model.add(UpSampling2D(size=(4, 4)))    model.add(Conv2D(64, (5, 5), padding='same'))    model.add(Activation('tanh'))    model.add(UpSampling2D(size=(2, 2)))    model.add(Conv2D(3, (5, 5), padding='same'))    model.add(Activation('tanh'))    return model
The generator takes a randomly generated array of 1000 dimensions and outputs an image of size 64×64×3. The output is just an image. There’s no need to delve too deeply; the input is 1000 random numbers, and the output is an image.
Next, let’s look at the discriminator’s code and structure:
Understanding GAN Networks from a Beginner's Perspective
The code is as follows:
def discriminator_model():    model = Sequential()    model.add(Conv2D(64, (5, 5), padding='same', input_shape=(64, 64, 3)))    model.add(Activation('tanh'))    model.add(MaxPooling2D(pool_size=(2, 2)))    model.add(Conv2D(128, (5, 5)))    model.add(Activation('tanh'))    model.add(MaxPooling2D(pool_size=(2, 2)))    model.add(Flatten())    model.add(Dense(1024))    model.add(Activation('tanh'))    model.add(Dense(1))    model.add(Activation('sigmoid'))    return model
The input is an image of size 64, 64, 3, and the output is a number 1 or 0, representing whether the image is a dog.
Next, let’s discuss the specific operations based on the code:
Understanding GAN Networks from a Beginner's Perspective
We concatenate the real and fake images, label the real images as 1 and the fake images as 0, and input them into the training network.
# Randomly generated 1000-dimensional noise
noise = np.random.uniform(-1, 1, size=(BATCH_SIZE, 1000))
# X_train is the training image data, here we take out a batch size of images for training, these are real images (64 images)
image_batch = X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE]
# Here are the fake images generated by the generator
generated_images = generator_model.predict(noise, verbose=0)
# Concatenate the real and fake images
X = np.concatenate((image_batch, generated_images))
# Corresponding labels for X, the first 64 images are real, label is 1, the last 64 images are fake, label is 0
y = [1] * BATCH_SIZE + [0] * BATCH_SIZE
# Send the concatenated training data to the discriminator for training
# d_loss = discriminator_model.train_on_batch(X, y)
If you don’t understand this part, you can combine it with other explanations to read together.
After training here, the accuracy of the discriminator will continue to improve.
Now we come to the core of GAN networks:
def generator_containing_discriminator(g, d):    model = Sequential()    model.add(g)    # The parameters of the discriminator are not modified    d.trainable = False    model.add(d)    return model
The network structure is shown below:
Understanding GAN Networks from a Beginner's Perspective
This model consists of a generator and a discriminator: looking at the code, the upper part of this model is the generation network, and the lower part is the discrimination network. The generation network first generates fake images, which are then sent into the discrimination network for judgment. Here, there is a d.trainable=False, which means that only the generator’s parameters are adjusted, and the discriminator’s parameters are not changed. It’s simply ingenious.
Now let’s see how to train the generation network, which is also a core area:
# Train a batch of data        for index in range(int(X_train.shape[0]/BATCH_SIZE)):
            # Generate random noise            noise = np.random.uniform(-1, 1, size=(BATCH_SIZE, 1000))
            # These are all real images            image_batch = X_train[index*BATCH_SIZE:(index+1)*BATCH_SIZE]
            # Here generate fake images            generated_images = g.predict(noise, verbose=0)
            # Concatenate real and fake images            X = np.concatenate((image_batch, generated_images))
            # The labels for the first 64 images are 1, i.e., real images, and the last 64 images are fake images            y = [1] * BATCH_SIZE + [0] * BATCH_SIZE
            # Train the discriminator to continuously improve its recognition accuracy            d_loss = d.train_on_batch(X, y)
            # Generate random noise again            noise = np.random.uniform(-1, 1, (BATCH_SIZE, 1000))
            # Set the discriminator's parameters to be non-adjustable            d.trainable = False
            # ××××××××××××××××××××××××××××××××××××××××××××××××××××××××××            # Here we send in noise and assume that this noise is labeled as real            g_loss = generator_containing_discriminator.train_on_batch(noise, [1] * BATCH_SIZE)            # ××××××××××××××××××××××××××××××××××××××××××××××××××××××××××
            # At this point, set the discriminator to be trainable again, allowing its parameters to be modified            d.trainable = True
            # Print loss values            print("batch %d d_loss : %s, g_loss : %f" % (index, d_loss, g_loss))
The key point is this line of code
g_loss = generator_containing_discriminator.train_on_batch(noise, [1] * BATCH_SIZE)
First, this network model (defined above) is sent into the generator, then the generator generates images, and these images are sent into the discriminator, where the label at this time is 1, indicating real images, but in reality, they are fake images. At this point, the discriminator will judge them as fake images, and the model will continuously adjust the generator’s parameters. At this moment, the discriminator’s parameters are set to be non-adjustable, d.trainable=False, so to continuously reduce the loss value, the model will keep adjusting the generator’s parameters until the discriminator believes that this is a real image. At this point, the discriminator and generator have reached a balance. This means that the fake images generated by the generator can no longer be distinguished by the discriminator. Therefore, we continue to iterate to improve the accuracy of the discriminator, repeating this process until the generated images are indistinguishable even by humans.
Finally, I trained for about 65 epochs. In reality, generating relatively realistic images of dogs may require thousands of epochs. Of course, different network structures require different iteration counts. I ran this for about 65 epochs due to time constraints, and you can see that it has some resemblance to a dog. This is the result after training for 65 epochs:
Understanding GAN Networks from a Beginner's Perspective
That’s all the content.
Original link:https://blog.csdn.net/LEE18254290736/java/article/details/97371930

End

Recommended for you
A Review of Image Classification Problems Under Long-Tail Distribution (2019-2020)
Major Updates on GitHub: Online Development Launched, Time to Uninstall IDE
Trump Takes Aim at H1B Visas, LeCun and Wu Enda Condemn Publicly!
23 Amazing Pandas Codes Commonly Used in Data Analysis
How to Create Beautiful Illustrations in Research Papers?

Leave a Comment