Keras custom data generators example with MNIST Dataset

Source: Deep Learning on Medium

Often, in real world problems the dataset used to train our models take up much more memory than we have in RAM. The problem is that we cannot load the entire dataset into memory and use the standard keras fit method in order to train our model.

One approach to tackle this problem involves loading into memory only one batch of data and then feed it to the net. Repeating this process until we have trained the network with all the dataset. Then we shuffle all the dataset and start again.

In order to make a custom generator, keras provide us with a Sequence class. This class is abstract and we can make classes that inherit from it.

We are going to code a custom data generator which will be used to yield batches of samples of MNIST Dataset.

Firstly, we are going to import the python libraries:

import tensorflow as tf
import os
import tensorflow.keras as keras
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Flatten
from tensorflow.keras.layers import Conv2D, MaxPooling2D
import numpy as np
import math

Then we are going to load the MNIST dataset into RAM memory:

mnist = tf.keras.datasets.mnist(x_train, y_train), (x_test, y_test) = mnist.load_data()

The MNIST Dataset consist of 60000 training images of handwritten digits and 10000 testing images.

Each image have dimensions of 28 x 28 pixels. You should take into account that in order to train the model we have to convert uint8 data to float32. Each pixel in float32 needs 4 bytes of memory.

Therefore the whole dataset needs :

4 bytes per pixel * (28 * 28 ) pixels per image * 70000 images + (70000*10) labels.

In total 220 Mb of memory that can perfectly fit in RAM memory but in real world problems we may need much more memory.

Our generator simulated generator is going to load the images from RAM but in a real problem they would be loaded from the hard disk.

class DataGenerator(tf.compat.v2.keras.utils.Sequence):

def __init__(self, X_data , y_data, batch_size, dim, n_classes,
to_fit, shuffle = True):
self.batch_size = batch_size
self.X_data = X_data
self.labels = y_data
self.y_data = y_data
self.to_fit = to_fit
self.n_classes = n_classes
self.dim = dim
self.shuffle = shuffle
self.n = 0
self.list_IDs = np.arange(len(self.X_data))
def __next__(self):
# Get one batch of data
data = self.__getitem__(self.n)
# Batch index
self.n += 1

# If we have processed the entire dataset then
if self.n >= self.__len__():
self.n = 0

return data
def __len__(self):
# Return the number of batches of the dataset
return math.ceil(len(self.indexes)/self.batch_size)
def __getitem__(self, index):
# Generate indexes of the batch
indexes = self.indexes[index*self.batch_size:
# Find list of IDs
list_IDs_temp = [self.list_IDs[k] for k in indexes]

X = self._generate_x(list_IDs_temp)

if self.to_fit:
y = self._generate_y(list_IDs_temp)
return X, y
return X
def on_epoch_end(self):

self.indexes = np.arange(len(self.X_data))

if self.shuffle:
def _generate_x(self, list_IDs_temp):

X = np.empty((self.batch_size, *self.dim))

for i, ID in enumerate(list_IDs_temp):

X[i,] = self.X_data[ID]

# Normalize data
X = (X/255).astype('float32')

return X[:,:,:, np.newaxis]
def _generate_y(self, list_IDs_temp):

y = np.empty(self.batch_size)

for i, ID in enumerate(list_IDs_temp):

y[i] = self.y_data[ID]

return keras.utils.to_categorical(

Then we are going to build the classification net:

n_classes = 10
input_shape = (28, 28)
model = Sequential()
model.add(Conv2D(32, kernel_size=(3, 3),
input_shape=(28, 28 , 1)))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dense(128, activation='relu'))
model.add(Dense(n_classes, activation='softmax'))

The next step is to make an instance of our generators:

train_generator = DataGenerator(x_train, y_train, batch_size = 64,
dim = input_shape,
to_fit=True, shuffle=True)
val_generator = DataGenerator(x_test, y_test, batch_size=64,
dim = input_shape,
n_classes= n_classes,
to_fit=True, shuffle=True)

If we want to check if the generator is working correctly, we can call to the next() method that yields a batch of samples and labels. Then check if the datatype of images and labels are correct, check the dimensions of the batch, etc…

images, labels = next(train_generator)

If we want that in one epoch the whole dataset is fed into the network:

steps_per_epoch = len(train_generator)
validation_steps = len(val_generator)

Finally we are going to train the network with the keras function fit_generator() .