Tensorflow-GAN: Basics of Generative Adversarial Networks

moonhwan Jeong
5 min readFeb 21, 2019

--

Machine learning is generally classified into three types: Supervised learning, Unsupervised learning and Reinforcement learning.

Category of machine learning. Image by https://www.techleer.com/articles/203-machine-learning-algorithm-backbone-of-emerging-technologies/

Understanding objects is the ultimate goals of supervised/unsupervised learning. We can classify the image using well trained discriminator model based on the data. Also we can create a sample image using well trained generator model.

What I cannot create, I do not understand.
-Richard Feynman

If I know about it, I will be able to create it. But, he also said, “What does it mean, to understand? … I don’t know.” Understanding objects is such a difficult task.

Ian Goodfellow introduced GANs(Generative Adversarial Networks) as a new approach for understanding data.

What is GANs?

The GANs consists of a generator and a discriminator, which are in adversary to each other and gradually improve their performance.

Ian Goodfellow likened the above process to a banknote counterfeiter (generator) and a police(discriminator). The banknote counterfeiter try to cheat the police and on the other hand the police try to classify these counterfeit bills as real or fake. In this competition, both develop their ability to lie and distinguish. And as a result, The police can not distinguish between real and counterfeit bills.

The concept image of GANs

MNIST Tutorial

a sample mnist(number 1) matrix. Image from https://tensorflow.rstudio.com/tensorflow/articles/tutorial_mnist_beginners.html

The MNIST database consists of handwritten digits images(matrix). In this tutorial, we will generate sample images of handwritten digits. As I said above, we need to know the distribution of the pixel values that make up the digit image for generating it.

Evolution of G’s distribution (green) and the D’s decision boundary (blue). Image by https://arxiv.org/abs/1406.2661

Above figure shows how G knows the true distribution (black dots). As the learning is repeated, the distribution of G is fitted to the true distribution. Ultimately when fully matched, D can not distinguish(P=0.5) between the two digit images.

For learning, it requires training networks(generators and discriminators) and DB.

Generator

A 2-layer Neural Network. Image by http://cs231n.github.io/neural-networks-1/

Our generator is very simple. It consists of 2-fully connected layers. The number of input layer node is same “n_noise”. The number of output layer node is same “n_input” which is the resolution of mnist image. 2-fully connected layer network has 4-trainable variables: two weight variables G_W1,G_W2 and two bias variables G_b1, G_b2. A hidden layer uses “relu” function as activation function. The pixel range of the mnist image is [0,1]. So We uses sigmoid activation function as output layer for normalizing result to [0,1].

Discriminator

Our Discriminator also consists of 2-fully connected layers. Of course the number of input nodes is equal to n_input. The output of discriminator is true/false. So number of input node is 1. And output layer uses sigmoid activation function for normalizing result to [0,1]. If the input is determined to be fake, the output is close to zero, and vice versa.

Train

We made generator and discriminator. Now, we need training DB(mnist data-set). Fortunately, tensorflow provides it. I referred to the code from golbin’s github

First import libraries: tensorflow, numpy, os, plt(for saving result images). And import generator and discriminator class.

You can download and store mnist data-set by just a code-line.

mnist = input_data.read_data_sets("./mnist/data/", one_hot=True)

Define some parameters: total_epoch, batch_size, learning_rate

n_input is 28*28 which is equal to the size of mnist image.
n_noise is the length of latent vector, it is defined 128
We also defined “get_noise” function which generates a random vector array.

After that we define a generator and discriminator. G.net(Z) returns generated sample(fake sample) from a random vector Z. D.net() measures how realistic a sample is. D_gene represents the realistic score of fake sample and D_real represents the realistic score of real sample in mnist data set.

We want D to return the high score when it takes a real image and, to return the low score when it takes a fake image. On the other hand, G should create a fake image which tricks D into getting a high score. The two networks are in conflict. Ian Goodfellow introduce below function V (G, D) to represent the two-player minimax game between D and G.

The code represents above the equation:

loss_D = tf.reduce_mean(tf.log(D_real) + tf.log(1 — D_gene))

We train D to maximize above equation. But we use AdamOptimizer with minimize function, we train D to maximize “-loss_D”

In the above equation, we should train G to minimize log(1 − D(G(z)). Early in learning, gradient of log(1 − D(G(z)) is small and it is optimized very slowly. Instead, we train G to maximize log D(G(z)). Refer to the below figure.

image by http://edoc.sub.uni-hamburg.de/haw/volltexte/2018/4361/pdf/bachelor_thesis.pdf

Initialize all variables using sess.run(tf.global_variables_initializer()).

We can optimize D by sess.run([train_D]) for that we feed input. train_D takes loss_D which also takes D_gene, D_real. D_gene take G_out which takes Z. Also D_real takes X. So we feed X and Z to perform sess.run([train_D, loss_D]). X is assigned from batch_xs which is received from mnist dataset. Z is assigned from noise which is generated by get_noise function.
As in D, G is also optimized in the following code: sess.run([train_G, loss_G], feed_dict={Z: noise}).
We print the loss value per an epoch.

Also, we save generated images per 10 epoch.

Generated Images

Generated images(fake samples) look like real handwritten digits.

Animation about the convergence process

Above figure shows that the generator gradually converges as the learning is repeated. We can create a digit image using GAN, but still it has some artifact: In some images 7 and 9 are not clearly distinguished.

With DCGAN, you can get much better images. I’ll cover this in the next article.

The full code for this article is available at the following link: https://github.com/fabulousjeong/gan-tensorflow

--

--

No responses yet