AI Synthwave

Original article was published on Artificial Intelligence on Medium

A single 360 x 640 RGB image consists of 360*640*3 = 691200 values, down-sampling by a factor of two gives us a four-fold reduction of 180*320*3 = 172800. This is a significant reduction in complexity, and therefore runtime.

If your network is too large (which is not necessarily due to the input image size, but can be) — you will receive Error: OOM when allocating tensor.

We then need to format the data into something readable by TensorFlow.

This is easy with the TensorFlow dataset object. We first shuffle the data, so that we are not feeding in the same group of images to the discriminator at once. Then we batch the data, this means that a batchsize number of examples will be seen by the network before we update network weights.

I’ve seen the batchsize most commonly range from 16 to 64 for image generation with GANs.

There is also the suggested use of SGD optimization function for the discriminator, this would actually update weights after a single training example. Nonetheless, we use Adam and a batchsize = 64.

DCGAN

We use the Deep Convolutional GAN architecture. Simplex munditiis by design — perfect for maximizing results and minimizing complexity.

Unlike the typical Convolution Neural Net (CNN) or Recurrent Neural Net (RNN) setup, GANs require us to take TensorFlow to a slightly lower-level to setup.

This means that we will split the code into four key parts:

  • Generator setup — including architecture setup, loss calculation and optimizer.
  • Discriminator setup — including same as above.
  • Step — the process we take for every single iteration during training, noise generation, model prediction, loss, and weight updates.
  • Fit — the training controller function, includes visualization and model checkpoint saving.

The model architecture itself is quite interesting. The DCGAN uses CNNs almost exclusively (no max pooling, minimal densely connected NNs).

We will go into more depth in the following sections, but in short — the generator and discriminators are two separate, and almost opposite CNNs.

The generator generates images, and the discriminator identifies fake images (produces by the generator) and real images (which we retrieved from YouTube).

Through competition, we produce an image generator. Capable of fooling a CNN (the discriminator).

Generator

The generator is stored in a class creatively named Generator. In this we will have three methods — dcgan, optimiser, and loss.

Our first method, dcgan, contains the model setup. We will build this to produce RGB images of size 360 * 640 (an array shape of 640, 360, 3).

We first initialize a TensorFlow Sequential model. After this we simply use the add method to add each successive layer. There are three layer sections, the input Dense and Reshape, followed by two Conv2DTranspose sections.

The input section takes our latent_units, which is simply an input layer of noise that we feed into the generator, and feeds these to a densely connected NN. These NN activations are then reshaped to fit the following transposed convolution layer.

LeakyReLU activation function

Following each NN layer, we include a BatchNormalization and LeakyReLU layer. BatchNormalization ensures that model weights are normalized, thus reducing the chances of vanishing or exploding gradients. LeakyReLU is our activation function.

The final Conv2DTranspose layer produces our generated images. We use a tanh activation function here, which is recommended over sigmoid when developing GANs.

Our generator optimizer and loss are both defined with optimiser and loss respectively. The generator’s loss is calculated with fake_preds — output from the discriminator.

Generator loss is minimized where the discriminator has mistakenly predicted all fake images 0 as real images 1. Hence we use tf.ones_like in our binary cross-entropy calculation.

Discriminator

The discriminator is much the same as our generator, but rather than using Conv2DTranspose layers we use the normal Conv2D. We also include a Dropout layer, which assigns a 50% probability of each value being masked at any one time. Helping us prevent over-fitting and encourage generalization.

We use the same optimizer, and a similar loss function. In this case the loss is minimized when the discriminator successfully identifies real images as 1 and fake images as 0.

Step

Both step and fit are contained within the Train class, which we initialize with our generator G and discriminator D models.

For each step/iteration, we must do several things:

  • generate noise— using a normal distribution of size batchsize, latent_units, in our case 64, 100.
  • generate fake_images— by processing noise through G.model.
  • make predictions — by processing real_images and fake_images through D.model.
  • calculate G_loss and D_loss — feeding real_output and fake_output to the loss functions we defined in each class, G.loss and D.loss.
  • calculate the G and D gradients — using a tf.GradientTape for each model, we feed in the model loss and trainable variables.
  • finally, we apply the gradient updates — using opt.apply_gradients.

In code, this looks like:

Note that self is being used as this is a method inside the Train class. At the end of the method we add the mean of the generator and discriminator loss to our history dataframe, which we later use for visualizing loss.

Fit

This is our final primary method, which acts as a controller all of the code we have just covered.

Compared to previous codes, this part is much easier to grasp. First, we create gen_noise to act as the sole noise profile that we use for generating visualizations. This is not used in training, and exists solely for us to visually measure how the generator is progressing over time.

We feed gen_noise into the generator G.model, which gives us a generated image array gen_image. This array is formatted, converted into a PIL Image object, then saved to file.

Loss is visualized for both the generator and discriminator with sns.lineplot, then saved to file. Both loss values are also printed to the console.

Lastly, we save model weights every 500 epochs using the save_weights method.

Generated Image