Multi-Head Attention
What it is
Multi-head attention runs multiple independent attention mechanisms in parallel, each operating on projected subspaces of the input. Results from all heads are concatenated and linearly transformed. This allows the model to attend to different representation aspects simultaneously—syntactic, semantic, long-range, local.
[illustrate: Input → split into h subspaces; parallel attention heads; concatenate outputs; final linear projection]
How it works
For h attention heads and dimension d_model:
-
Project to subspaces: Each head projects to d_k = d_model / h dimensions
- Q_i = X · W_Q^(i)
- K_i = X · W_K^(i)
- V_i = X · W_V^(i)
-
Parallel attention: Apply scaled dot-product attention to each head independently
- head_i = Attention(Q_i, K_i, V_i)
-
Concatenate: MultiHead(Q, K, V) = Concat(head_1, …, head_h) · W_O
-
Output projection: Linear transformation W_O combines head outputs
Example
# 8-head attention; d_model = 512, so d_k = 64 per head
Input: "the cat sat on the mat" (6 tokens × 512-dim)
Head 1: Learns article-noun agreement
"the" attends to "cat", "mat"
Head 2: Learns verb-subject relations
"sat" attends to "cat"
"cat" attends to "the"
Head 3: Learns spatial relations
"on" attends to "mat", "sat"
... (5 more heads learning different patterns)
Output: Concatenate 8 heads (8 × 64 = 512-dim) → project
Variants and history
Multi-head attention was introduced in the Transformer paper (Vaswani et al., 2017). The paper showed empirically that multiple heads each specialize: some capture local syntax, others distant semantic relations. Variants include sparse attention (limit which heads or positions attend), cross-attention with different input/output spaces, and factorized attention (decompose heads further). 8–16 heads is standard; larger models use more heads.
When to use it
Use multi-head attention in:
- Transformer encoder-decoder architectures
- Any task where diverse interaction patterns help
- When you have sufficient model capacity
- Building flexible, adaptable attention layers
Multi-head attention is standard in transformers. Cost: h × O(n^2) if implemented naively, but implemented efficiently in parallel. Trade-off: more heads capture more patterns but consume more parameters and computation.