The L×L attention matrix is the problem
Naive attention computes the full attention matrix, writes it to HBM (the GPU's main memory), softmaxes it, writes that back, then multiplies with . For , that's 16M floats per head per layer — each read and write costing HBM bandwidth. Attention is memory-bound, not compute-bound.
The compute is fast; the memory traffic around the compute is the bottleneck. GPUs end up spending most of their time waiting for bytes.
FlashAttention — never materialize the matrix
Dao et al. 2022: don't write the full matrix to HBM at all. Load blocks of , , into on-chip SRAM (small, fast, scratchpad), compute attention over each block in place, accumulate the output directly, and never store intermediate softmax results. Use an online softmax algorithm to keep running normalizers correct as new blocks stream through.
That's it. Same math, dramatically less HBM traffic — 2–4× speedup on attention, same quality. No approximation, no tricks: the output is identical to naive attention, byte-for-byte.
The only subtle bit — online softmax rescaling
Softmax looks like it requires seeing the whole row: you need and before you can normalize anything. FlashAttention's trick is that when a new block of keys arrives, you can update those running stats and rescale the partial output you already accumulated. Say you've processed one block and hold where . A new block gives you fresh stats . The merged max is , and both the running sum and the running output get a single correction factor:
Every time the max changes, you multiply the running output by a scalar correction and keep going. The final after the last block is bit-exact equal to the standard softmax. Milakov & Gimelshein 2018 published this "online softmax" trick years before attention needed it; Dao, Fu, Ermon et al. (NeurIPS 2022) saw that it unlocked tiled attention kernels and shipped FlashAttention. The whole speedup rides on this one numerical identity.
FlashAttention-3 and Hopper-specific tricks
FlashAttention-3 (Shah, Dao et al. 2024) adds three Hopper-specific optimizations:
- WGMMA— use Hopper's new warpgroup matmul instruction, dramatically higher throughput than Ampere's mma.sync.
- TMA — Tensor Memory Accelerator, a hardware unit that handles global↔shared memory transfers asynchronously. Overlap data movement with compute.
- FP8 with incoherent processing— block quantize Q, K, V on-the-fly for FP8 matmul, with incoherent rounding so errors don't align across blocks.
FA3 reaches 85% of H100 peakin BF16 and over 1 PFLOPs/s in FP8 — up from ~35% in FA2. It's in every serious serving engine on Hopper hardware.