CycleGAN Image Processing Tool for Style Transfer

CycleGAN Image Processing Tool for Style Transfer

1. Introduction to GAN

“Foodie, food spirit, foodies are the best of people”.
This GAN foodie is not the same as that foodie. The GAN we are going to discuss is the Generative Adversarial Network proposed by Goodfellow in 2014. So what is so magical about GAN?
Conventional deep learning tasks such as image classification, object detection, and semantic or instance segmentation can all be summarized as predictions. Image classification predicts a single category, object detection predicts bounding boxes and categories, and semantic or instance segmentation predicts the category of each pixel. GAN, on the other hand, generates something new, like an image.
The principle of GAN can be summarized in one sentence:
  • Through adversarial learning, it learns a generative model of the data distribution. GAN is an unsupervised process that captures the distribution of the dataset so that it can generate data from random noise that follows the same distribution.
The components of GAN: the adversarial game between the discriminator and the generator
  • D Discriminator: learns the boundary between real and fake data
  • G Generator: learns the data distribution and generates data

CycleGAN Image Processing Tool for Style Transfer

The classic loss function of GAN is as follows (minmax reflects the adversarial nature)

CycleGAN Image Processing Tool for Style Transfer

2. Practical CycleGAN Style Transfer

Now that we understand the role of GAN, let’s experience the magical effects of GAN. Here, we will use CycleGAN as an example to implement image style transfer. Style transfer refers to changing the style of the original image; as shown in the figure, the left is the original image, the middle is the style image (Van Gogh painting), and the right is the generated image with Van Gogh’s style, which retains most of the content of the original image.

CycleGAN Image Processing Tool for Style Transfer

2.1 Introduction to CycleGAN

CycleGAN is essentially the same as GAN, learning the underlying data distribution in the dataset. GAN generates images from random noise that follow the same distribution, while CycleGAN adds learned distributions to meaningful images to generate images in another domain. CycleGAN assumes that there exists a latent relationship between the two domains in image-to-image translation.
It is well known that the mapping function of GANs is difficult to ensure the validity of the generated images. CycleGAN utilizes cycle consistency to ensure that the generated images are structurally consistent with the input images. Let’s take a look at the structure of CycleGAN:

CycleGAN Image Processing Tool for Style Transfer

Key Features are summarized as follows:
  • Two-way GAN: two generators [G: X->Y, F: Y->X] and two discriminators [Dx, Dy]. The purpose of G and Dy is to generate objects, while Dy (the positive class is the Y domain) cannot discriminate. Similarly, F and Dx are the same.
  • Cycle consistency: G is the generator for generating Y, and F is the generator for generating X. Cycle consistency constrains the range of objects generated by G and F, ensuring that the object generated by G can return to the original domain through the F generator, e.g., x->G(x)->F(G(x))=x.
Adversarial loss is as follows:

CycleGAN Image Processing Tool for Style Transfer

CycleGAN Image Processing Tool for Style Transfer

CycleGAN Image Processing Tool for Style Transfer

2.2 Implementing CycleGAN

2.2.1 Generator

From the introduction above, there are two generators, one forward and one backward. The structure is based on the paper<span>Perceptual Losses for Real-Time Style Transfer and Super-Resolution: Supplementary Material</span>. It can be roughly divided into: downsampling + residual block + upsampling, as shown in the figure (excerpted from the paper):

CycleGAN Image Processing Tool for Style Transfer

Implementing downsampling uses stride=2 convolutions, and upsampling uses nn.Upsample:
# Residual block
class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()

        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features),
        )

    def forward(self, x):
        return x + self.block(x)

