MXFP8 training for MoEs. It’s real. And it’s faster.
TorchAO folks — PyTorch’s quantization whizzes — claim a 1.3x speedup over BF16 for Llama4 Scout on a 256-GPU GB200 beast. That’s +30.2% end-to-end, hitting 81% of theoretical roofline glory. Neat trick. But does it hold water beyond the lab?
They ran it on Crusoe Cloud, full TorchTitan glory. Sequence lengths at 8192, C4 dataset, tiny local batch of 1 for convergence checks. Loss curves? Identical to BF16 after 3k steps. No wobbles. No disasters.
Here’s their money shot:
We recently demonstrated a +30.2% training speedup for Llama4 Scout with equivalent convergence to bfloat16, by using MXFP8 MoE training primitives in TorchAO!
Punchy, right? But read the fine print — they skipped MXFP8 on output projections (too finicky) and attention’s wk/wv (too puny for gains). Smart exclusions. Or signs of fragility?
Is MXFP8 Training for MoEs Actually Stable?
Convergence first. Always. They cranked a long run — 3k steps — on that 64-node cluster. FSDP sharded everywhere: 256 on attention, shared experts, FFNs; 64 on routed ones. EP=16 for experts. Activation checkpointing full throttle to shave memory. Torch.compile on model and loss. MXFP8 only on routed experts’ grouped GEMMs.
Loss curves overlap like twins. No divergence. Mirrors their dense model tests. Good sign. But batch size? One. Tiny. Scale to real-world 16-microbatch (for GEMM efficiency), and perf jumps: 5317 tokens/sec BF16 to 6921 MXFP8. 30.2% boom.
Single sentence verdict: Promising. But unproven at elephantine scales.
And here’s the engine: _to_mxfp8_then_scaled_grouped_mm. Quantizes activations and weights dynamically to MXFP8, runs scaled grouped GEMM, spits out original precision. Differentiable. ~1.8x faster than compiled BF16 for MoE shapes. Routed experts? 1.43x quicker. Whole model with shared experts? 1.3x e2e.
Roofline maxes at 1.37x for these GEMMs. They nabbed most of it. Impressive engineering. TorchAO 0.17.0, TorchTitan 0.2.1, Torch 2.11 dev. Cutting-edge stuff.
Why Does MXFP8 Speed Up MoE Training So Much?
MoEs are sparse activation party animals. Routed experts mean grouped GEMMs — multiple small matrices. BF16 chugs. MXFP8? Ditches precision for speed. Dynamic quant on inputs, then 2x TFLOPs/sec peak vs BF16.
Quant overhead? Minimal. Net win. Diagram in their post sketches forward/backward: Activations in, quant to MXFP8, GEMM, dequant, scale. Backprop mirrors it. Smooth.
Microbenchmarks (appendix gems):
For routed expert shapes, 1.8x kernel pop. Shared linears add more juice.
TorchTitan makes it plug-and-play. Config tweaks in docs. Or raw TorchAO primitive if you’re framework-agnostic. Repro scripts ready. No excuses.
But wait. Theoretical 1.37x, real 1.3x. Close. Yet they cherry-picked layers. Output embeddings tank convergence in low prec — old song. Wk/wv too small. Fair. Still, full-model MXFP8? Dream on.
The Hype Trap – And My Historical Gut Punch
This smells like FP16’s 2017 hype cycle. NVIDIA pushed mixed prec, everyone chased 2x flops, convergence nightmares ensued. Tweaks piled up: loss scaling, better opts. MXFP8 sidesteps much — scales post-GEMM, dynamic ranges.
Unique twist: MoEs amplify it. Sparsity hides quant noise in routing? Maybe. But my bold call — it’ll stall at 10T+ params. Routers get hypersensitive; top-2 gating crumbles under 8-bit wobble. We’ve seen it in GQA flops. PyTorch PR spins ‘equivalent convergence,’ but that’s small-batch C4. Throw Synthia or real multilingual at it? Cracks show.
Corporate spin? TorchAO’s no Meta cash cow, but NVIDIA-adjacent (GB200 love). Crusoe Cloud shoutout feels earned — not AWS simps. Still, 30% ain’t ‘revolutionary.’ It’s iterative grind. Good. Necessary. But don’t quit your BF16 job.
Perf table seals it:
| Number of GPUs | BF16 tokens/sec | MXFP8 tokens/sec | MXFP8 speedup vs BF16 | | 256 | 5317 | 6921 | 30.2% |
Scale to 1k GPUs? Multi-node black magic via FSDP/EP. They nailed it.
Short para: Devs, try it.
Repro? TorchTitan docs. Exclude FQNs: output, router.gate, etc. Batch to 16. Watch tokens fly.
TorchAO Primitives: Hacker’s Delight or Headache?
Prototype API: _to_mxfp8_then_scaled_grouped_mm. Docs pack rooflines, benches for Llama/GPT shapes. Plug into custom loops. No framework tax.
Overhead diagram? Quant kernels fast enough. Net 1.2-1.3x. Backward pass? Symmetric magic.
Critic hat: Why MXFP8 over E4M3? Wider dynamic range, less outlier pain. BF16 fallback smart. But GB200 lock-in? Hopper’s FP8 muscle underused elsewhere.
One para rant. Vendors chase FLOPs wars — AMD MI300X lurks, Grok-3 rumors. MXFP8 ties PyTorch to CUDA throne. Open source? Sure. Portable? Ha.
Long explore: Imagine Mixtral-8x22B retrain. Routed experts dominate compute. 30% shave? Months saved. Carbon footprint dips. But convergence on The Pile? Uncharted. Their C4 win teases; real frontier awaits.
And activation CKPT full — memory hero. No OOMs at 8192 ctx. LR 1e-4, warmup 2k. Conservative. Push harder?
Will MXFP8 Replace BF16 in Production MoEs?
Not yet. Baseline holds. But momentum builds. TorchTitan’s compile + AO = speed demon. Fork it, tweak, ship.
Dry humor break: If your cluster’s idling, congrats — repro this tonight.
FAQ time.
🧬 Related Insights
- Read more: Cursor 3’s Agent Window Changes Everything—But Not for Everyone
- Read more: rs-trafilatura Unlocks Firecrawl’s Hidden Precision
Frequently Asked Questions
What is MXFP8 training for MoEs? MXFP8 dynamically quantizes MoE inputs to 8-bit floating point for faster grouped GEMMs, delivering up to 1.3x training speed vs BF16 without hurting convergence.
How to enable MXFP8 in TorchTitan? Check TorchTitan docs: apply to routed experts, exclude sensitive layers like output/router.gate. Use TorchAO 0.17+ on GB200.
Does MXFP8 work on non-NVIDIA hardware? Prototype’s CUDA-tuned for Hopper/GB200; portability TBD. Roofline assumes NVIDIA FP8 units.