Microscale
0
Act IIInside the Machine
lesson multi-head · 8 min · 40 xp

Multi-head attention

Why parallel heads learn different patterns

One head only sees one thing

The attention mechanism you learned about in the last lesson computes a single weighted average over values. “Single” is the problem. A real sentence has many relationships happening at once — syntax, semantics, coreference, positional nearness, dependency structure — and a single weighted average cannot express them all simultaneously without collapsing them.

The fix, almost trivially, is to run several attention heads in parallel, each with its own learned projections. Each head can specialise in a different pattern. Their outputs are then concatenated and linearly mixed into the final layer output.

MultiHead(Q,K,V)=Concat(head1,,headh)WO\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h)\, W^O
headi=Attention(QWiQ,KWiK,VWiV)\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)

The learned weight matrices WiQW_i^Q, WiKW_i^K, WiVW_i^Vare different for each head — that's how the heads end up attending to different things even though the input is identical.

Four heads, four archetypes

On the right is a sentence processed by four attention heads in parallel. Click through them to see the pattern each one has learned. These are hand-wired demonstrations of patterns that realattention heads learn — you can find almost all of them in Clark et al. 2019's analysis of BERT heads. Some heads are extremely boring (“always attend to the previous token”); others are sophisticated (“attend from verb to subject”). Both are useful, and a real transformer has heads of both kinds and many more besides.

Theresearcherwhocitedthepaperwontheprize
Head 1: Previous token
Simply attends to the immediately preceding token. Sounds trivial, but it's one of the most common patterns real attention heads learn — and it's how positional information propagates.

The mixing happens at the end

After each head computes its own weighted sum, their outputs are concatenated and projected through a final matrix WOW^O. This is not decorative — the WOW^O projection is what lets the downstream FFN receive a single unified vector that contains information from all heads. Without it, the layer would output a giant concatenated blob with no interaction between heads.

A subtle and important fact: the head dimension dhd_h is usually dmodel/hd_{\text{model}}/h, so that after concatenation the shape is the same as the input. If dmodel=3072d_{\text{model}} = 3072 and h=24h = 24, each head operates in a 128-dimensional subspace. Notice that this is also how modern Phi-4-mini is configured — 24 query heads, each with dh=128d_h = 128. You've now seen where that number comes from.

Why not give every head access to the full dmodeld_{\text{model}}? You could, but the total parameter count of QKV projections would blow up as O(hdmodel2)O(h \cdot d_{\text{model}}^2). Thedh=dmodel/hd_h = d_{\text{model}}/h choice keeps it at O(dmodel2)O(d_{\text{model}}^2) regardless of head count — a parameter-free design choice that pays for itself forever.
comprehension check
comprehension · 1 / 3

Why do we use multiple attention heads instead of one big one?