The Softmax Function for Output Layers
The Softmax function is the standard activation function for multi-class classification output layers. It normalizes a vector of raw scores (logits) into a probability distribution where all values are positive and sum to one.
Mathematical Formulation
The Softmax function exponentiates and normalizes logits to generate a probability distribution.
The Softmax Equation
For a vector of raw prediction scores (logits) $\mathbf{z} = [z_1, z_2, \dots, z_C]^T$ representing $C$ classes, the Softmax activation $a_i$ for class $i$ is defined as:
$$a_i = S(z_i) = \frac{e^{z_i}}{\sum_{j=1}^C e^{z_j}}$$
This formulation ensures that the output values are bounded between 0 and 1, and that the sum of all class probabilities is exactly equal to 1. The use of the exponential function ensures that larger logits are mapped to significantly larger probabilities, creating a 'soft max' effect.
Temperature Scaling
Temperature scaling modifies the Softmax distribution by dividing logits by a positive scalar $T$:
$$S(z_i; T) = \frac{e^{z_i / T}}{\sum_{j=1}^C e^{z_j / T}}$$
When $T > 1$ (high temperature), the distribution becomes flatter (higher entropy), representing higher uncertainty. When $T < 1$ (low temperature), the distribution becomes sharper, peaking around the maximum logit. This parameter is crucial in reinforcement learning, knowledge distillation, and LLM text generation.
Numerical Stability in Computation
Direct evaluation of the Softmax equation leads to overflow and underflow problems on standard floating-point hardware.
Exploding Exponentials and Underflow
Standard 32-bit floating-point variables overflow when evaluating $e^z$ for $z > 88$. If a logit in the output layer is large (e.g., $100.0$), computing the numerator of Softmax will result in an infinity value (overflow), yielding NaN outputs.
Conversely, if all logits are highly negative (e.g., $-100.0$), $e^z$ underflows to exactly zero, resulting in a division-by-zero error. These stability issues require mathematical adjustments before implementing Softmax in code libraries.
The Log-Softmax Trick
To avoid numerical instability, we exploit the shift-invariance property of Softmax. By subtracting the maximum logit value $M = \max(\mathbf{z})$ from each element before exponentiating, the inputs are shifted such that the maximum value is zero:
$$S(z_i) = \frac{e^{z_i - M}}{\sum_{j=1}^C e^{z_j - M}}$$
This ensures that the largest term in the sum is $e^0 = 1$, completely preventing overflow. In loss calculations, we evaluate the log-softmax directly: $\log S(z_i) = z_i - M - \log \sum e^{z_j - M}$, which improves numerical stability during backpropagation.
PyTorch Implementation
PyTorch implements stable Softmax and LogSoftmax layers, and we can write a manual stable version to verify the mathematical scaling.
Softmax and LogSoftmax in PyTorch
Here is how to apply Softmax and LogSoftmax in PyTorch, specifying the target dimension for normalization:
<pre><code class="language-python">import torch import torch.nn as nn # Logits representing a batch of 2 samples, 3 classes logits = torch.tensor([[100.0, 99.0, 98.0], [-100.0, -101.0, -102.0]]) # Apply standard Softmax (internally stabilized by PyTorch) softmax = nn.Softmax(dim=-1) probs = softmax(logits) # Apply LogSoftmax for improved training stability log_softmax = nn.LogSoftmax(dim=-1) log_probs = log_softmax(logits) print("Probabilities:\n", probs) print("Log Probabilities:\n", log_probs)</pre>In this code, we specify dim=-1, which normalizes scores across the last dimension (the class dimension). PyTorch automatically applies numerical stabilization under the hood, preventing NaN outputs even with extreme inputs like $100.0$ and $-100.0$.
Manual Stable Softmax Function
To understand the mechanics, we can write a manual version of the stable Softmax equation using PyTorch tensor operations:
<pre><code class="language-python">def stable_softmax(z): # z shape: [batch_size, classes] max_z, _ = torch.max(z, dim=-1, keepdim=True) # Subtract max for shift invariance shifted_z = z - max_z exps = torch.exp(shifted_z) sum_exps = torch.sum(exps, dim=-1, keepdim=True) return exps / sum_exps manual_probs = stable_softmax(logits) print("Manual check matches PyTorch?:", torch.allclose(probs, manual_probs))</pre>This implementation extracts the maximum logit along the class axis, keeping the dimension intact for proper broadcasting. Subtracting the max stabilizes the exponential calculations, yielding identical results to PyTorch's native function.