Loss Functions: Categorical Cross-Entropy

Categorical Cross-Entropy (CCE) is the standard loss function for multi-class classification tasks. It evaluates the alignment between the target distribution (typically one-hot encoded) and the predicted class probabilities.


Multi-Class Formulations

CCE measures the information distance across multiple categories, generalizing binary cross-entropy to $C$ classes.

The CCE Formula

For a batch of $N$ samples and $C$ categories, Categorical Cross-Entropy is defined as:

$$\mathcal{L}_{CCE} = -\frac{1}{N} \sum_{i=1}^N \sum_{c=1}^C y_{i,c} \log(\hat{y}_{i,c})$$

where $y_{i,c}$ is a binary indicator ($0$ or $1$) showing whether class $c$ is the correct label for sample $i$, and $\hat{y}_{i,c}$ is the predicted probability for class $c$. If targets are one-hot encoded, only the term corresponding to the true class contributes to the loss.

Target Representations

In classification, target labels can be represented in two formats: one-hot vectors or integer class labels. One-hot vectors ($[0, 1, 0]$) are mathematically explicit but consume significant memory when the number of classes $C$ is large (e.g., in language modeling).

Integer targets ($y=1$) represent class indexes, saving memory. PyTorch's native cross-entropy loss accepts integer targets directly, computing the index lookup dynamically during the forward pass to optimize memory access.

Gradient Dynamics and Regularization

Integrating Softmax and CCE stabilizes gradient propagation and allows for probability regularization during training.

The Softmax-CCE Unified Gradient

When predictions are generated using Softmax ($\hat{y} = S(\mathbf{z})$), the derivative of the CCE loss with respect to the input logit $z_k$ simplifies to:

$$\frac{\partial \mathcal{L}_{CCE}}{\partial z_k} = \hat{y}_k - y_k$$

This represents the simple difference between predicted probability and ground truth. This linear gradient prevents saturation because the gradient remains strong as long as there is an error, speeding up multi-class training.

Label Smoothing

In standard CCE, targets are hard probabilities ($0$ or $1$), which forces the network to output extreme logits to achieve zero loss, leading to overfitting and overconfidence. Label smoothing regularizes this by relaxing targets to:

$$y_k^{smoothed} = y_k(1 - \epsilon) + \frac{\epsilon}{C}$$

where $\epsilon$ is a small smoothing parameter. This prevents logits from growing infinitely, encouraging the network to generalize better to unseen test samples.

PyTorch Implementation

We can use PyTorch to calculate multi-class cross-entropy loss, comparing native modules with manual implementation.

Using nn.CrossEntropyLoss

PyTorch's nn.CrossEntropyLoss combines nn.LogSoftmax and nn.NLLLoss (Negative Log-Likelihood Loss) in a single stable step:

<pre><code class="language-python">import torch import torch.nn as nn # Logits representing a batch of 2 samples, 3 classes logits = torch.tensor([[2.0, 1.0, 0.1], [0.5, 3.0, 1.5]], requires_grad=True) # Target class indexes (integer format) targets = torch.tensor([0, 1]) # sample 0 is class 0, sample 1 is class 1 # Compute Cross-Entropy Loss loss_fn = nn.CrossEntropyLoss() loss = loss_fn(logits, targets) # Backpropagate gradients loss.backward() print("PyTorch Cross Entropy Loss:", loss.item()) print("Gradients on logits:\n", logits.grad)</pre>

In this code, we print the loss and gradients. PyTorch accepts raw logits directly, eliminating the need to apply Softmax manually before calculating loss. The resulting gradients show the prediction error per class.

Coding Manual Stable CCE

We can verify the unified computation by implementing stable CCE manually using the log-sum-exp formulation in PyTorch:

<pre><code class="language-python"># Manual calculation: -logits[target] + log(sum(exp(logits))) log_sum_exp = torch.log(torch.sum(torch.exp(logits), dim=-1)) # Extract logits corresponding to the true target classes gathered_logits = logits[torch.arange(len(targets)), targets] manual_loss = torch.mean(-gathered_logits + log_sum_exp) print("Manual Loss matches PyTorch?:", torch.allclose(loss, manual_loss))</pre>

This code illustrates how PyTorch avoids exponentiating raw logits directly. By evaluating the loss through the log-sum-exp identity, it maintains numerical precision and avoids precision errors, verifying the optimization strategy.