P e x c e r a

The Concept of the Attention Mechanism (Bahdanau)

The attention mechanism, introduced by Dzmitry Bahdanau in 2014, revolutionized sequence-to-sequence modeling by enabling the decoder to dynamically focus on different parts of the input sequence. This mechanism replaces the static information bottleneck with a dynamic context retrieval loop.


The Core Mechanism of Bahdanau (Additive) Attention

Bahdanau attention computes alignment scores between the current decoder state and all encoder hidden states to calculate a weighted context vector.

Alignment Scores and Weights

In standard sequence-to-sequence models, the encoder compresses the input sentence into a single fixed-length vector. Bahdanau attention relaxes this constraint. Let \\( h_1, h_2, \\dots, h_{T_x} \\) be the hidden states of the encoder, and \\( s_{i-1} \\) be the hidden state of the decoder at the previous step. The alignment score \\( e_{ij} \\) measures how well the inputs around position \\( j \\) match the output at position \\( i \\). Bahdanau parameterizes this alignment score as a feedforward neural network that is trained jointly with the rest of the system:

\\( e_{ij} = \\mathbf{v}_a^T \\tanh(\\mathbf{W}_a s_{i-1} + \\mathbf{U}_a h_j) \\)

Where \\( \\mathbf{W}_a \\) and \\( \\mathbf{U}_a \\) are learnable weight matrices, and \\( \\mathbf{v}_a \\) is a projection vector. This is called "additive attention" because the decoder state and encoder state are combined linearly. The raw alignment scores are then normalized across the input sequence using a Softmax function to obtain the attention weights: \\( \\alpha_{ij} = \\frac{e^{e_{ij}}}{\\sum_{k=1}^{T_x} e^{e_{ik}}} \\), representing a probability distribution over the encoder states.

Context Vector Computation

Once the attention weights \\( \\alpha_{ij} \\) are computed, they are used to calculate the context vector \\( c_i \\) for decoder step \\( i \\). The context vector is a weighted sum of the encoder hidden states:

\\( c_i = \\sum_{j=1}^{T_x} \\alpha_{ij} h_j \\)

The context vector summarizes the relevant parts of the input sequence for the current decoder step. It is concatenated with the current decoder input and used to compute the new decoder state \\( s_i \\) and output logits. By computing a new context vector at each step, the decoder can retrieve information from different parts of the input sequence dynamically.

PyTorch Implementation of Additive Attention

A custom PyTorch module demonstrates the linear projections and tanh activations involved in calculating additive alignment scores.

Bahdanau Attention Module

The following PyTorch code implements the additive attention scoring and context vector generation steps:

<pre><code class="language-python">import torch import torch.nn as nn class BahdanauAttention(nn.Module): def __init__(self, encoder_dim, decoder_dim, attention_dim): super().__init__() self.W = nn.Linear(decoder_dim, attention_dim, bias=False) self.U = nn.Linear(encoder_dim, attention_dim, bias=False) self.v = nn.Linear(attention_dim, 1, bias=False) def forward(self, prev_decoder_state, encoder_states): """ Args: prev_decoder_state: Tensor of shape [batch_size, decoder_dim] encoder_states: Tensor of shape [batch_size, seq_len, encoder_dim] """ seq_len = encoder_states.size(1) # Project states into attention space # W_s shape: [batch_size, attention_dim] -> unsqueezed to [batch_size, 1, attention_dim] W_s = self.W(prev_decoder_state).unsqueeze(1) # U_h shape: [batch_size, seq_len, attention_dim] U_h = self.U(encoder_states) # Additive combination: tanh(W_s + U_h) # combined shape: [batch_size, seq_len, attention_dim] combined = torch.tanh(W_s + U_h) # Calculate alignment scores # scores shape: [batch_size, seq_len, 1] -> squeezed to [batch_size, seq_len] scores = self.v(combined).squeeze(2) # Normalize via softmax to get attention weights weights = torch.softmax(scores, dim=-1) # [batch_size, seq_len] # Compute context vector as weighted sum of encoder states # weights shape: [batch_size, 1, seq_len] # encoder_states shape: [batch_size, seq_len, encoder_dim] # context shape: [batch_size, 1, encoder_dim] -> squeezed to [batch_size, encoder_dim] context = torch.bmm(weights.unsqueeze(1), encoder_states).squeeze(1) return context, weights # Example run attn = BahdanauAttention(encoder_dim=64, decoder_dim=64, attention_dim=32) d_state = torch.randn(4, 64) # Batch size of 4 e_states = torch.randn(4, 10, 64) # Sequence length of 10 context, weights = attn(d_state, e_states) print("Context vector shape:", context.shape) # Should be [4, 64] print("Attention weights shape:", weights.shape) # Should be [4, 10]</pre>

Tensor Shape Verification

In the attention forward pass, tracking tensor shapes is crucial. The decoder state \\( s_{i-1} \\) of shape [batch_size, decoder_dim] is projected and unsqueezed to shape [batch_size, 1, attention_dim] to enable broadcasting. The encoder states of shape [batch_size, seq_len, encoder_dim] are projected to [batch_size, seq_len, attention_dim].

Adding these tensors triggers broadcasting along the sequence dimension, resulting in a joint tensor of shape [batch_size, seq_len, attention_dim]. The final dot product with the projection vector yields a score tensor of shape [batch_size, seq_len]. Softmax is applied along this sequence dimension, ensuring that the weights for each sample sum to 1. The context vector is then computed using batch matrix multiplication (torch.bmm) to yield a shape of [batch_size, encoder_dim].

Cognitive and Architectural Interpretation

Attention mimics human focus, providing both structural optimization and model interpretability.

Soft vs. Hard Attention

Attention mechanisms fall into two main categories: soft attention and hard attention. Bahdanau attention is a "soft" mechanism because it uses a continuous, differentiable softmax distribution over the input states. Because every step is differentiable, the entire model can be trained end-to-end using standard backpropagation and gradient descent. The context vector is a deterministic expectation across all inputs.

In contrast, "hard" attention selects a single input state to focus on at each step (treating attention weights as a multinomial distribution from which a single index is sampled). While hard attention reduces computational overhead during inference, the sampling operation is non-differentiable. Training hard attention models requires reinforcement learning techniques like policy gradient methods (REINFORCE), which are typically less stable than backpropagation.

Interpretability through Attention Maps

Deep learning models are often criticized as "black boxes." Attention mechanisms provide a window into the model's inner workings. By logging the attention weight matrix \\( \\alpha \\) during inference, we can generate a 2D alignment map (or heat map) showing which input words the model focused on when generating each output word.

In machine translation, this map visualizes word alignments between the source and target languages. If the model translates "white house" to "maison blanche," the attention map will show that the model focused on "house" while generating "maison," and then on "white" while generating "blanche." This interpretability is highly valuable for debugging and explaining model decisions.