GAN — Self-Attention Generative Adversarial Networks (SAGAN)


Photo by Stefan Cosma

Motivation

Image captioning and language translation use attention to improve accuracy. For example, an image captioning deep network focuses on different areas of the image to generate words in the caption.

The highlighted area is where it focuses on generating the specific word below.

For GAN models trained with ImageNet, they are good at classes with a lot of texture (landscape, sky) but perform much worse for structure (like dog’s legs). While convolutional filters are good at exploring spatial locality information, the receptive fields may not be large enough to cover larger structures. We can increase the filter size or the depth of the deep network but this will make GANs even harder to train.

Alternatively, we can apply the attention concept. For example, to refine the image quality of the eye region (the red dot on the left figure), SAGAN retrieve the feature maps on the highlight region in the middle figure to refine the eye. The right figure shows the attention area used for the mouth area (green dot).

Design

For each convolution layer, we refine each location with an extra term computed by the self-attention mechanism (the first term below γo).

where x is the original layer output and y is the new output.

First, we use the equation below to compute the attention map β.

Both Wf and Wg are model parameters to be trained. For each spatial location, an attention map is created which acts as a mask.

Next, we compute the self-attention output o.

The final output is

where γ is initialized as 0 so the model will explore the local spatial information first.

Here is some implementation details applied to both the generator and the discriminator:

  • The attention mechanism is appended to the existing output of the convolutional layers, and
  • Spectral normalization is used to stabilize the GAN training.

Loss function

SAGAN uses hinge loss to train the network:

Different training rate for the discriminator and the generator is used.

Further readings

Source: Deep Learning on Medium