"They told me to carry sixteen bits instead of thirty-two, and somehow the whole convoy moved twice as fast. Nobody mentioned I would have to whisper to keep the small numbers awake."
An All-Reduce That Has Seen Some Gradients
Mixed-precision training is a single-node efficiency technique, but it pays its biggest dividend out on the network: when each worker carries gradients in sixteen bits instead of thirty-two, every all-reduce moves half the bytes and every accelerator fits a larger local batch. This is a scale-up trick that directly serves scale-out. We keep the high-precision master weights and a few delicate reductions in FP32 for numerical safety, push the heavy matrix multiplies and the gradient exchange into FP16 or BF16, and recover the small gradients that would otherwise vanish with a technique called loss scaling. The collectives of Chapter 4 and the parallelism strategies of the next chapters remain the main event; mixed precision is the labeled enabler that makes them cheaper per step.
Everything in this chapter so far has treated a gradient as a vector of numbers and asked how to move it between machines correctly and quickly. We never asked how wide each of those numbers is. The default in deep learning was once a 32-bit float for every weight, activation, and gradient, and that default is wasteful: modern accelerators compute low-precision matrix multiplies several times faster than full-precision ones, and a 32-bit gradient costs twice the memory and twice the network bytes of a 16-bit one. Mixed-precision training keeps the parts that need accuracy in 32 bits and runs everything else in 16, capturing most of the speed and memory of pure low precision while preserving the numerical stability of full precision. It is a per-node technique, firmly on the scale-up side of the book's central distinction, but it is included here because it changes the arithmetic of distribution: smaller gradients mean smaller all-reduce payloads, and lower memory means bigger batches per worker, both of which improve the scaling efficiency we study in Section 15.9.
1. The Numerical Formats and Their Trade-Off Beginner
A floating-point number spends its bits on two things: an exponent that sets the dynamic range (how large or small a value it can represent) and a mantissa that sets the precision (how finely it resolves values within that range). The formats used in deep learning differ in how they divide a fixed bit budget between these two, and that split is the whole story. A 32-bit IEEE float (FP32) spends 8 bits on the exponent and 23 on the mantissa, giving both wide range and fine precision. To go smaller you must sacrifice one or the other.
The classic 16-bit format, IEEE half precision (FP16), keeps 10 mantissa bits but shrinks the exponent to 5, which narrows its range badly: its smallest positive normal value is about $6.1 \times 10^{-5}$, and anything below roughly $6 \times 10^{-8}$ flushes to zero. Google's brain-float (BF16) makes the opposite choice, keeping the full 8-bit exponent of FP32 and accepting only 7 mantissa bits. BF16 therefore covers the same enormous range as FP32 (so gradients rarely underflow) at the cost of coarser precision. NVIDIA's TF32, used internally by Tensor Cores, is a 19-bit compute format with FP32's 8-bit exponent and 10 mantissa bits, a middle ground applied automatically inside matrix multiplies. The 2024-era frontier pushes further to 8-bit floats (FP8), available as E4M3 (4 exponent, 3 mantissa bits, more precision) and E5M2 (5 exponent, 2 mantissa bits, more range). Table 15.8.1 lays out the trade-off.
| Format | Total bits | Exponent | Mantissa | What it buys |
|---|---|---|---|---|
| FP32 | 32 | 8 | 23 | Wide range, fine precision; the master-weight baseline |
| TF32 | 19 | 8 | 10 | FP32 range at Tensor-Core speed, used inside matmuls |
| BF16 | 16 | 8 | 7 | FP32 range, coarse precision; no loss scaling needed |
| FP16 | 16 | 5 | 10 | Fine precision, narrow range; needs loss scaling |
| FP8 (E4M3) | 8 | 4 | 3 | Frontier compute format; more precision than E5M2 |
| FP8 (E5M2) | 8 | 5 | 2 | Frontier; more range, used for gradients |
The relationship between bits and the dynamic range is exponential. With an exponent field of $e$ bits, the representable magnitudes span roughly
$$\text{range} \approx 2^{2^{e}-1}, \qquad \text{so FP16's } e=5 \text{ gives } \sim\!2^{31}, \quad \text{BF16's } e=8 \text{ gives } \sim\!2^{255}.$$That gulf, $2^{31}$ against $2^{255}$, is exactly why a gradient that comfortably fits in BF16 can silently underflow to zero in FP16. The precision side scales the same way with mantissa bits: $m$ mantissa bits resolve values to a relative step of about $2^{-(m+1)}$, so BF16's 7 bits give coarser steps than FP16's 10. The practical lesson is that you choose the format by which failure mode you fear more. Fear vanishing gradients, and BF16's range saves you with no extra machinery; need to resolve nearby values finely, and FP16's mantissa serves better, provided you add loss scaling to protect its range.
The speed and memory savings of 16-bit arithmetic are a single-node, scale-up benefit, the same family as the quantization and KV-cache paging treated as per-node efficiency in Chapter 22. What makes mixed precision belong in a chapter about distributed training is that the gradient is the thing you communicate. Halving its width halves the all-reduce payload on every single step, and the freed memory lets each worker hold a larger local batch, which raises the compute-to-communication ratio. The technique is scale-up in mechanism but scale-out in payoff.
2. The Mixed-Precision Recipe: Master Weights and FP32 Reductions Intermediate
Naively casting an entire training step to 16 bits diverges, because two operations genuinely need the extra bits. The first is the weight update itself. An update adds a tiny step (the learning rate times the gradient) to a comparatively large weight; in 16 bits that small addend rounds away entirely, so the weights stop moving. Mixed-precision training fixes this by keeping a master copy of the weights in FP32 that accumulates every update faithfully, and casting a 16-bit working copy from it for the forward and backward passes. The second is reductions over many terms, such as the sum inside a batch-normalization statistic or a large dot product, where accumulating thousands of 16-bit values compounds rounding error; these accumulate in FP32 even when their inputs are 16-bit. Figure 15.8.1 shows how these pieces fit together with the distributed all-reduce.
For BF16 the picture simplifies: because its exponent matches FP32, gradients almost never underflow, so the two green scaling boxes in Figure 15.8.1 disappear and the loop is just cast-down, compute, all-reduce, apply. That simplicity is why BF16 has become the default on hardware that supports it. FP16 remains valuable on hardware without fast BF16, and on workloads where its extra mantissa bit matters, but it carries the loss-scaling obligation we turn to next.
3. Loss Scaling: Keeping Small Gradients Alive in FP16 Intermediate
The narrow range of FP16 creates a specific, measurable failure: gradients whose magnitude falls below about $6 \times 10^{-8}$ become exactly zero when cast to half precision, so the corresponding weights receive no update and the model fails to learn from them. Because backpropagated gradients in deep networks are routinely this small, especially in early layers, the loss is real, not hypothetical. Loss scaling fixes it with a multiplication. Before the backward pass, multiply the loss by a large constant $S$ (a power of two such as $1024$ or $65536$); by the chain rule every gradient is then scaled by the same $S$, shifting the whole distribution up into FP16's representable range. After the all-reduce and before the optimizer step, divide the gradients back by $S$ in FP32. The update is mathematically unchanged; only the intermediate representation was protected.
The code below demonstrates both halves of this section in pure NumPy. It first shows what fraction of a batch of tiny gradients vanishes under a naive FP16 cast and how a scale of $S = 1024$ rescues every one of them, then quantifies the memory and communication halving of 16-bit versus 32-bit on a billion-parameter gradient.
import numpy as np
# ---- Part 1: FP16 gradient underflow and how loss scaling recovers it ----
# FP16's smallest positive normal is ~6.1e-5; below ~6e-8 everything flushes to 0.
rng = np.random.default_rng(0)
# A batch of realistic tiny gradient magnitudes (e.g. deep-net early layers).
grads_fp32 = (rng.uniform(1e-9, 2e-7, size=20000)).astype(np.float32)
naive_fp16 = grads_fp32.astype(np.float16) # cast straight to half
underflowed = np.sum(np.asarray(naive_fp16) == 0.0) # how many vanished?
S = 1024.0 # loss-scale factor
scaled_fp16 = (grads_fp32 * S).astype(np.float16) # scale UP, then cast
recovered = np.asarray(scaled_fp16).astype(np.float32) / S # unscale in FP32
underflowed_scaled = np.sum(np.asarray(scaled_fp16) == 0.0)
mask = grads_fp32 > 0
rel_err = np.mean(np.abs(recovered[mask] - grads_fp32[mask]) / grads_fp32[mask])
print("underflowed to 0 without scaling:", int(underflowed),
f"({100*underflowed/grads_fp32.size:.1f}%)")
print("underflowed to 0 with S=1024 :", int(underflowed_scaled),
f"({100*underflowed_scaled/grads_fp32.size:.1f}%)")
print("mean rel. error after recovery :", f"{rel_err:.2e}")
# ---- Part 2: memory and all-reduce volume, 16-bit vs 32-bit ----
P = 1_000_000_000 # 1-billion-param gradient
print("FP32 all-reduce payload (GB) :", f"{P*4/1e9:.2f}")
print("FP16 all-reduce payload (GB) :", f"{P*2/1e9:.2f}")
print("communication volume reduction :", f"{(P*4)/(P*2):.1f}x")
print("local memory saved (GB) :", f"{(P*4-P*2)/1e9:.2f}")
underflowed to 0 without scaling: 2882 (14.4%)
underflowed to 0 with S=1024 : 0 (0.0%)
mean rel. error after recovery : 4.22e-04
FP32 all-reduce payload (GB) : 4.00
FP16 all-reduce payload (GB) : 2.00
communication volume reduction : 2.0x
local memory saved (GB) : 2.00
The two results in Output 15.8.1 are the whole argument of this section in numbers. The first block is the stability problem and its fix: a 14.4% loss of gradient signal, erased by one multiplication. The second block is the distributed payoff: every all-reduce that this chapter's collectives perform moves half the bytes, and every worker frees 2 GB it can spend on a larger batch. In production the scale factor is not fixed; a dynamic loss scaler raises $S$ when no overflow is seen for a while and halves it the moment a gradient becomes infinite, automatically tracking the largest safe value. We reproduce that controller in the exercises.
Loss scaling multiplies and later divides by the same $S$, so why choose a power of two like $1024$ rather than, say, $1000$? Because multiplying a float by a power of two only changes its exponent, never its mantissa: the operation is exact, with zero rounding error introduced by the scaling itself. The only rounding that happens is the unavoidable cast into FP16. Pick $1000$ and you would pay a tiny tax twice for nothing. The numerics quietly reward you for respecting binary.
4. Why This Matters for Distribution Intermediate
The connection to the rest of the chapter runs through two quantities that Chapter 3 taught us to track: bytes moved and bytes resident. On the communication side, the gradient all-reduce of Section 4.1 moves a volume proportional to the gradient's byte size. Halving the width of every element halves that volume directly, which in the alpha-beta cost model of Section 3.7 halves the bandwidth term of every step. For a communication-bound training job, that is close to a doubling of scaling efficiency for free. On the memory side, the model weights, the optimizer state, the activations cached for backprop, and the gradients themselves all shrink, and the freed memory buys a larger local batch. A larger local batch means more compute per all-reduce, raising the compute-to-communication ratio that determines whether adding workers helps or merely adds overhead, the central question of Section 15.9. Mixed precision thus pushes on both levers that govern scaling at once.
The manual recipe of Figure 15.8.1, casting a working copy, choosing where to keep FP32, scaling the loss, and unscaling before the step, is provided by PyTorch's automatic mixed precision in two objects. torch.autocast picks the right precision per operation (16-bit matmuls, FP32 reductions) inside its context, and GradScaler runs a dynamic loss scaler with overflow detection. The roughly thirty lines of bookkeeping collapse to a wrapper and three calls, and it composes transparently with DistributedDataParallel: the all-reduce simply carries 16-bit gradients.
# model already wrapped in DistributedDataParallel; per worker:
import torch
scaler = torch.amp.GradScaler("cuda") # dynamic loss scaler
for x, y in loader:
optimizer.zero_grad()
with torch.autocast("cuda", dtype=torch.bfloat16): # or torch.float16
loss = loss_fn(model(x), y) # forward in 16-bit
scaler.scale(loss).backward() # scale up, backward, DDP all-reduce
scaler.step(optimizer) # unscale in FP32, then step
scaler.update() # adjust S for next iteration
torch.autocast plus GradScaler. With dtype=torch.bfloat16 the scaler is effectively a no-op because BF16's range removes the underflow; with torch.float16 it does the real work shown in Output 15.8.1.Who: A deep-learning platform engineer scaling a 1.3-billion-parameter language model across 32 GPUs on a commodity 100 Gb/s Ethernet cluster.
Situation: Training ran in FP32 and spent nearly 40% of each step waiting on the gradient all-reduce; adding more GPUs made the wait worse, not the throughput better.
Problem: The interconnect was the binding ceiling. A 1.3-billion-parameter FP32 gradient is a 5.2 GB all-reduce every step, and the network could not move it fast enough to keep the accelerators busy.
Dilemma: Buy faster interconnect (a capital scale-up nobody had budget for) or shrink the bytes on the wire without changing the math, accepting the small risk of numerical instability.
Decision: Switch to BF16 mixed precision, because the hardware supported it and BF16's FP32-matching range meant no loss scaling and no tuning to get wrong.
How: They wrapped the forward in torch.autocast(dtype=torch.bfloat16) as in Code 15.8.2, left DistributedDataParallel untouched, and let the all-reduce carry 16-bit gradients.
Result: The all-reduce payload fell from 5.2 GB to 2.6 GB per step, communication wait dropped below 20%, and the freed GPU memory let them double the local batch, which raised the compute-to-communication ratio further. End-to-end throughput rose about 1.8x with no measurable accuracy change.
Lesson: When communication binds, the cheapest lever is often the width of the numbers, not the speed of the network. A per-node format choice fixed a cluster-level bottleneck.
5. The FP8 Frontier Advanced
If 16 bits halve the cost of 32, the obvious next question is whether 8 bits halve it again, and the 2024-2026 research and hardware frontier is answering yes for an expanding fraction of the training step. The difficulty is that 8 bits leave almost no room: E4M3 resolves values to a handful of levels and E5M2 covers range with only two mantissa bits, so a single global scale factor that worked for FP16 is far too coarse. The breakthrough has been finer-grained scaling, and we look at where that is heading next.
NVIDIA's Hopper and Blackwell Tensor Cores and the Transformer Engine library brought FP8 (E4M3 and E5M2) into mainstream training, and large public runs now use it: DeepSeek-V3 (2024) trained a 671-billion-parameter mixture-of-experts model substantially in FP8, reporting the communication and memory savings as central to its cost. The key enabler is fine-grained, per-block scaling rather than one scale per tensor. The Open Compute Project's Microscaling (MX) formats, standardized in 2023 and adopted across vendors, attach a shared scale to small blocks of 32 elements (MXFP8, MXFP6, MXFP4), so each block lands in range independently; Blackwell hardware decodes these natively. Active questions for 2025-2026 are how low the gradient all-reduce itself can go (FP8 and even FP6 gradient compression), how to keep the optimizer state stable at these widths, and whether MXFP4 can train, not merely serve, without accuracy loss. The trajectory is clear: each halving of element width is a halving of the bytes this chapter's collectives must move.
We have now placed mixed precision where it belongs: a per-node, scale-up technique whose payoff is felt out on the network and in the memory budget of every worker, sharpening the very scaling efficiency that the collectives and parallelism strategies aim for. With the gradient now half its former size on the wire, the next section confronts what is left: the practical bottlenecks, stragglers, and diminishing returns that decide how far data parallelism actually scales, in Section 15.9.
Using Table 15.8.1, explain in two or three sentences why BF16 needs no loss scaling while FP16 does, referring to the exponent-bit counts and the resulting dynamic ranges. Then describe a workload where FP16's extra mantissa bit (10 versus BF16's 7) would matter enough to justify the loss-scaling machinery, and one where BF16 is the clear choice. Why does the field default to BF16 on hardware that supports it?
Extend Code 15.8.1 into a dynamic loss scaler. Start with $S = 2^{15}$. On each simulated step, cast the scaled gradients to FP16 and check for overflow (any non-finite value); if overflow occurs, halve $S$ and skip the update, otherwise apply the update and, after a fixed number of clean steps (say 2000), double $S$. Run it on a stream of gradient batches whose magnitudes drift over time and plot $S$ against the step index. Explain how this controller automatically finds the largest safe scale without any manual tuning, and what happens if you make the growth interval too short.
A data-parallel step spends $T_\text{comp}$ on computation and $T_\text{comm} = \alpha + \beta \cdot B$ on the gradient all-reduce, where $B$ is the gradient byte size and $\beta$ is the inverse bandwidth from Section 3.7. Suppose FP32 training is communication-bound with $T_\text{comm} = 2\,T_\text{comp}$. Switching to BF16 halves $B$ (so halves the $\beta B$ term) and, by freeing memory, lets you double the local batch (so doubles $T_\text{comp}$ and the work done per step). Estimate the new ratio of communication to computation and the change in useful throughput per step, treating $\alpha$ as negligible. Which of the two effects, the smaller payload or the larger batch, helps more here, and how would your answer change if $\alpha$ dominated $\beta B$?