"They asked the central server to hold everyone's gradient. The central server, being one machine, politely explained that it had a bandwidth, and that bandwidth had a breaking point."
An All-Reduce That Has Seen Some Gradients
Synchronous SGD computes the same update on every worker only if their partial gradients are combined correctly, and the way that combination is wired into the network decides whether training scales to thousands of accelerators or chokes on a single bottleneck. The previous sections established that the data-parallel gradient is an exact sum of per-shard pieces. This section is about the combine step itself: whether you sum or average the pieces (and what that does to the learning rate), whether the pieces flow through one central server or through a symmetric all-reduce with no center at all, and how a single all-reduce can be made to cost almost nothing by hiding it behind the backward pass. It also shows how gradient accumulation lets one worker pretend to hold a much larger batch than its memory allows, while cutting how often the network has to run.
By the time a distributed training step reaches its combine phase, every worker is holding a vector of the same shape: its local gradient, computed on its own shard of the mini-batch. The arithmetic that fuses these vectors into one shared update is simple, but the systems question hiding behind it is not. Who adds the vectors together? Does the result pass through a coordinator, or do the workers exchange directly? How many bytes cross the network, and can any of that traffic happen while the accelerators are still busy computing? Those questions separate a training run that doubles its throughput when you double the workers from one that adds workers and gets slower. This section answers them for the pattern that became the default for deep learning: all-reduce SGD.
1. Sum or Mean, and the Learning Rate That Comes With It Beginner
Each worker $k$ holds a partial gradient computed on its local batch $\mathcal{B}_k$ of size $b_k$. There are two natural ways to define what the workers compute and then how to fuse the results, and confusing them is the most common source of a learning rate that is silently off by a factor of $K$. The first convention has each worker return an unnormalized sum of per-example gradients, and the combine step adds those sums and divides once by the global batch size $B = \sum_k b_k$:
$$g_{\text{sum}} = \sum_{k=1}^{K} \sum_{i \in \mathcal{B}_k} \nabla \ell(w; x_i, y_i), \qquad \bar{g} = \frac{1}{B}\, g_{\text{sum}}.$$The second convention has each worker return its local mean gradient, and the combine step averages those means. When every worker has the same batch size $b_k = B/K$, the average of the means equals the global mean, so the two conventions agree:
$$\bar{g} = \frac{1}{K} \sum_{k=1}^{K} \left( \frac{1}{b_k} \sum_{i \in \mathcal{B}_k} \nabla \ell(w; x_i, y_i) \right) \quad \text{(equal-size case only)}.$$The practical rule is that the update $w \leftarrow w - \eta\, \bar{g}$ should always be driven by the mean gradient $\bar{g}$ over the global batch, because the learning rate $\eta$ you tuned was tuned against a mean. If you accidentally feed it the sum $g_{\text{sum}}$, you have multiplied your effective learning rate by $B$, and training will diverge on the first step. The collective that frameworks expose, $\texttt{all\_reduce}$ with a SUM reduction, returns the sum; turning it into a mean is your one extra division, and forgetting it is a rite of passage. The unequal-size case matters too: the mean-of-means in the second equation is wrong when workers receive different batch sizes, exactly the failure that Chapter 4 warns about for weighted collectives, and only the sum-then-divide-by-$B$ form stays correct.
Hardware collectives reduce with SUM because addition is associative and order-free, which is what lets a ring or tree combine partial results without a coordinator. The mean that your optimizer actually wants is your responsibility: divide the all-reduced sum by the global batch size $B$, not by the worker count $K$, unless every worker carries an identical batch. Tie the learning rate to the mean gradient and the sum-versus-mean ambiguity disappears; ignore it and you have a hidden factor of $K$ (or $B$) waiting to blow up the first step.
2. All-Reduce SGD versus the Parameter Server Intermediate
There are two ways to wire the combine step across machines, and they have opposite shapes. The parameter-server pattern is centralized and asymmetric: every worker pushes its local gradient to one or more dedicated server nodes, the servers add the gradients, apply the update, and the workers pull back the fresh weights. The all-reduce pattern is decentralized and symmetric: there is no server, every worker plays the identical role, and the same collective leaves every worker holding the full summed gradient simultaneously. We treat the parameter-server design in depth in Chapter 11; here we contrast it with all-reduce to explain why all-reduce became the default for dense deep-learning models. Figure 10.5.1 shows the two topologies side by side.
The decisive difference is bandwidth at the center. In the parameter-server picture, the server's network link must absorb a gradient from every worker on each step, so the traffic into the center grows with $K$ and the central link saturates first. Ring all-reduce, the algorithm built in Section 4.4, has no center: the $P$-element gradient is split into $K$ chunks that circulate around the ring, and each worker sends and receives only $\frac{P}{K}$ elements per hop. Over the full reduce-scatter-then-all-gather schedule, every worker moves about $2\frac{K-1}{K}P$ elements regardless of $K$, so the per-worker bandwidth cost is essentially constant in the cluster size. That bandwidth optimality, not any change in the math, is why dense deep-learning training abandoned the central server for symmetric all-reduce. The same arc is flagged in Section 1.1: the all-reduce you perform by hand returns, scaled out, as the engine of every data-parallel trainer.
A parameter server is the office printer of distributed training: perfectly fine when three people use it, a riot when three hundred queue up at 9 a.m. Ring all-reduce is the realization that nobody actually needs a printer in the middle if everyone just hands their page to the person on their left.
3. Gradient Accumulation: A Bigger Batch Than Memory Allows Intermediate
Suppose the batch size you want will not fit in accelerator memory, because the activations for that many examples overflow the device. Gradient accumulation buys you the larger batch without the memory: split the per-worker batch into $m$ smaller micro-batches, run forward and backward on each one in turn, and add its gradient into an accumulator buffer instead of stepping the optimizer. After $m$ micro-batches the accumulator holds the sum of all of them, which is exactly the gradient you would have computed had the full batch fit at once. Only then do you run the all-reduce and the optimizer update. Because activations for one micro-batch are freed before the next begins, peak memory is set by the micro-batch, not by the effective batch.
There is a second payoff that is squarely a scale-out concern: the all-reduce now runs once per $m$ micro-batches instead of once per micro-batch, so the communication frequency drops by a factor of $m$. You are amortizing one fixed-cost network exchange over $m$ local computations, which is the same lever that local-update methods pull (and that Section 10.7 develops into full communication-efficient optimization). Figure 10.5.2 lays the timeline out.
Who: A research engineer fine-tuning a vision-language model on a four-GPU workstation with modest per-card memory.
Situation: The recipe from the model's authors specified an effective batch of 512, but each GPU could fit only 16 examples before running out of memory on activations.
Problem: Four GPUs at 16 examples each give a batch of 64, eight times too small; the learning-rate schedule tuned for 512 was unstable at 64, and a bigger machine was not in budget.
Dilemma: Retune the entire schedule for the small batch (slow, and it drifts from the published result), rent larger accelerators (over budget), or simulate the large batch with gradient accumulation (free, but the loop must be written so the optimizer steps only after accumulation).
Decision: They used gradient accumulation with $m = 8$ micro-batches per GPU, recovering the effective batch of $4 \times 16 \times 8 = 512$ without new hardware.
How: They divided each step into eight micro-batches, summed the gradients locally, ran a single all-reduce after the eighth, then stepped the optimizer; the only subtlety was dividing by the true global count $512$, not by the micro-batch count.
Result: Loss curves matched the published run, and because the all-reduce fired once per eight micro-batches, network time per effective step dropped by roughly eight times, making the small cluster competitive.
Lesson: Gradient accumulation converts a memory ceiling into extra sequential compute and, as a bonus, cuts communication frequency; it is the first thing to reach for when the target batch will not fit.
4. The Combine Step in Code, and Why It Equals One Big Batch Intermediate
The claim that ties this section together is that mean-aggregated all-reduce SGD with gradient accumulation over $m$ micro-batches produces the identical update to a single step over one large batch, up to floating-point rounding. The code below verifies it directly. It performs one big-batch step over all $N$ examples as the reference, then reconstructs the same step the distributed way: $K$ workers, each splitting its rows into $m$ micro-batches, summing the micro-batch gradients locally, all-reducing the per-worker sums, dividing by the global count $B = N$, and taking one update. It also confirms the sum-then-divide form and the mean-of-means form agree when shards are equal.
import numpy as np
rng = np.random.default_rng(7)
N, d = 4096, 12
X = rng.standard_normal((N, d))
w_true = rng.standard_normal(d)
y = X @ w_true + 0.1 * rng.standard_normal(N)
w0 = np.zeros(d)
lr = 0.05
def grad_sum(w, idx):
# SUM (not mean) of per-example gradients on a subset of rows
r = X[idx] @ w - y[idx]
return 2.0 * (X[idx].T @ r)
# Reference: one big-batch step over ALL N examples, mean gradient.
big = w0 - lr * (grad_sum(w0, np.arange(N)) / N)
# All-reduce SGD with gradient accumulation: K workers, each with B rows;
# every worker splits its B rows into m micro-batches, SUMS them locally,
# then ONE all-reduce (SUM) over workers, divide by total count, one update.
K, m = 8, 4
worker_rows = np.array_split(np.arange(N), K)
partials = []
for rows in worker_rows:
micro = np.array_split(rows, m)
local = sum(grad_sum(w0, mb) for mb in micro) # accumulate m micro-batches
partials.append(local)
allreduced = np.sum(partials, axis=0) # all-reduce SUM across workers
acc = w0 - lr * (allreduced / N) # mean-normalize, one update
print("examples N :", N)
print("workers K, micro m :", K, m)
print("max |w_acc - w_big| :", f"{np.max(np.abs(acc - big)):.2e}")
print("rel error of the step :", f"{np.linalg.norm(acc - big) / np.linalg.norm(big):.2e}")
# sum-then-divide vs mean-of-means: identical only because shards are equal-size.
sum_step = w0 - (lr / N) * allreduced
mean_step = w0 - lr * np.mean(partials, axis=0) / (N // K)
print("sum/N vs mean-of-means :", f"{np.max(np.abs(sum_step - mean_step)):.2e}")
big and the accumulate-then-all-reduce update acc are compared element by element; the final line checks the sum-then-divide and mean-of-means conventions against each other.examples N : 4096
workers K, micro m : 8 4
max |w_acc - w_big| : 2.50e-16
rel error of the step : 1.56e-15
sum/N vs mean-of-means : 0.00e+00
The relative error of $1.56 \times 10^{-15}$ is rounding, not approximation. Splitting one batch across eight workers, and each worker's work across four micro-batches, changed the resulting weight update by nothing that matters. That exactness is what licenses the entire engineering effort around all-reduce: the systems work goes into making the combine step cheap and reliable, never into recovering accuracy that the partitioning supposedly lost. The zero on the final line is a reminder that the sum-versus-mean conventions are interchangeable only under equal shards; treat them as identical on an unbalanced cluster and the learning rate drifts.
This section advances the book's spine by separating the math of the combine step (an exact, order-free sum) from its wiring (centralized server versus symmetric all-reduce). The math is settled and never changes; the wiring is where scale-out lives or dies. Ring all-reduce wins because it keeps per-worker bandwidth constant in the cluster size, the same reasoning that returns in Section 4.4 for the collective itself and reappears as reduce-scatter and all-gather inside sharded training. Whenever a later method combines gradients, ask the two questions from here: sum or mean, and through a center or around a ring.
5. Where the All-Reduce Sits, and Hiding It Behind Backprop Advanced
Naively, the all-reduce happens after the entire backward pass finishes: compute all gradients, then combine, then step. That ordering leaves the network idle during backprop and the accelerators idle during the all-reduce, so the two costs add. The structure of backpropagation offers a better schedule. Backprop produces gradients layer by layer, from the output toward the input, and the gradient of a layer is final the moment that layer's backward computation completes; it does not depend on the layers still being processed. So the all-reduce for an early-finishing layer (one near the output) can start while the backward pass is still grinding through the layers nearer the input. The communication of one bucket of gradients overlaps with the computation of the next, and a large fraction of the all-reduce cost disappears behind work the accelerator was doing anyway.
This overlap is the reason production data-parallel trainers group gradients into buckets and launch each bucket's all-reduce as soon as it is ready, rather than waiting for the whole backward pass. The mechanism, and the conditions under which the overlap is near-perfect or only partial, are exactly the communication-computation overlap analyzed in Section 4.10; with gradient accumulation, the overlap applies to the single all-reduce that fires after the last micro-batch, while the accumulation micro-batches run with no communication at all. The payoff is that a well-scheduled step can hide most of its communication, which is what makes data parallelism scale even when the gradient vectors are billions of numbers long.
Because the combine step is the tax on data parallelism, recent work attacks it from both ends. On the squeezing side, low-bit and structured gradient compression in the lineage of PowerSGD and 1-bit Adam shrink the bytes per all-reduce, and 2024 to 2025 systems push fp8 and even lower-precision gradient reductions into the collective itself. On the hiding side, local-update methods such as DiLoCo (Douillard et al., 2024) and its descendants let workers take many steps between communications, turning the per-step all-reduce into a far rarer event and enabling training over slow or geo-distributed links; this is the same amortization that gradient accumulation performs, taken to its limit. A third thread fuses the all-reduce schedule with the compiler so that bucket boundaries and overlap windows are chosen automatically for a given model and interconnect. We give these methods their full treatment, with the cost models to compare them, in Section 10.7; for now, note that the field treats Output 10.5.1's exactness as fixed and spends all its ingenuity on making the combine step cheaper to move.
Code 10.5.1 spelled out the accumulate-then-all-reduce-then-divide loop by hand. In PyTorch, DistributedDataParallel buckets the gradients, launches each bucket's all-reduce during the backward pass for the overlap of Section 5, and averages by the world size automatically. Gradient accumulation is a one-line guard that skips the all-reduce until the last micro-batch:
# Run with: torchrun --nproc_per_node=8 thisfile.py
import torch, torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
dist.init_process_group("nccl")
model = DDP(model) # buckets + overlapped all-reduce
ACCUM = 4 # micro-batches per optimizer step
for step, micro_batches in enumerate(loader): # loader yields ACCUM micro-batches
for j, (x, y) in enumerate(micro_batches):
last = (j == ACCUM - 1)
ctx = model.no_sync() if not last else nullcontext()
with ctx: # skip all-reduce until the last micro-batch
loss = model(x, y) / ACCUM # scale so the SUM becomes the right mean
loss.backward() # gradients accumulate in .grad
optimizer.step(); optimizer.zero_grad() # one update per ACCUM micro-batches
no_sync() suppresses the all-reduce on all but the last micro-batch, the division by ACCUM turns the accumulated sum into the correct mean, and DDP handles bucketing, the world-size average, and the backprop overlap internally.6. Putting the Pieces Together Beginner
All-reduce SGD is one tight loop: each worker computes a local gradient (optionally accumulated over several micro-batches), the workers combine their gradients with a symmetric all-reduce that runs once per step and overlaps with the backward pass, each worker divides by the global batch size to recover the mean, and each worker applies the identical optimizer update. No node is special, no central link saturates, and the result is bit-for-bit the large-batch step you would have taken on one impossibly large machine. The remaining question is what happens when a worker is slow or its gradient arrives late, because the synchronous all-reduce above waits for the slowest participant. That straggler problem, and the asynchronous and stale-gradient methods that trade exactness for the freedom not to wait, are the subject of Section 10.6.
A colleague reports that their data-parallel run diverges on the very first step, but the single-GPU version of the same code trains fine. Their combine step calls an all-reduce with a SUM reduction and then applies the optimizer directly to the result with the learning rate they tuned on one GPU. State exactly what is wrong, what factor their effective learning rate is off by in terms of the worker count $K$ and the per-worker batch size, and the single line of code that fixes it. Then explain why the bug would have stayed hidden if they had used a MEAN-reducing collective instead.
Modify Code 10.5.1 so the workers receive very different batch sizes (for example, one worker gets half the rows and the rest split the remainder). Combine the per-worker gradients two ways: (a) the mean-of-means form, dividing each worker's sum by its own row count and averaging over workers; and (b) the sum-then-divide-by-$B$ form. Measure each against the big-batch reference big. Report which one stays exact, by how much the other is off, and explain why unequal shards break the average of means but not the sum-then-divide form. State what this implies for a real cluster where a straggler is handed a smaller batch to catch up.
A gradient has $P = 2 \times 10^{9}$ elements at 2 bytes each, and you train with $K$ workers on links of 25 gigabytes per second. For the parameter-server pattern, the central server's inbound link must carry one gradient from each of the $K$ workers per step; estimate that ingress time as a function of $K$. For ring all-reduce, each worker moves about $2\frac{K-1}{K}P$ elements regardless of $K$; estimate that time. Plot or tabulate both for $K \in \{8, 64, 512\}$ and identify the cluster size beyond which the central link is the clear bottleneck. Explain in one sentence why the ring time barely moves while the server time grows, and connect your answer to the alpha-beta cost model of Chapter 3.