Message Passing in GNNs
Message passing is the unifying framework for Graph Neural Networks. In each layer, nodes send feature messages to their neighbors, aggregate incoming messages, and update their internal hidden states to capture local structural context.
The Message Passing Framework
Message passing updates node representations through three sequential phases: message generation, aggregation, and state update.
Message, Aggregation, and Update steps
The message passing framework (Gilmer et al., 2017) generalizes spatial graph convolutions into three steps. Let \\( \\mathbf{h}_v^{(k)} \\) be the hidden feature vector of node \\( v \\) at layer \\( k \\). First, each neighbor node \\( u \\) in the neighborhood \\( \\mathcal{N}(v) \\) generates a message using a differentiable message function \\( M_k \\): \\( \\mathbf{m}_{uv}^{(k)} = M_k(\\mathbf{h}_u^{(k-1)}, \\mathbf{h}_v^{(k-1)}, \\mathbf{e}_{uv}) \\), where \\( \\mathbf{e}_{uv} \\) represents edge features.
Second, node \\( v \\) aggregates all incoming messages using an aggregation function \\( \\bigoplus \\): \\( \\mathbf{m}_v^{(k)} = \\bigoplus_{u \\in \\mathcal{N}(v)} \\mathbf{m}_{uv}^{(k)} \\). Third, the node updates its hidden state by combining its previous state and the aggregated message using a differentiable update function \\( U_k \\):
\\( \\mathbf{h}_v^{(k)} = U_k(\\mathbf{h}_v^{(k-1)}, \\mathbf{m}_v^{(k)}) \\)
By repeating this process for \\( K \\) layers, each node gathers information from nodes up to \\( K \\) hops away, learning representations that integrate local neighborhood structures.
Permutation Invariance in Aggregation
The aggregation function \\( \\bigoplus \\) is a critical component of the message passing framework. Because a node's neighbors have no inherent ordering, the aggregation function must be permutation invariant. This means that for any set of inputs, the output must remain identical regardless of the order in which they are processed. The most common permutation-invariant aggregators are Sum, Mean, and Max.
The choice of aggregator involves trade-offs. The Sum aggregator captures structural information by tracking the total number of neighbors (the degree of the node). The Mean aggregator focuses on the average characteristics of the neighborhood, which is useful when neighbor sizes vary widely. The Max aggregator highlights the most salient feature activation among the neighbors, regardless of neighborhood size.
Custom PyTorch Message Passing Layer
A custom PyTorch implementation shows how to perform neighborhood aggregation using tensor scatter operations.
Implementation of Message Passing
This PyTorch implementation demonstrates a message passing layer that computes messages and aggregates them using scatter operations:
<pre><code class="language-python">import torch import torch.nn as nn class CustomMessagePassing(nn.Module): def __init__(self, node_dim, message_dim): super().__init__() # Message MLP: input is concatenated source and target node features self.msg_mlp = nn.Sequential( nn.Linear(node_dim * 2, message_dim), nn.ReLU() ) # Update MLP: input is concatenated old node feature and aggregated message self.update_mlp = nn.Linear(node_dim + message_dim, node_dim) def forward(self, x, edge_index): """ Args: x: Node features [num_nodes, node_dim] edge_index: Graph connectivity matrix [2, num_edges] (source, target) """ num_nodes = x.size(0) src_idx, trg_idx = edge_index[0], edge_index[1] # Step 1: Generate messages from source to target nodes # Gather node features for all edge endpoints x_src = x[src_idx] # [num_edges, node_dim] x_trg = x[trg_idx] # [num_edges, node_dim] # Concatenate features and pass through MLP messages = self.msg_mlp(torch.cat((x_src, x_trg), dim=-1)) # [num_edges, message_dim] # Step 2: Aggregate messages using scatter_add_ (sum aggregator) # Initialize zero tensor for aggregated messages agg_messages = torch.zeros(num_nodes, messages.size(1), device=x.device) # Scatter-add messages to target node indices trg_idx_expanded = trg_idx.unsqueeze(-1).expand(-1, messages.size(1)) agg_messages.scatter_add_(0, trg_idx_expanded, messages) # Step 3: Update node states # Concatenate original features and aggregated messages update_in = torch.cat((x, agg_messages), dim=-1) # [num_nodes, node_dim + message_dim] new_x = self.update_mlp(update_in) # [num_nodes, node_dim] return new_x # Example run layer = CustomMessagePassing(node_dim=8, message_dim=16) x_nodes = torch.randn(4, 8) # 4 nodes, 8 features # Edges: 0->1, 1->2, 2->3, 3->0 edges = torch.tensor([[0, 1, 2, 3], [1, 2, 3, 0]]) out = layer(x_nodes, edges) print("Updated nodes shape:", out.shape) # [4, 8]</pre>Tensor operations
In this implementation, the graph connectivity is represented as an edge_index tensor of shape [2, num_edges], which is standard in graph libraries like PyTorch Geometric. We extract the source (src_idx) and target (trg_idx) node indices to perform parallel lookups, gathering the feature representations for all edge endpoints. This yields tensors of shape [num_edges, node_dim].
We concatenate these endpoint features and pass them through the message MLP in a single parallel operation. To aggregate these edge-level messages back to the node level, we utilize scatter_add_, which sums the message vectors according to their target node indices. This operation ensures that node aggregation scales with the number of edges, bypassing the need for sparse adjacency matrix multiplications.
Expressive Power and Limits
The expressive power of GNNs is bounded by the 1-WL isomorphism test, and models must balance depths to prevent over-smoothing.
The Weisfeiler-Lehman (1-WL) Graph Isomorphism Test
A key theoretical result in graph machine learning (Morris et al., 2019; Xu et al., 2019) is that the expressive power of any standard message passing GNN is bounded by the 1-Weisfeiler-Lehman (1-WL) graph isomorphism test. The 1-WL test is a classic algorithm for determining whether two graphs are structurally identical. It works by color refinement: each node is assigned an initial color, and at each step, nodes update their colors by hashing their current color and the multiset of their neighbors' colors.
If two graphs are non-isomorphic but the 1-WL test fails to distinguish them (assigning them the same color distributions), no standard message passing GNN can distinguish them either. This limitation has led to research into more expressive architectures, such as sub-graph GNNs or GNNs that incorporate node identifiers or structural features like cycles.
Over-smoothing and Over-squashing
When training deep GNNs, two performance bottlenecks occur: over-smoothing and over-squashing. Over-smoothing occurs when we stack too many message passing layers. As the number of layers \\( K \\) increases, each node aggregates information from an increasingly large neighborhood. In the limit, all node representations converge to the same vector, making it impossible to perform node classification.
Over-squashing occurs when the size of the neighborhood grows exponentially with the number of layers. The model is forced to compress an exponential amount of information from the neighborhood into a single fixed-size node feature vector, creating a bottleneck. Practitioners manage these bottlenecks by using skip connections, layer normalization, or confining GNNs to shallower architectures (typically 2 to 4 layers).