Tensorflow — Make dataset from images
A data-set is needed to train the model. In the last article, we covered the model for generating faces. I used Celeb_A dataset(link) which has about 200k portraits photos to train the model.
In GAN model, we generally use dataset as ground-truth. We feeded “batch_x” to train discriminator in above code. How to get “batch_x” from images?
iterator, image_count = ImageIterator(data_root, batch_size, model.image_size, model.image_channels).get_iterator()
I made “ImageIterator” class to get “batch_x” iteraterable.
The function preprocess_image and load_and_preprocess_image read image file and do processing such as croping and normalizing. If we need, we can add other processing methods: flipping, tilting, rotating etc. Let’s look at the “get iterator” function, which is the core of the ImageIterator class.
First, we take and shuffle all the image paths in the directory.
Then, we create the dataset from the shuffled path. Tensorflow provide a function for this job. So it is very simple, just few lines in code.
path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)
image_ds = path_ds.map(self.load_and_preprocess_image)ds = image_ds.repeat()
ds = ds.batch(self.batch_size)
ds = ds.prefetch(buffer_size=AUTOTUNE)
We want to get each batch of data in loop(for, while). It is wise to create an iterator to help with this job. And it is really simple.
iterator = ds.make_initializable_iterator()
We can get “batch_x” using below code.
next_element = iterator.get_next()
batch_x = sess.run(next_element)
Don’t forget initializing the iterator before training.
sess.run(iterator.initializer)
You can download full code from https://github.com/fabulousjeong/dcgan-tensorflow
Refer to tensorflow tutorials for more details.