Variational Autoencoders (VAEs) Basics

Variational Autoencoders (VAEs) are probabilistic generative models that represent latent spaces as continuous probability distributions. By optimizing the Evidence Lower Bound (ELBO), VAEs learn structured latent spaces suitable for sampling and data generation.


Probabilistic Formulation

VAEs frame encoding as estimating a probability distribution over latent variables, enabling generation via sampling.

Generative Modeling

Traditional autoencoders are deterministic: they map an input to a single static vector in the latent space. This design makes generation difficult because the latent space can have large, unconstrained gaps where the decoder has never been trained. Variational Autoencoders (VAEs), introduced by Kingma and Welling in 2013, address this by framing the latent space probabilistically. The encoder estimates a probability distribution over the latent variables, conditioned on the input: \\( q_\\phi(\\mathbf{z} \\mid \\mathbf{x}) \\).

In practice, the encoder outputs the parameters of a multivariate Gaussian distribution: a mean vector \\( \\mathbf{\\mu} \\) and a log-variance vector \\( \\log(\\mathbf{\\sigma}^2) \\). The decoder is a probabilistic model \\( p_\\theta(\\mathbf{x} \\mid \\mathbf{z}) \\) that reconstructs the input from a sampled latent vector \\( \\mathbf{z} \\). By mapping inputs to distributions rather than static points, VAEs ensure the latent space is continuous and suitable for sampling.

The Reparameterization Trick

To generate the output, the model must sample a latent vector \\( \\mathbf{z} \\) from the distribution \\( \\mathcal{N}(\\mathbf{\\mu}, \\mathbf{\\sigma}^2) \\). However, random sampling is a non-differentiable operation, which stops the flow of gradients during backpropagation. This prevents the encoder from receiving updates from the reconstruction loss. VAEs solve this using the reparameterization trick.

Instead of sampling directly from \\( \\mathcal{N}(\\mathbf{\\mu}, \\mathbf{\\sigma}^2) \\), the model samples an auxiliary noise variable \\( \\mathbf{\\epsilon} \\) from a standard normal distribution: \\( \\mathbf{\\epsilon} \\sim \\mathcal{N}(0, \\mathbf{I}) \\). The latent vector \\( \\mathbf{z} \\) is then computed deterministically as:

\\( \\mathbf{z} = \\mathbf{\\mu} + \\mathbf{\\sigma} \\odot \\mathbf{\\epsilon} \\)

Where \\( \\odot \\) represents element-wise multiplication. This shift isolates the stochasticity in \\( \\mathbf{\\epsilon} \\), allowing gradients to flow backward through the deterministic operations on \\( \\mathbf{\\mu} \\) and \\( \\mathbf{\\sigma} \\) during backpropagation.

The ELBO Objective Function

The VAE objective function maximizes the Evidence Lower Bound, balancing reconstruction accuracy with latent space regularization.

Mathematical Derivation

The training objective of a VAE is to maximize the marginal log-likelihood of the observed data: \\( \\log p(\\mathbf{x}) \\). Because calculating this likelihood directly is intractable (requiring integrating over all possible latent states), we derive and maximize the Evidence Lower Bound (ELBO):

\\( \\text{ELBO}(\\theta, \\phi; \\mathbf{x}) = \\mathbb{E}_{q_\\phi(\\mathbf{z} \\mid \\mathbf{x})}[\\log p_\\theta(\\mathbf{x} \\mid \\mathbf{z})] - D_{KL}(q_\\phi(\\mathbf{z} \\mid \\mathbf{x}) \\parallel p(\\mathbf{z})) \\)

The first term is the reconstruction term, which encourages the decoder to reconstruct the inputs accurately when sampling from the latent distribution. The second term is the Kullback-Leibler (KL) divergence, which measures the difference between the encoder's predicted distribution \\( q_\\phi(\\mathbf{z} \\mid \\mathbf{x}) \\) and a prior distribution \\( p(\\mathbf{z}) \\), typically chosen to be a standard normal distribution \\( \\mathcal{N}(0, \\mathbf{I}) \\). Minimizing the KL divergence regularizes the latent space, forcing the representations to cluster around the origin.

PyTorch VAE Implementation

The following PyTorch code implements a complete VAE, including the encoder heads, reparameterization trick, and custom ELBO loss function:

<pre><code class="language-python">import torch import torch.nn as nn class VariationalAutoencoder(nn.Module): def __init__(self, input_dim, hidden_dim, latent_dim): super().__init__() # Shared encoder backbone self.encoder_backbone = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.ReLU() ) # Encoder heads for mean and log-variance self.fc_mu = nn.Linear(hidden_dim, latent_dim) self.fc_logvar = nn.Linear(hidden_dim, latent_dim) # Decoder self.decoder = nn.Sequential( nn.Linear(latent_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, input_dim), nn.Sigmoid() # Restricts reconstruction values to [0, 1] ) def encode(self, x): h = self.encoder_backbone(x) return self.fc_mu(h), self.fc_logvar(h) def reparameterize(self, mu, logvar): std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return mu + eps * std def decode(self, z): return self.decoder(z) def forward(self, x): mu, logvar = self.encode(x) z = self.reparameterize(mu, logvar) recon_x = self.decode(z) return recon_x, mu, logvar # Custom loss function: Reconstruction (BCE) + KL Divergence def vae_loss_fn(recon_x, x, mu, logvar): # Reduction='sum' is standard for VAE loss balancing recon_loss = nn.functional.binary_cross_entropy(recon_x, x, reduction='sum') # Closed-form KL divergence for Gaussian prior N(0, I) kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) return recon_loss + kl_loss # Example run vae = VariationalAutoencoder(input_dim=784, hidden_dim=128, latent_dim=16) x_sample = torch.rand(4, 784) recon, mu, logvar = vae(x_sample) loss = vae_loss_fn(recon, x_sample, mu, logvar) print("Total loss:", loss.item())</pre>

Latent Space Properties and Generation

The KL divergence term forces the latent space to be continuous, though the resulting reconstructions can be blurry.

Continuity and Completeness

The KL divergence term in the ELBO loss function acts as a regularizer, forcing the latent distributions to stay close to a standard normal distribution \\( \\mathcal{N}(0, \\mathbf{I}) \\). This pressure shapes the latent space to have two key properties: continuity and completeness. Continuity means that points that are close in the latent space decode to visually similar outputs. Completeness means that any point sampled from the prior distribution decodes to a valid output.

Without the KL regularization term, the model would behave like a standard autoencoder, encoding different classes of data into isolated clusters in the latent space. Sampling from the spaces between these clusters would produce nonsense outputs, making generation unreliable.

Mode Collapse and Blurriness

While VAEs have a stable optimization objective (unlike GANs), they suffer from a well-known limitation: their reconstructions tend to be blurry. This blurriness occurs because the model is trained using element-wise loss functions (like MSE or BCE). These losses evaluate each pixel independently, and when faced with uncertainty about the exact position of a high-frequency detail, the model minimizes the loss by averaging all possible positions, producing a blurry output.

Additionally, if the KL divergence term is weighted too heavily, the model can suffer from "posterior collapse." In this scenario, the encoder outputs the prior distribution \\( \\mathcal{N}(0, \\mathbf{I}) \\) for all inputs, and the decoder ignores the latent codes entirely, relying on local correlations in the data to reconstruct the input.