Predicting further ahead
Speculative decoding gave us our first answer to the memory wall: ask a cheap draft model to propose several tokens, then verify them in one big forward pass of the target. Multi-token prediction — MTP — is a different move from the same playbook. Instead of a separate draft model, the base model itself grows a short chain of prediction modules and learns, during pretraining, to emit two, three, even four tokens per step.
DeepSeek-V3 trained its MTP modules alongside the main trunk from the start. Qwen3-Next (Sept 2025) did the same and shipped MTP in its default serving path — the first major non-DeepSeek model to do so. Both quote inference speedups in the 1.5–2× range at production batch sizes, without touching the quality of the main head's output. The trick is elegant, the math is clean, and the why it works is the whole memory-wall lesson cashing its check.
accept@k=2: 0.72
accept@k=3: 0.59
effective speedup: 2.51×
At , preset DeepSeek-V3 (~0.85), batch 1: one forward pass produces ~2.89 tokens in expectation — 2.51× faster than vanilla AR once the per-head compute tax and batch-size scaling are factored in.
Why decode waits on memory
Remember the roofline. An H100's fp16 matrix core can do ~990 TFLOPs/s; its HBM3 feeds ~3.35 TB/s. The ratio — the machine's arithmetic intensity balance point — sits around 295 FLOPs per byte. Above that intensity your kernel is compute-bound; below it, memory-bound.
Decode at batch 1 on a 70B model has an arithmetic intensity of roughly 1–2 FLOPs per byte. A single token forward reads ~140 GB of weights from HBM and does maybe 140 GFLOPs of actual arithmetic on them. Two hundred to six hundred times below the roofline. The GPU spends most of its cycles just waiting for the next weight shard to land in its registers, then crunches on that shard for a microsecond, then waits again.
That is the memory-wall lesson. Speculative decoding exploits it by amortising the weight-load across k verified tokens; MTP exploits it by amortising across Dpredicted tokens in one pass. Both are fighting the same enemy from different angles — one at serving time with a draft model, the other in the architecture with extra heads. When you see the 45 → 125 tok/s jump on the ticker above, you're watching the roofline gap close.
The one-matmul-per-token contract
A vanilla autoregressive transformer produces exactly one token per forward pass. The reason is not a law of physics; it is an accidental contract. Each layer of the trunk emits a hidden state , the LM head projects it to a vocabulary distribution, you sample a token , and that token becomes the next input. To predict you must know — causal attention demands it — so you run the whole stack again.
The model has, in fact, more information than it uses. The trunk at step t already has a rich context representation; it could plausibly guess severalnext tokens. The standard LM head just doesn't ask it to. Everything MTP does, at bottom, is ask it to.
Only the next token. Nothing further. The supervised signal never touches what the model thinks about or .
MTP's insight: train on further targets
Gloeckle, Youbi Idrissi, Rozière, Lopez-Paz & Synnaeve (Meta, April 2024) asked the minimal question: what if we just add more output heads? Put n linear heads on top of the same shared trunk, have head k predict , and sum the losses. No new hyperparameters. No routing. No separate draft model. Train that, and see what happens.
What happened was a small pretraining quality win, a non-trivial inference speedup, and a new architectural primitive that nobody had been using. On 13B-scale code models Gloeckle et al. reported +12% HumanEvaland +17% MBPP at pass@1, with MTP trained from scratch versus a control next-token-only model — and up to 3× wall-clock inference speedup when the extra heads were used as speculative drafters at serving time.
The Meta design is deliberately minimal: parallel linear heads, all reading the same final trunk hidden. That's the same shape as Medusa (Cai et al., 2024) but trained jointly with the backbone rather than bolted on after the fact. It leaves an obvious axis unexplored: what if the heads could see each other?
DeepSeek-V3's upgrade — sequential modules
DeepSeek-V3 (Dec 2024) did exactly that. Their MTP implementation diverges from Gloeckle's parallel heads in one crucial way: the modules form a chain, not a fan. Each module takes as input both the previous module's hidden state and the embedding of the token that module just predicted, runs them through a full transformer block, and emits its own hidden — which the next module then consumes.
- Take the previous hidden (for, this is the trunk's final hidden ).
- Take the embedding of the previously-predicted token, , through the shared embedding table.
- RMSNorm both independently. Concatenate on the feature axis → .
- Project through the unshared to get a -dim input to the transformer block.
- Run the unshared transformer block → hidden .
- Project through the shared → logits. Sample (or argmax) . Feed everything to module .
Notice what is shared and what is not. The embedding and the output projection are shared with the main model — no extra parameters for the token vocabulary, no extra parameters for output logits. But each module's projection matrix and transformer block are unshared across depths. This matters: the job of predicting t+2 is qualitatively different from predicting t+3 (the model knows one less actual token and has to commit to a further-out guess), so giving each depth its own block lets it specialize.
The cost is modest. For DeepSeek-V3 — 671B total, 7168 hidden dim, 61 trunk layers, 129280 vocab — adding one MTP module (D=1) adds roughly ~12–14B parameters: one 14336 × 7168 projection plus one full transformer block. That's ~2% of total parameters, in exchange for a ~1.8× inference speedup under self-speculative decoding. The paper ships with and does not publish ablations against , so the public evidence for how deeper chains would have performed is simply absent — the architecture supports them, but the trade-off wasn't measured in this work.
MTP mode: modules chain. Tk's input is Tk−1's hidden ⧺ the embedding of the token Tk−1 just predicted. The chain lets the model reason about what it said before predicting the next token — which is why acceptance stays high further out.
Training: one trunk, D+1 losses
Training an MTP model looks almost exactly like training a plain LM, with one change. The total loss is the main next-token loss plus the average of the D MTP module losses, weighted by a schedule that decays over training:
where is the cross-entropy of module k's logits against the true token at position t+k+1. The factor is a fairness knob: at the MTP losses are averaged, not summed, so deeper chains don't overwhelm the main loss.
controls how much the model caresabout MTP. DeepSeek-V3's schedule is specific: λ = 0.3 for the first 10T training tokens, then λ = 0.1 for the final 4.8T. The intuition is that early on the MTP signal provides a useful auxiliary objective — the model has to think slightly further ahead, which improves its representations. Later, as the main model crystalises, MTP is kept around primarily for inference speedup; a smaller weight prevents it from pulling the trunk toward optimising for the wrong distribution.
At inference: discard, or speculate
Now for the payoff. A trained MTP model can be served in one of two modes, and the choice matters.
At serving time, ignore the MTP modules entirely. Use just the main head for next-token prediction.
You still get the pretraining quality boost (the MTP loss regularised the trunk's representations), but no inference speedup.
Best when: latency already fine, batch is large, MTP modules were trained but you don't want the verification complexity in serving.
Run the full chain per forward pass. Each module produces a candidate token; the next step re-verifies them against the main-head distribution (rejection sampling, exactly as speculative decoding).
Accepted tokens advance the sequence; rejected tokens are replaced with a main-head sample. Zero quality loss. Speed gain scales with acceptance rate and D.
Best when: latency-bound, batch 1–8, H100/H200 or MI300X serving.
Production numbers from SGLang + H200 TP8 (July 2025 blog): 1.8× throughput at batch 1, 1.5× at batch 32. The speedup compresses at higher batch because the workload moves from memory-bound to compute-bound — there are enough tokens in flight to saturate the matrix cores even without MTP, so the “free” tokens the chain produces don't reduce the critical path as much. On AMD MI300X (SGLang benchmarks late 2025), MTP delivered 1.25–2.11× on random prompts — the wide range reflects sensitivity to prompt distribution: code and math get the upper end; chatty small-talk gets the lower end.
Qwen3-Next(Alibaba, September 2025) is the inflection point. It's the first major non-DeepSeek model to ship native MTP, and it ships with MTP on by default in the vLLM-hosted serving path. That matters: MTP is no longer a DeepSeek-specific quirk but a generic technique a serving engineer is expected to know. Expect more 2026 releases to quietly adopt it.
The serving take
MTP is the second answer, after speculative decoding, to the question “how do we decode more than one token per memory round-trip?” Speculative decoding answers with a draft model; MTP answers by modifying the target architecture. Both stack cleanly: you can run speculative decoding on top of an MTP-trained model, using its MTP modules as the draft. Some recent SGLang deployments do exactly that, chaining the gains.
For the SLM practitioner: if you are training from scratch and can afford the 1–2% parameter overhead plus the joint loss, MTP is close to a free speedup at serving time plus a small quality bump during pretraining. If you are wrapping an existing model and cannot retrain, Medusa is the cheaper compromise. And if the target model is someone else's pretrained checkpoint on HuggingFace, classical speculative decoding with a small draft remains the pragmatic choice. The three techniques are not rivals so much as points on a single curve: how much architectural commitment are you willing to trade for how much speedup.
At , DeepSeek-V3 preset, batch 1: expected accepted tokens per forward pass ≈ 2.89, effective speedup ≈ 2.51×. The memory wall didn't move, but we learned to ask it for more per knock.
Three tiers. Three ways to test the same ideas.
Recall checks the shapes and schedules. Apply runs the speedup math on new numbers. Reason transfers MTP to scenarios the lesson didn't cover.