How can we train GAN faster with better image quality? We have many new cost functions already. But Relativistic GAN stands out as it trains models with impressive FID score in fewer iterations. After 5K training iterations, the original GAN still produces noise while RaSGAN produces much higher-quality images.
In the original GAN, the discriminator maximizes its capability to distinguish real and fake images, and in practice, it performs pretty well. In particular for high-resolution images, the higher degree of freedom makes it harder to generate natural images and therefore much easier to be detected. It is not hard to train a discriminator with output D(X) close to zero for both fake images below.
But does it help in training the generator or create a stable model? Ironically, optimal discriminator often makes training harder. The quest for pushing D(X) to 0 for fake images may lead to greedy optimization that hurts performance and stability. Argumentatively, how the generator learn effectively when both images above score similarly.
Trends to new cost functions
For convenient, we will call the original GAN as SGAN (Standard GAN) in this article.
Let’s look into some general trends in the research of new cost functions. SGAN tries to squeeze the discriminator output D(X) into the two ends (0 or 1). On the contrary, new cost functions measure the difference (or distance) between real and fake images. When the discriminator is optimal, this is particularly important because we are not squeezing the costs for all fake images into saturated areas where gradients vanish. In short, the cost function should score the fakes differently subject to how natural the fakes are.
By design or not, new cost functions add regularization, through parameter clipping, Lipschitz constraint or gradient penalty, to the discriminator model which makes it harder to be optimal or overfitted. This balances the training in the discriminator and the generator better.
In addition, the goal for D(X) equals to 1 and 0 for real and fake images respectively may not be desirable. When GAN is in equilibrium, both discriminator and generator are optimal and D(X) should be 0.5 (the discriminator should have random odd in detecting fakes). This optimal state does not align well with the goal of D(X). For many different reasons, the target labels for real and fakes images in the discriminator and generator cost functions are redefined in new cost functions.
Relativistic GAN is not a new cost function. It is a general approach in devising new cost functions from the existing one. For example, we have RSGAN for SGAN. SGAN measures the probability that the input data is real. Relativistic GANs measures the probability that the real data is more realistic than the generated data (or vice versa). Let’s see how we make the SGAN cost function “relativistic” to demonstrate the difference.
Relativistic standard GAN (RSGAN)
In SGAN, the goal is determining how real the input is (D(x)). For RSGAN, we compute a “distance” D(xr, xf) instead, the probability that the real data is more realistic than the fake data.
We perform the transformation:
is transformed to RSGAN:
RSGAN reaches the optimal point when D(X) = 0.5 (i.e. C(xr)=C(xf))
RSGAN pushes D(X) towards 0.5 (the left figure). For SGAN, the generator goal is to push D(X) closer to 1 (the right figure).
SGAN is thought to be optimizing JS-divergence (Proof)
This can be rewritten as:
The JS-divergence is
As shown before, the goal of the SGAN generator is to push D(xf) towards 1. So the SGAN training deviates from the JS-divergency minimization. On the contrary, RSGAN behaves closer to the divergence minimization.
As the discriminator and generator become optimal, both D(xr) and D(xf) move towards 1 for SGAN. This completely ignores the fact that half of the input to the discriminator is fake and the expected value for D(x) should be 0.5. The RSGAN paper challenges whether the discriminator is making sensible predictions.
The gradient of SGAN is:
When the discriminator is optimal, 1 – D(xr) → 0. So the gradient for the discriminator mostly comes from the fake images (the second term in the discriminator cost). i.e. the discriminator stops learning from real images and learns mostly from fakes. At that point, SGAN is not learning how to make images more natural. In contrast, RSGAN learns from both as its gradients depend on both xr and xf.
Make cost function relativistic
We can make other cost functions relativistic too. Let’s see how we can generalize the formula of the cost functions first.
Integral probability metrics
Many new cost functions like WGAN can be generalized as Integral probability metrics (IPM):
which sup is the least upper bound and C constraints to a specific class of function F, like Lipschitz function. P and Q are the data distribution for the real and fake images respectively.
The objective function for SGAN is:
D(x) can also represent as sigmoid(C(x)).
Generalized cost function
IPM, SGAN and many cost functions can be generalized as:
where f and g are functions mapping a scalar input to another scalar and xr is the real image.
The corresponding relativistic cost function is:
In many cost functions, like LSGAN and IPM, they can further simplify to:
Algorithm for RGAN
Relativistic average GAN (RaGAN)
RGAN computes the probability that the sampled real data is more realistic than a sampled fake data (or vice versa). The computed values have high variance. (large swing from different samples) Alternatively, relativistic average GAN (RaGAN) compares the value with the average of its opponent. For example, the RaGAN version for SGAN is (RaSGAN):
And the generic cost functions for RaGAN become:
Algorithm for RaGAN
RaGAN and RGAN provide:
- more stable and faster training, and
- higher-quality and higher-resolution images.
Here are the results for the cat dataset by converting different cost functions into RGAN and RaGAN. FID score (the lower the better) is measured during different training iterations. SGAN and LSGAN fail to generate reasonable images at higher resolutions. RaGAN and RGAN show better FID scores comparing to where it is derived from. In addition, the standard deviation of the FID score is much smaller indicating the model is more stable.
Listing of RGAN and RaGAN
Source: Deep Learning on Medium