Microscale
0
Act VIIIServing the Model
lesson flash-attention · 10 min · 50 xp

FlashAttention

Never materialize the L×L matrix

The L×L attention matrix is the problem

Naive attention computes the full L×LL \times L attention matrix, writes it to HBM (the GPU's main memory), softmaxes it, writes that back, then multiplies with VV. For L=4096L = 4096, that's 16M floats per head per layer — each read and write costing HBM bandwidth. Attention is memory-bound, not compute-bound.

FLOPs: O(L2d)Memory traffic: O(L2+Ld)\text{FLOPs: } O(L^2 d) \quad \text{Memory traffic: } O(L^2 + Ld)

The compute is fast; the memory traffic around the compute is the bottleneck. GPUs end up spending most of their time waiting for bytes.

FlashAttention — never materialize the matrix

Dao et al. 2022: don't write the full L×LL \times L matrix to HBM at all. Load blocks of QQ, KK, VV into on-chip SRAM (small, fast, scratchpad), compute attention over each block in place, accumulate the output directly, and never store intermediate softmax results. Use an online softmax algorithm to keep running normalizers correct as new blocks stream through.

memory traffic: O(Ld/M) where M=SRAM block size\text{memory traffic: } O(L d / M) \text{ where } M = \text{SRAM block size}

That's it. Same math, dramatically less HBM traffic — 2–4× speedup on attention, same quality. No approximation, no tricks: the output is identical to naive attention, byte-for-byte.

The only subtle bit — online softmax rescaling

Softmax looks like it requires seeing the whole row: you need m=maxjsjm = \max_j s_j and Z=jesjmZ = \sum_j e^{s_j - m}before you can normalize anything. FlashAttention's trick is that when a new block of keys arrives, you can update those running stats and rescale the partial output you already accumulated. Say you've processed one block and hold (m1,Z1,O1)(m_1, Z_1, O_1) where O1=jesjm1vjO_1 = \sum_j e^{s_j - m_1} v_j. A new block gives you fresh stats (m2,Z2,O2)(m_2, Z_2, O_2). The merged max is m=max(m1,m2)m = \max(m_1, m_2), and both the running sum and the running output get a single correction factor:

Z=em1mZ1+em2mZ2,O=em1mO1+em2mO2Z = e^{m_1 - m} Z_1 + e^{m_2 - m} Z_2, \quad O = e^{m_1 - m} O_1 + e^{m_2 - m} O_2

Every time the max changes, you multiply the running output by a scalar correction and keep going. The final O/ZO / Zafter the last block is bit-exact equal to the standard softmax. Milakov & Gimelshein 2018 published this "online softmax" trick years before attention needed it; Dao, Fu, Ermon et al. (NeurIPS 2022) saw that it unlocked tiled attention kernels and shipped FlashAttention. The whole speedup rides on this one numerical identity.

HBM = slow off-chip GPU memory · ~3 TB/sSRAM = fast on-chip scratchpad · ~19 TB/s (6× faster)
naive attention — memory traffic
HBM: L×L attention matrixHBM: softmax outputHBM: final output
3× HBM round trips per head per layer
FlashAttention — block-streaming
SRAM: Q, K, V blocksonline softmax accumulates in placerunning stats: max, sum-exp, outputHBM: final output (only)
1 HBM round trip per head per layer
FA2 speedup vs naive
2-4×
FA3 on H100 BF16
840TFLOPs
FA3 on H100 FP8
1300TFLOPs

FlashAttention-3 and Hopper-specific tricks

FlashAttention-3 (Shah, Dao et al. 2024) adds three Hopper-specific optimizations:

  • WGMMA— use Hopper's new warpgroup matmul instruction, dramatically higher throughput than Ampere's mma.sync.
  • TMA — Tensor Memory Accelerator, a hardware unit that handles global↔shared memory transfers asynchronously. Overlap data movement with compute.
  • FP8 with incoherent processing— block quantize Q, K, V on-the-fly for FP8 matmul, with incoherent rounding so errors don't align across blocks.

FA3 reaches 85% of H100 peakin BF16 and over 1 PFLOPs/s in FP8 — up from ~35% in FA2. It's in every serious serving engine on Hopper hardware.