Gated Recurrent Units (GRU) Comparison

The Gated Recurrent Unit (GRU), proposed by Cho et al., is a popular variation of the LSTM. By merging the cell and hidden states and using only two gates, GRUs offer a simpler structure with fewer parameters while maintaining comparable performance.


GRU Gates and Math

The GRU does not maintain a separate cell state. It tracks context using a single hidden state h_t updated via two gates: the update gate and the reset gate.

Update and Reset Equations

1. Update Gate: z_t = \\sigma(W_z [h_{t-1}, x_t] + b_z) decides how much of the past state to keep.
2. Reset Gate: r_t = \\sigma(W_r [h_{t-1}, x_t] + b_r) decides how much of the past state to forget.
3. Candidate State: \\tilde{h}_t = \\tanh(W_h [r_t \\odot h_{t-1}, x_t] + b_h).
4. Hidden State: h_t = (1 - z_t) \\odot h_{t-1} + z_t \\odot \\tilde{h}_t.

GRU in PyTorch

We initialize a GRU similarly to an LSTM, but it returns a single state tensor instead of a state tuple.

<pre><code class="language-python">import torch import torch.nn as nn # Input size=10, Hidden size=20 gru = nn.GRU(input_size=10, hidden_size=20, batch_first=True) x = torch.randn(3, 5, 10) # Forward pass returns output sequence and final hidden state output, h_n = gru(x) print(output.shape) # torch.Size([3, 5, 20]) print(h_n.shape) # torch.Size([1, 3, 20])</pre>

LSTM vs. GRU Trade-offs

Selecting between LSTMs and GRUs involves balancing model capacity, data size, and training speed.

Structural Simplicity

GRUs have fewer parameters because they use only two gates instead of three. This reduces overfitting on small datasets, saves GPU memory, and speeds up training iterations.

Empirical Comparison

While GRUs are faster to train and work well on small-to-medium datasets, LSTMs have higher representational capacity and may outperform GRUs on large datasets that contain complex, long-term dependencies.