Wasserstein GAN in Swift for TensorFlow

Source: Deep Learning on Medium

Vanilla Generative Adversarial Network (GAN) as explained by Ian Goodfellow in original paper.

I am a big fanboy of Apple Swift and deep neural networks. And the recently upcoming framework for deep learning is Swift for TensorFlow. So, obviously, I had to jump right into it! I have already written Wasserstein GAN and other GANs in either TensorFlow or PyTorch but this Swift for TensorFlow thing is super-cool. In the backend it is an ultimate effort to make Swift a machine learning language from compiler point-of-view. In this post I will share my work on writing and training Wasserstein GAN in Swift for TensorFlow. The code is open-source on GitHub and can be run in Google Colab right now!

History of Generative Adversarial Nets

Generative adversarial networks (GAN) were invented by Ian Goodfellow in 2014. GAN usually have 2 neural networks, namely, generator G, and critic C. And the only available data is unlabelled collection of real-world & real-valued datum (from Nature) which can be images, audio, etc. GANs were designed for improved real data modeling such that when a model is asked to say generate image it should be able to do so which is what G is for. Here, C helps G in learning to generate more realistic data by itself learning to predict that the image generated by G is fake. It also takes real image which it learns to call real one. This is an iterative process which improves C at predicting fake and real data and in turn helping G to shape its parameters such that it generates more realistic data.

These vanilla GAN don’t produce very good quality images. So, the work on improving the quality of image generation kept on continuing and one of the most important directions ever taken in this subfield is constraining Critic network in 1-Lipschitz set of function space and minimizing Wasserstein distance between G distribution (fake) and P distribution (real). Check out Wikipedia page for understanding Lipschitz continuity. Now let’s continue to code WGAN with Swift for TensorFlow!


Data is first thing to get to get the neural network learning. So, I used CIFAR-10 dataset which contains images of 10 categories as follows:

  • Airplane
  • Automobile
  • Bird
  • Cat
  • Deer
  • Dog
  • Frog
  • Horse
  • Ship
  • Truck

Each image is a 32x32 size RGB image. There are about 50k training images and 10k test images. Woah, I never noticed that the number of images & classes are close to those in MNIST dataset🤔. Anyway, I used this data for training my Wasserstein GAN to generate such images.

Data downloading and loading

First import TensorFlow and Python (version 3.x) using Swift for TensorFlow toolchain. Then import the Python libraries through Python interoperability feature! Now define some CIFAR-10 downloading, loading, and preprocessing functions. Finally, load the dataset.


There are some important configurations to be set for training the network. I have tried to keep the configurations are similar to WGAN paper as possible. So, I set the batch size to 64, images were resized to 64×64 dimensions and their number of channels (RGB) were three. WGAN was trained for 5 epochs (as suggested in PyTorch tutorials). The latent space of G was set to 128 dimensions. The number of iterations of C per G were set to 5, this is to approximate the 1-Lipschitz function nicely as suggested in paper. Also, the trainable parametric values of C must be bound to extremes [-0.01, 0.01].


Wasserstein Generative Adversarial Network

A model of WGAN contains a C and G network as discussed above. C contains multiple convolutional layers whilst G is made up of sequential transpose convolutional layers which are also sometimes wrongly called deconvolution layers.

Custom Layers

To make custom neural layers in Swift for TensorFlow make your structure conform to Layer protocol. And the parameters in Swift for TensorFlow are accessed by conforming your neural structure to KeyPathIterable protocol which is default but I wrote it for remembering how iteration over properties of a type in Swift happens. Currently, TransposedConv2D implementation in Swift for TensorFlow doesn’t work well so I decided to use UpSampling2D op following Conv2D layer the way it was suggested by Odena, et al., 2016. The Conv2D structure is used as it is following BatchNorm op which is also used after sequential UpSampling2D and Conv2D operations in place of TranposedConv2D. These custom layers I wrote are shown below in Swift code.

Custom neural layers

WGAN Architecture

G network takes a random vector of shape [128] from Gaussian distribution. This goes through a dense (fully-connected) layer following BatchNorm and relu activation function. The resulting transformed output is reshaped such that it can go through a sequence of couple of UpSampling2D, Conv2D, BatchNorm layers and relu activations. The final block of layers simply up samples this output following convolution layer and a tanh activation function.

