This is part two of my “Notes on ‘Non-linear PCA” blog article series, where I discuss various deep generative modeling techniques. In this section, I am going to discuss GANs.
Generative Adversarial Networks
Generative Adversarial Networks (GANs) are a generative architecture based on a game-theoretical approach where two agents, a generator and a discriminator, compete to outperform the other in a zero-sum game. The discriminator’s main task is to differentiate between legitimate ground truth samples and fake samples produced by the generator network. Similar to previously explored generative models, the discriminator model learns a more compact latent space representation of the original distribution. The generator subsequently samples from the latent space to produce samples that are increasingly more difficult for the discriminator to identify as artificial. After training, the discriminator is typically discarded. To train a GAN, we use the following minimax loss function that optimizes the KL-divergence, with generator network $G$, discriminator network $D$, and adversarial noise $z$:
The psuedo-code for training the GAN can be found in Algorithm 1.
GANs often suffer from unstable training, vanishing gradients, and mode collapse. Training can be unstable due to the fact that the canonical method of training neural networks, gradient descent, typically does not work on GANs due to the high-dimensional solution space. Vanishing gradients occur due to highly complex loss functions with multiple linear algebra calculations required to compute a gradient. Further, mode collapse happens when a GAN routinely fails to generate diverse samples reflecting the ground truth distribution. An example could be a GAN trained on a dataset of farm animals. A case where the generator model exclusively generates images of horses instead of a diverstiy of animals like cows, chickens, pigs, or dogs would be indicative of mode collapse. A variety of explanations have been posited for this important shortcoming of GANs, including generator networks failing to notice a pattern of omission in training data, or if the generator learns “too fast” relative to the discriminator 1.
Wasserstein GANs
Wasserstein GANs are GANs that try to mitigate the mode collapse problem. Wasserstein GANs narrow the discriminator’s strategy set using a bounded Lipschitz norm $||D||_L \leq K$ where $K$ is a positive constant. This more tractable search space helps offset the mode collapse problem 2. Additionally, Wasserstein GANs take the idea that the discriminator network’s main purpose is to provide feedback to the generator network by indicating how “far off” the generator is from producing samples indistinguishable from the ground truth. As such, Wasserstein GANs use the Wasserstein metric as a measure of “distance” between the generator’s adversarial samples and the ground truth distribution learned by the discriminator. The Waserstein metric is defined as:
Using this improved objective function, Wasserstein GANs are able to avoid model collapse and vanishing gradients by offering a more meaningful and continuous gradients throughout training. This modified objective function is typically paired with weight clipping and using gradient penalty techniques to avoid training failure.