Tensorflow2.0 PGGAN: Progressive Growing of GANs for Improved Quality, Stability, and Variation

Here, I introduce a simple code to implement PGGAN in Tensorflow 2.0.

Introduction

The key idea of “PGGAN” is growing the generator and discriminator progressively. The approach speeds up training and makes learning much more stable. It can also produce high quality images.

Progressive Growing of GANs

Existing methods have learned image features of all resolutions at the same time, but this paper proposes the following a useful method for generating high-resolution images. By starting at a low resolution of “4x4”, the large-scale structure is first learned, and then gradually increased to “1024x1024” while learning finer scale details.

In order to double the resolution of the image features, new layers are faded in the network smoothly as above. Through this method, there is an effect of preventing too sudden shock to a well-trained low-resolution layer. Let’s look at how to implement the above progressively increasing structure in Tensorflow 2.0.

The lowest level (4x4 resolution) discriminator and generator can be implemented as below. Refer to appendix A of the paper for hyper-parameters and the structure of each level. Networks consist of 4x4, 3x3 convolutions and ‘LeakyReLU’, “toRGB” and “fromRGB” blocks use 1x1 convolution.

Let me present how to grow the generator. The “fade in” block is placed alongside the existing “toRGB” block. Get the node above the “toRGB” block as the following code and double it. And I reuse the existing “toRGB” block defined as “x1”. Define a “fade in” block (x2) with two 3x3 convolutions and a new “toRGB”. Then “WeightedSum” x1 and x2 to smoothly put the “fade in” block. For the convenience of implementation, the (c) state generator(`generator_stabilize`) is also defined here. To be honest, activation of toRGB is just ‘linear’ in the table appendix A, but I used “tanh” to get stable output although it may decrease diversity.

Implementation of fade-in networks for discriminator is similar as that of generator. First, double the input resolution. And reuse the existing “formRGB” block defined as “x1”. Define a “fade in” block (x2) with a new “fromRGB” and two 3x3 convolutions. And add an AveragePooling2D layer. Then “WeightedSum” x1 and x2 to smoothly put the “fade in” block. And add existing discriminator layers. Also I define (c) state discriminator here.

WeightedSum can be implemented by inheriting the” Add layer”. Define “_merge_function” to perform Weighted Sum as below. Note that Alpha is updated during training, so it should be defined as backend.variable.

Minibatch Discriminator

GAN tends to describe only partial variations in training data. This phenomenon is called modal collapsing, and to solve it, methods such as featuring matching and historical averaging can be used. Here, PGGAN uses the minibatch discriminator introduced in “Improved techniques for GANs” by Salimans et.al. The above method calculates feature statistics not only for individual images but also for the entire minibatch, so that the generated image and the training images have similar statistics. Usually, it is used by adding a minibatch layer to the end of the discriminator.

To add feature statistics for minibatch, the minibatch layer can be implemented as follows.

1) Calculate the standard deviation of each feature at each spatial location of each minibatch.
2) Calculate the average of all features and spatial location.
3) Crate a constant feature map with the average by using tiling.
4) Concatenate input and average standard deviation of minibatch.

Equalized Learning Rate

Away from the existing delicate weight initialization method, PGGAN introduces a method of scaling weights at runtime after weight initialization with a standard normal distribution (N(0,1)). w=w/c where C is the per-layer normalization constant in He’s initialization.

An optimizer like Adam or RMASProp updates the gradient using the estimated standard deviation. This method takes longer to converge when the dynamic range of parameters is huge. In general, dynamic range is related to the number and size of parameters. The dynamic range can be similar by scaling by the number of parameters as follows. (stddev = sqrt(2 / fan_in)) This approach ensures that all weight parameters have the same dynamic range and thus have the same learning speed.

To implement “Equalized Learning Rate” scheme, I defined the WeightScaling class for adjusting dynamic range. And redefine dense and convolution as weight scaled version with WeightScaling layer.

Pixel normalization

For stable training, add pixel normalization after each convolution layer of the generator. It normalizes the feature vector for the pixel(axis=-1) by unit length as shown below. This method effectively prevents an increase in signal magnitude.

Loss Functions

Discriminator loss

As shown below code, The discriminator loss can be defined the sum of three loss: Adversarial loss, Gradient penalty loss, Draft loss. Please refer to the Keras WGAN-GP Example if you want to know about Gradient penalty loss.

The generator loss is really simple. It just uses Adversarial loss.

You can find FULL code on my github repository. : )

Results

I trained with images from the CelebA-HQ MASK dataset.
For each step, the model is trained by viewing 800k images.
The below figure shows the process of training 256x256 resolution images.
The models were trained on a Geforce 1080TI graphic card, and it took about 3 days.
In the figure below, you can see that the perturbation is stronger than what I expected. The main reason is that the discriminator in the above example does not learn label information.

However, it shows a pretty good result as shown in the figure below. The interpolation results are also nice, and show that the generator does not learn only small part of the features of the data, but represents the entire distribution well.

Acknowledges

TKarras’s repository, https://github.com/tkarras/progressive_growing_of_gans

Keras WGAN-GP Example, https://keras.io/examples/generative/wgan_gp/

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store