class GeneratorResNet(nn.Module):
    def __init__(self, input_shape, num_residual_blocks):
        super(GeneratorResNet, self).__init__()

        channels = input_shape[0]

        # Initial convolution block
        out_features = 64
        model = [
            nn.ReflectionPad2d(channels),
            nn.Conv2d(channels, out_features, 7),
            nn.InstanceNorm2d(out_features),
            nn.ReLU(inplace=True),
        ]
        in_features = out_features

        # Downsampling
        for _ in range(2):
            out_features *= 2
            model += [
                nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True),
            ]
            in_features = out_features

        # Residual blocks
        for _ in range(num_residual_blocks):
            model += [ResidualBlock(out_features)]

        # Upsampling
        for _ in range(2):
            out_features //= 2
            model += [
                nn.Upsample(scale_factor=2),
                nn.Conv2d(in_features, out_features, 3, stride=1, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True),
            ]
            in_features = out_features

        # Output layer
        model += [nn.ReflectionPad2d(channels), nn.Conv2d(out_features, channels, 7), nn.Tanh()]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

2.2.2 Discriminator

The traditional GAN discriminator outputs a single value to judge the degree of authenticity. In contrast, patchGAN outputs N*N values, where each value represents the authenticity of a certain-sized receptive field on the original image. Intuitively, it judges the authenticity of a cropped, repeatable part of the original area, and can be considered a fully convolutional network, first proposed in pix2pix (Image-to-Image Translation with Conditional Adversarial Networks). The advantage is that it has fewer parameters and can better capture high-frequency information from local areas.
class Discriminator(nn.Module):
    def __init__(self, input_shape):
        super(Discriminator, self).__init__()

        channels, height, width = input_shape

        # Calculate output shape of image discriminator (PatchGAN)
        self.output_shape = (1, height // 2 ** 4, width // 2 ** 4)

        def discriminator_block(in_filters, out_filters, normalize=True):
            """Returns downsampling layers of each discriminator block"""
            layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *discriminator_block(channels, 64, normalize=False),
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
            *discriminator_block(256, 512),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(512, 1, 4, padding=1)
        )

    def forward(self, img):
        return self.model(img)

CycleGAN Image Processing Tool for Style Transfer

2.2.3 Training

Loss and Model Initialization
# Losses
criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()

cuda = torch.cuda.is_available()
input_shape = (opt.channels, opt.img_height, opt.img_width)

# Initialize generator and discriminator
G_AB = GeneratorResNet(input_shape, opt.n_residual_blocks)
G_BA = GeneratorResNet(input_shape, opt.n_residual_blocks)
D_A = Discriminator(input_shape)
D_B = Discriminator(input_shape)
Optimizer and Training Strategy
# Optimizers
optimizer_G = torch.optim.Adam(
    itertools.chain(G_AB.parameters(), G_BA.parameters()), lr=opt.lr, betas=(opt.b1, opt.b2)
)
optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

# Learning rate update schedulers
lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
    optimizer_G, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step
)
lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
    optimizer_D_A, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step
)
lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
    optimizer_D_B, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step
)
Training Iteration
  • The training data consists of paired data, but they are unpaired, meaning A and B have no direct connection. A is the original image, and B is the style image.
  • Generator Training
    • GAN loss: the discriminator judges the loss between the two generated images fake_A and fake_B and the ground truth.
    • Cycle loss: the differences between the generated images fake_A and fake_B and the original images A and B.
  • Discriminator Training:
    • loss_real: MSELoss between the discriminator and the ground truth for A/B.
    • loss_fake: MSELoss between the generated fake_A/fake_B and the ground truth.