The architecture of network C is such that it has 4 blocks of Conv2D followed by BatchNorm. Each of the block is followed by a leakyReLU activation function with a negative slope of 0.2. At the time of this writing, leakyReLU function is not implemented in Swift for TensorFlow so I implemented my own by making it differentiable. At last the output is flattened and is passed through a dense layer producing [1] dimensional output for giving probability of image being real/fake.

Wasserstein GAN architecture

Notice that in both G and C, the stride is (2, 2) and padding is set to same. Also the kernel size in both is set to (4, 4) in size. In both networks above, call(_:) functions are made differentiable using @differentiable attribute. Similar way is used to create differentiable LeakyReLU activation function.

Differentiable Leaky ReLU


The networks are trained for 5 epochs. At each iteration C was trained 5 times more than a single train step of G. A batch size of 64 was used for each network. I used RMSProp optimizer with 0.00005 learning rate for both networks. The zInput for G was sampled from uniform distribution. I also tracked the time it took to train for 1 epoch which was around 12 minutes for me. Graph of Wasserstein distance and couple of losses like G loss & C loss was also plotted using Matplotlib and this could be done because Swift for TensorFlow allows for Python interoperability. And one more thing I learned about Swift for TensorFlow is that you can iterate over properties of any arbitrary type conforming to KeyPathIterable protocol implemented in tensorflow/swift fork of apple/swift. It’s a super cool idea but much of the work still need to be done like accessing parameters of a particular layer and modification is needed to applying activation per layer.

Anyways, following is a deeper self explanatory Swift code for training my Wasserstein GAN!

Note that one needs to set the training/inference mode via the following lines of code.

// For setting to training mode
Context.local.learningPhase = .training
// If you want to perform only inference
Context.local.learningPhase = .inference


It is been a good experience touching deep neural networks with Swift language. I think Swift for TensorFlow will grow into a new mainstream machine learning framework most probably replacing much of PyTorch just like PyTorch did to vanilla TensorFlow in past. I really liked that it support eager execution as a default behavior and also I don’t have to manually set the device to train on & also I can use Cloud TPUs in Google Colab. It is good to see that TensorFlow community has been putting efforts to make Swift for TensorFlow available for usage on Colab by giving Jupyter-based Swift environment. One can also use Raw kinda namespace (it’s more like structure I guess because concept of namespace does not exist in Swift) to access some basic ops like mul(_:_:), add(_:_:), etc. in Swift for TensorFlow. It is an awesome effort of Chris Lattner & Swift for TensorFlow team at Google.

Following are my non-exhaustive list of concerns about the way Swift for TensorFlow allows the flexibility right now in terms of research prototyping.

Accessing parameters per Layer

Swift for TensorFlow adopts a novel approach to access properties of a type for modification. This is required to access and update the parameters of a neural network. This facility is provided to the type which conforms to KeyPathIterable protocol. This behavior is present for neural networks structure conforming to Layer protocol by default. So, you don’t have to write it again and again because most probably you might ever want to access properties of neural network. This works fine as I have done above in my Wasserstein GAN code but I still don’t get the flexibility to access parameters per layer but I just iterate over all properties once without knowing which layer they belong to currently. This was actually not my requirement but it can be a requirement in other cases when, say you wanna do transfer learning. It requires the access to particular layers for updating parameters. Hopefully, this flexibility will be provided soon in Swift for TensorFlow.

Activation Function per Layer

I think that Google will and possible need to change the way activations are applied to each layer (which is what I don’t use because LeakyReLU can’t be applied that way). It is something like Conv2D(…, activation: relu) where relu can also be replaced by linear for linear activation and more but LeakyReLU don’t fit well to this design because there is not give a -ve slope value for it. Here, function is not being called, it is simply a function passed as argument to Conv2D layer or any other layer. The best possible solution I can think of is to use enumeration instead of passing function as Conv2D(…, activation: .relu) or for LeakyReLU do something like Conv2D(…, activation: .leakyrelu(0.2)) where 0.2 is the associated value with enumeration case leakyrelu of activation functions.

I hope you enjoyed the read. If you indeed liked my article please tap some claps and share with your friend and follow me for more stuff like this! Apart from writing articles I am also actively tweeting about machine learning, blockchain, quantum computing on Twitter @Rahul_Bhalley!