Conquer Class Imbalanced Dataset Issues using GANs

Source: Deep Learning on Medium

The summary() function can also be used to see the model layout and number of trainable parameters. Call the train() function to begin the discriminator and generator training.

Discriminator model

_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d_1 (Conv2D) (None, 128, 128, 16) 448
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU) (None, 128, 128, 16) 0
_________________________________________________________________
conv2d_2 (Conv2D) (None, 64, 64, 8) 1160
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU) (None, 64, 64, 8) 0
_________________________________________________________________
conv2d_3 (Conv2D) (None, 32, 32, 16) 1168
_________________________________________________________________
leaky_re_lu_3 (LeakyReLU) (None, 32, 32, 16) 0
_________________________________________________________________
conv2d_4 (Conv2D) (None, 16, 16, 8) 1160
_________________________________________________________________
leaky_re_lu_4 (LeakyReLU) (None, 16, 16, 8) 0
_________________________________________________________________
flatten_1 (Flatten) (None, 2048) 0
_________________________________________________________________
dropout_1 (Dropout) (None, 2048) 0
_________________________________________________________________
dense_1 (Dense) (None, 1) 2049
=================================================================
Total params: 11,970
Trainable params: 5,985
Non-trainable params: 5,985
_________________________________________________________________

Generator model

_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_2 (Dense) (None, 65536) 6619136
_________________________________________________________________
leaky_re_lu_5 (LeakyReLU) (None, 65536) 0
_________________________________________________________________
reshape_1 (Reshape) (None, 16, 16, 256) 0
_________________________________________________________________
conv2d_transpose_1 (Conv2DTr (None, 32, 32, 128) 524416
_________________________________________________________________
leaky_re_lu_6 (LeakyReLU) (None, 32, 32, 128) 0
_________________________________________________________________
conv2d_transpose_2 (Conv2DTr (None, 64, 64, 128) 262272
_________________________________________________________________
leaky_re_lu_7 (LeakyReLU) (None, 64, 64, 128) 0
_________________________________________________________________
conv2d_transpose_3 (Conv2DTr (None, 128, 128, 128) 262272
_________________________________________________________________
leaky_re_lu_8 (LeakyReLU) (None, 128, 128, 128) 0
_________________________________________________________________
conv2d_5 (Conv2D) (None, 128, 128, 3) 3459
=================================================================
Total params: 7,671,555
Trainable params: 7,671,555
Non-trainable params: 0
_________________________________________________________________

The training processes will look like below snippet

.....
>319, 5/10, d1=0.692, d2=0.761 g=0.709
>319, 6/10, d1=0.822, d2=0.759 g=0.690
>319, 7/10, d1=0.733, d2=0.764 g=0.723
>319, 8/10, d1=0.662, d2=0.740 g=0.743
>319, 9/10, d1=0.701, d2=0.683 g=0.758
>319, 10/10, d1=0.830, d2=0.744 g=0.728
>320, 1/10, d1=0.749, d2=0.717 g=0.731
>320, 2/10, d1=0.677, d2=0.796 g=0.722
>320, 3/10, d1=0.766, d2=0.700 g=0.717
>320, 4/10, d1=0.676, d2=0.736 g=0.765
>320, 5/10, d1=0.792, d2=0.762 g=0.730
>320, 6/10, d1=0.690, d2=0.710 g=0.719
>320, 7/10, d1=0.807, d2=0.759 g=0.708
>320, 8/10, d1=0.715, d2=0.747 g=0.711
>320, 9/10, d1=0.719, d2=0.720 g=0.731
>320, 10/10, d1=0.695, d2=0.717 g=0.694
################# Summarize ###################
>Accuracy real: 35%, fake: 57%

After 320 epochs, below is the quality of images I was able to produce. Training for this GAN took ~ 30 minutes. More complex generator and discriminator model could have produced better quality images.

save_plot() will generate a 7 by 7 matrix of images

Progression of the GAN towards a perfect cube