for epoch in range(opt.epoch, opt.n_epochs):
    for i, batch in enumerate(dataloader):

        # The data consists of paired data, but they are unpaired, meaning A and B have no direct connection
        real_A = Variable(batch["A"].type(Tensor))
        real_B = Variable(batch["B"].type(Tensor))

        # Adversarial ground truths
        valid = Variable(Tensor(np.ones((real_A.size(0), *D_A.output_shape))), requires_grad=False)
        fake = Variable(Tensor(np.zeros((real_A.size(0), *D_A.output_shape))), requires_grad=False)

        # ------------------
        #  Train Generators
        # ------------------

        G_AB.train()
        G_BA.train()

        optimizer_G.zero_grad()

        # Identity loss
        loss_id_A = criterion_identity(G_BA(real_A), real_A)
        loss_id_B = criterion_identity(G_AB(real_B), real_B)

        loss_identity = (loss_id_A + loss_id_B) / 2

        # GAN loss
        fake_B = G_AB(real_A)
        loss_GAN_AB = criterion_GAN(D_B(fake_B), valid)
        fake_A = G_BA(real_B)
        loss_GAN_BA = criterion_GAN(D_A(fake_A), valid)

        loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2

        # Cycle loss
        recov_A = G_BA(fake_B)
        loss_cycle_A = criterion_cycle(recov_A, real_A)
        recov_B = G_AB(fake_A)
        loss_cycle_B = criterion_cycle(recov_B, real_B)

        loss_cycle = (loss_cycle_A + loss_cycle_B) / 2

        # Total loss
        loss_G = loss_GAN + opt.lambda_cyc * loss_cycle + opt.lambda_id * loss_identity

        loss_G.backward()
        optimizer_G.step()

        # -----------------------
        #  Train Discriminator A
        # -----------------------

        optimizer_D_A.zero_grad()

        # Real loss
        loss_real = criterion_GAN(D_A(real_A), valid)
        # Fake loss (on batch of previously generated samples)
        # fake_A_ = fake_A_buffer.push_and_pop(fake_A)
        loss_fake = criterion_GAN(D_A(fake_A_.detach()), fake)
        # Total loss
        loss_D_A = (loss_real + loss_fake) / 2

        loss_D_A.backward()
        optimizer_D_A.step()

        # -----------------------
        #  Train Discriminator B
        # -----------------------

        optimizer_D_B.zero_grad()

        # Real loss
        loss_real = criterion_GAN(D_B(real_B), valid)
        # Fake loss (on batch of previously generated samples)
        # fake_B_ = fake_B_buffer.push_and_pop(fake_B)
        loss_fake = criterion_GAN(D_B(fake_B_.detach()), fake)
        # Total loss
        loss_D_B = (loss_real + loss_fake) / 2

        loss_D_B.backward()
        optimizer_D_B.step()

        loss_D = (loss_D_A + loss_D_B) / 2

        # --------------
        #  Log Progress
        # --------------

        # Determine approximate time left
        batches_done = epoch * len(dataloader) + i
        batches_left = opt.n_epochs * len(dataloader) - batches_done
        time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
        prev_time = time.time()

    # Update learning rates
    lr_scheduler_G.step()
    lr_scheduler_D_A.step()
    lr_scheduler_D_B.step()

2.2.4 Result Presentation

This article trains the transformation of Monet’s style, as shown in the figure: the first two rows are Monet-style paintings converted to ordinary photos, and the third and fourth rows are ordinary photos converted to Monet-style paintings.

CycleGAN Image Processing Tool for Style Transfer

Now let’s take a look at actual photos taken by a mobile phone:

CycleGAN Image Processing Tool for Style Transfer

2.2.5 Other Uses of CycleGAN

CycleGAN Image Processing Tool for Style Transfer

CycleGAN Image Processing Tool for Style Transfer

3. Summary

This article provides a detailed introduction to one application of GAN, CycleGAN, and its application to image style transfer. The summary is as follows:
  • GAN learns the distribution in the data and generates new data that follows the same distribution.
  • CycleGAN consists of two-way GANs: two generators and two discriminators; to ensure that the images generated by the generator have a certain relationship with the input images, rather than being randomly produced, cycle consistency is introduced to determine the differences between A->fake_B->recov_A and A.
  • Generator: downsampling + residual block + upsampling.
  • Discriminator: instead of generating a single judgment value for one image, it uses patchGAN, generating N*N values and then taking the average.

Author Introduction: Wedo Experimenter, Data Analyst; loves life and writing.

Support the Author

CycleGAN Image Processing Tool for Style Transfer

More Reading

Google AI Team Uses GAN Model to Synthesize Abnormal Creatures

NVIDIA Research Develops Method to Train GAN with Fewer Datasets

Attention Mechanism Practice in Image Caption Generation in Python

Highly Recommended

CycleGAN Image Processing Tool for Style Transfer

CycleGAN Image Processing Tool for Style Transfer

Click below to read the original text and join the community membership

Leave a Comment