Part I: Foundations of Distributed AI
Chapter 4: Communication Primitives for Distributed Training

Why Communication, Not Compute, Bounds Distributed Training

"I finish my gradient in four milliseconds and then sit on the barrier for forty-six, waiting for everyone else to agree on the answer. I am not slow. I am punctual to a meeting that will not start."

A GPU Idling on a Communication Barrier
Big Picture

A distributed training step has two halves: compute the gradient, which an accelerator does very fast, and then combine the gradients across all workers, which the network does much more slowly. As models and clusters grow, the second half dominates, so communication, not compute, sets the pace of training. This is the thesis of the entire chapter, and it inverts the intuition most practitioners carry from single-machine work, where a faster chip always means a faster run. Once you cross into the distributed regime, buying more accelerators can leave the step time almost unchanged, because the gradient was never the bottleneck; moving it was. This section makes that claim quantitative with a single ratio you can compute on the back of an envelope, shows the ratio crossing the break-even point of one as the model or the worker count grows, and previews the collective operations that the rest of the chapter exists to make cheap.

In Section 1.1 we proved that data-parallel training is exact: split the examples across $K$ workers, sum their partial gradients, divide by the total count, and you recover the identical single-machine gradient. That proof settled correctness and deliberately left cost for later. Later is now. The summing step, harmless when the data sat in one process, becomes a network operation when the workers are separate machines, and the time it takes does not shrink no matter how fast the workers compute. This section measures that time against the compute it gates and shows that for the models and clusters people actually train, the network usually wins.

The structure of every data-parallel step is the same two-phase rhythm, and seeing it drawn once fixes the intuition for the whole chapter. Figure 4.1.1 lays out a single step on a timeline: a short compute phase where each worker runs the forward and backward pass on its shard, followed by a longer communication phase where the workers exchange and combine their gradients before any of them can take the optimizer step.

One data-parallel training step on each worker Compute forward + backward All-reduce the gradient exchange and sum one vector across all K workers Optimizer apply update wall-clock time for one step fast on the accelerator slow over the network, and it does not shrink with a faster chip As the model grows or workers are added, the orange block stretches while the green block stays put: the step becomes communication-bound.
Figure 4.1.1: The two-phase rhythm of a data-parallel step. The green compute phase runs on the accelerator and is fast; the orange all-reduce phase runs over the interconnect and is slow, and no worker can apply its optimizer update until the all-reduce finishes. The whole chapter is about shrinking the orange block. The ratio of orange time to green time is the quantity defined next.

1. The Two Halves of a Training Step Beginner

Write the wall-clock time of one synchronous data-parallel step as the sum of the two phases in Figure 4.1.1. The compute phase takes time $T_{\text{comp}}$: each worker runs the forward and backward pass over its local batch. A widely used estimate puts the cost of one training step at roughly six floating-point operations per parameter per example (two for the forward pass, four for the backward), so for a model of $P$ parameters, a local batch of $b$ examples, and an accelerator that sustains $R$ floating-point operations per second,

$$T_{\text{comp}} = \frac{6\,P\,b}{R}.$$

The communication phase takes time $T_{\text{comm}}$: the workers all-reduce the gradient, a vector of $P$ numbers. Chapter 3 modeled the cost of moving data across a network as latency plus size over bandwidth; here we keep only the bandwidth term, because for the large gradients of real models it dominates. A ring all-reduce, the algorithm Section 4.4 dissects, moves about $2(K-1)/K$ times the gradient's bytes across each worker's link. With $g$ bytes per parameter and an effective per-worker bandwidth $B$,

$$T_{\text{comm}} = \frac{2(K-1)}{K} \cdot \frac{g\,P}{B}.$$

The contrast between these two formulas is the whole story. The compute time carries the local batch $b$ in its numerator; the communication time does not. Add more workers to a fixed global batch and each worker's $b$ falls, so $T_{\text{comp}}$ shrinks while $T_{\text{comm}}$ holds nearly constant (the factor $2(K-1)/K$ creeps up toward $2$ and then stops). Grow the model and both times scale with $P$, so their ratio stays fixed at whatever the hardware dictates. We build the cost model carefully on the foundation laid in Section 3.8; the point here is the shape, not the third decimal place.

Key Insight: Compute Carries the Batch, Communication Carries the Model

Per-step compute scales with the local batch each worker processes; per-step communication scales with the size of the model, not the batch. That single asymmetry is why distributed training drifts toward being communication-bound. Strong scaling (a fixed global batch spread over more workers) shrinks the per-worker compute toward zero while the gradient you must exchange stays exactly as large. The network does not care that each worker now has less to do; it still has to move the whole model's worth of numbers every step.

2. The Communication-to-Computation Ratio Intermediate

Define the communication-to-computation ratio $\gamma$ as the time spent moving the gradient divided by the time spent producing it:

$$\gamma = \frac{T_{\text{comm}}}{T_{\text{comp}}} = \frac{2(K-1)}{K} \cdot \frac{g\,P}{B} \cdot \frac{R}{6\,P\,b} = \frac{2(K-1)}{K} \cdot \frac{g\,R}{6\,b\,B}.$$

The algebra delivers a small surprise: the parameter count $P$ cancels. At fixed batch and worker count, doubling the model size doubles both phases and leaves the ratio untouched, so a tiny model and a giant one sit at the same point on the curve. What does move $\gamma$ is the per-worker batch $b$ (in the denominator) and the worker count $K$ (through the factor that climbs toward $2$). Shrink $b$ by spreading a fixed global batch over more workers and $\gamma$ rises in inverse proportion. The break-even point is $\gamma = 1$: below it the step is compute-bound and adding workers still helps; above it the step is communication-bound and each new worker mostly adds waiting. The honest scaling question from Section 3.5, where Amdahl's law caps speedup at the reciprocal of the serial fraction, is exactly this: the all-reduce is the serial fraction of a training step, and $\gamma$ measures how large it has grown.

The script below turns these formulas into numbers. It fixes the hardware (a fast accelerator, an NVLink-class interconnect) and a global batch, then sweeps the worker count $K$, splitting that global batch across the workers so each one's local batch shrinks as the cluster grows. It prints the compute time, the all-reduce time, and the ratio $\gamma$, flagging the rows where the step has gone communication-bound.

def compute_seconds(P, local_batch, flop_per_param_per_example, flop_rate):
    # One step costs about 6 FLOPs per parameter per example (forward + backward).
    return flop_per_param_per_example * P * local_batch / flop_rate

def allreduce_seconds(P, bytes_per_param, bandwidth, K):
    # Ring all-reduce moves about 2(K-1)/K times the gradient bytes per worker.
    return 2.0 * (K - 1) / K * (P * bytes_per_param) / bandwidth

flop_rate = 312e12   # 312 TFLOP/s sustained (modern accelerator, bf16)
bandwidth = 600e9    # 600 GB/s effective interconnect per worker (NVLink-class)
fppe = 6.0           # FLOPs per parameter per example (fwd + bwd)
bpp = 2.0            # bytes per parameter (bf16 gradient)

# Strong scaling: a FIXED global batch of 2048 examples is split across K workers,
# so each worker's local batch shrinks as we add workers. Compute per step falls,
# communication per step does not, so the ratio climbs and eventually crosses 1.
global_batch = 2048
P = 7e9              # a 7-billion-parameter model

print("Strong scaling: P = 7e9 params, global batch = 2048 split across K workers")
print(f"{'K':>6} {'local batch':>12} {'compute (ms)':>13} {'all-reduce (ms)':>16} {'gamma':>9}")
for K in [1, 4, 16, 64, 256, 1024]:
    local = max(global_batch // K, 1)
    tc = compute_seconds(P, local, fppe, flop_rate)
    ta = allreduce_seconds(P, bpp, bandwidth, K)
    flag = "  <- comm-bound" if ta > tc else ""
    print(f"{K:>6d} {local:>12d} {tc*1e3:>13.2f} {ta*1e3:>16.2f} {ta/tc:>9.3f}{flag}")
Code 4.1.1: A back-of-the-envelope estimator for the two phases of a data-parallel step. The strong-scaling sweep holds the global batch fixed and divides it among the workers, so the local batch (and thus compute time) falls while the all-reduce time stays put.
Strong scaling: P = 7e9 params, global batch = 2048 split across K workers
     K  local batch  compute (ms)  all-reduce (ms)     gamma
     1         2048        275.69             0.00     0.000
     4          512         68.92            35.00     0.508
    16          128         17.23            43.75     2.539  <- comm-bound
    64           32          4.31            45.94    10.664  <- comm-bound
   256            8          1.08            46.48    43.164  <- comm-bound
  1024            2          0.27            46.62   173.164  <- comm-bound
Output 4.1.1: The ratio crosses one between four and sixteen workers. At $K=4$ the step is still compute-bound ($\gamma = 0.51$); by $K=16$ communication takes more than two and a half times as long as compute, and at $K=1024$ it takes over a hundred times as long. The compute column shrinks toward zero while the all-reduce column barely moves: the network has become the pacemaker.

Read the columns side by side. The compute time falls by a factor of a thousand from one worker to a thousand workers, exactly as you would hope from spreading the batch. The all-reduce time barely changes, climbing only from $35$ to $47$ milliseconds because the factor $2(K-1)/K$ is the only thing that moves. By $K = 1024$ the workers spend over $99\%$ of each step waiting on the network and well under $1\%$ computing. That is not a pathology of these particular numbers; it is the generic fate of strong scaling, and it is why the rest of this chapter is a catalogue of techniques for making the orange block in Figure 4.1.1 shorter.

Thesis Thread: The Tax Named, the Chapter Justified

Section 1.1 promised that the complexity of the rest of the book is not about the correctness of the gradient but about "the cost and reliability of the combining step." Output 4.1.1 is that promise made numeric. The combining step is an all-reduce, and once $\gamma$ exceeds one, it is the thing your training run is actually doing most of the time. Every collective in this chapter (all-reduce in Section 4.3, reduce-scatter and all-gather in Section 4.5, all-to-all in Section 4.6) is a way to perform that combining step with less time on the wire, and Section 4.10 hides what remains behind the backward pass. This is where the book's communication-primitives spine begins.

3. Why Growing the Model Does Not Save You Intermediate

A natural hope is that very large models, the ones everyone wants to train, will be compute-bound simply because they do so much arithmetic. The cancellation of $P$ in the ratio already warns against this, and the second sweep makes it concrete. Holding the worker count and local batch fixed while the parameter count grows by three orders of magnitude leaves $\gamma$ exactly constant: both the compute and the all-reduce columns scale in lockstep with $P$, so a hundred-billion-parameter model is no more compute-bound than a hundred-million-parameter one running on the same cluster shape.

Table 4.1.1: Growing the model at a fixed cluster shape ($K = 64$ workers, local batch $32$), computed from the same formulas as Code 4.1.1. The absolute times grow with $P$ but their ratio $\gamma$ is pinned by the hardware and batch, not the model size.
Parameters $P$Compute (ms)All-reduce (ms)Ratio $\gamma$
$10^{8}$0.060.6610.66
$7 \times 10^{8}$0.434.5910.66
$7 \times 10^{9}$4.3145.9410.66
$7 \times 10^{10}$43.08459.3810.66

The lesson of Table 4.1.1 is that model size is the wrong knob for escaping the communication wall. What sets $\gamma$ is the cluster shape (how many workers, how fast their links) and the per-worker batch, never the parameter count alone. This is also why the remedies that genuinely help are the ones that change those quantities: a fatter interconnect raises $B$, a larger per-worker batch raises $b$, gradient compression lowers the effective $g$, and overlap hides $T_{\text{comm}}$ behind $T_{\text{comp}}$ so that the two phases run at once instead of in sequence. We develop the systematic version of these levers in Chapter 10 on communication-efficient optimization, and put them to work inside a real training loop in Chapter 15 on data-parallel deep learning.

Fun Note: The Most Expensive Sum in Computing

Strip away the framework and a synchronous training step is, at heart, adding up some vectors and dividing by a count, the first thing anyone learns to do with numbers. We have built planet-spanning clusters, custom silicon, and dedicated fiber, and we still spend most of a training run on that addition. The arithmetic was never hard. Getting everyone to agree on the answer at the same instant is the hard part, and it always has been.

4. The Collectives This Chapter Covers Beginner

If communication is the binding constraint, then the operations that perform it deserve first-class study, and that is what the rest of this chapter provides. Every one of them is introduced through the AI computation that needs it, never as abstract networking. Table 4.1.2 is the roadmap: each collective, the training pattern that gives rise to it, and the section that develops it. You have already met the first one by hand in Section 1.1; the others are its relatives, and recognizing which collective a given parallelism strategy leans on is the single most useful lens for the chapters ahead.

Table 4.1.2: The collective operations of distributed training, each introduced in this chapter through the AI operation that calls for it. The interconnect itself is treated thinly, just enough to reason about cost.
CollectiveThe AI operation that needs itWhere it is developed
All-reduceSynchronizing gradients in data-parallel SGDSection 4.3
All-gather, reduce-scatterSharded optimizer state in ZeRO and FSDPSection 4.5
All-to-allRouting tokens to experts in mixture-of-expertsSection 4.6
Broadcast, gatherWeight and experience movement in parameter servers and actor-learner RLSection 4.7

Before any of those, Section 4.2 draws the substrate they run on: the distinction between point-to-point messages and collective operations, and a deliberately thin view of the interconnects (NVLink, PCIe, InfiniBand, RDMA) that supply the bandwidth $B$ in our ratio. We keep that hardware tour brief on purpose, because the goal is not to become network engineers but to reason about cost well enough to predict, before launching a job, whether it will land on the green side of $\gamma = 1$ or the orange one.

Library Shortcut: Profile $\gamma$ Without the Algebra

The ratio in Code 4.1.1 is an estimate; on real hardware you can measure it directly. PyTorch's profiler tags each operation by category, so you can read off how much of a step went to compute versus communication (NCCL collectives) without deriving anything:

import torch
from torch.profiler import profile, ProfilerActivity

with profile(activities=[ProfilerActivity.CUDA], record_shapes=False) as prof:
    train_one_step(model, batch)            # forward, backward, and the DDP all-reduce
    torch.cuda.synchronize()                # wait for the GPU stream to drain

# Rows whose name starts with "nccl:" are the all-reduce time; the rest is compute.
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=8))
Code 4.1.2: The empirical version of $\gamma$. Where Code 4.1.1 predicts the split from first principles, the profiler reports the measured split, and the nccl: rows are precisely the all-reduce time the rest of this chapter works to shrink.
Practical Example: The Cluster That Got Slower When They Added Nodes

Who: A platform engineer running pretraining for a small research lab on rented GPUs.

Situation: A seven-billion-parameter model was training on sixteen GPUs and the team had budget to jump to sixty-four, expecting roughly four times the throughput.

Problem: At sixty-four GPUs the step time barely fell, and tokens-per-second improved by less than thirty percent, far short of the four times they had paid for.

Dilemma: Either the new nodes were faulty and worth a support ticket, or the workload had crossed into a regime where extra workers stop helping, in which case the fix was algorithmic, not operational.

Decision: Before opening a ticket they estimated $\gamma$ with the formula behind Code 4.1.1 and then confirmed it with the profiler in Code 4.1.2.

How: They had kept the global batch fixed while quadrupling the workers, so the local batch had dropped from $32$ to $8$; the estimate put $\gamma$ near eleven at sixteen workers and over forty at sixty-four, and the profiler showed more than ninety percent of each step in nccl: all-reduce.

Result: They raised the global batch to keep the local batch at $32$ and enabled gradient bucketing with computation overlap, which pulled the measured all-reduce share back under half and recovered most of the missing speedup.

Lesson: When more workers stop helping, the suspect is almost never broken hardware; it is a communication-to-computation ratio that quietly crossed one. Compute $\gamma$ before you scale, not after.

5. What This Buys Us for the Rest of the Book Intermediate

We now have a single number, $\gamma$, that decides whether a distributed training step is paced by the accelerator or by the network, and a clear picture of what moves it: the per-worker batch and the cluster shape, not the model size. That number reframes every later design decision. A choice of interconnect is a choice of $B$; a parallelism strategy is a choice of which collective runs and how large its payload is; a compression scheme lowers $g$; an overlap technique removes $T_{\text{comm}}$ from the critical path entirely. None of these make sense without first accepting that communication is the thing worth optimizing, which is what this section established.

It is worth saying plainly that the goal is not zero communication; for synchronous data parallelism the all-reduce is mandatory, because the exactness proof of Section 1.1 depends on it. The goal is to keep $\gamma$ small enough that the workers spend their time computing rather than waiting, and to know, before committing a cluster, on which side of the break-even line a planned job will fall. With the thesis fixed and the ratio in hand, the next step is to look at the wires themselves: the point-to-point and collective communication model, and the thin slice of interconnect physics we need to populate $B$. That is the subject of Section 4.2, The Communication Substrate.

Research Frontier: Pushing the Break-Even Point Outward (2024 to 2026)

Because $\gamma$ governs whether scale helps, recent work attacks each of its levers. Local-update training in the DiLoCo line (Douillard et al., 2024) and its open replications let workers take many optimizer steps between communications, slashing how often the all-reduce fires and enabling training over ordinary internet links where $B$ is small; follow-on streaming and asynchronous variants in 2024 to 2025 push this toward genuinely geo-distributed runs. Gradient and optimizer-state compression (the PowerSGD and 1-bit-Adam lineage, and newer low-rank and quantized all-reduce schemes) shrink the effective $g$ so each collective moves fewer bytes. At the systems level, NCCL and its successors plus topology-aware collective scheduling raise the effective $B$ and overlap the all-reduce with the backward pass so aggressively that, in the best case, $T_{\text{comm}}$ nearly vanishes from the critical path. We return to these with the analytical machinery to compare them in Chapter 10; for now, read every one of them as an engineered assault on the ratio defined in this section.

Exercise 4.1.1: Find the Break-Even Worker Count Analysis

Using the ratio $\gamma = \frac{2(K-1)}{K} \cdot \frac{g\,R}{6\,b\,B}$, hold the global batch $G$ fixed so that the local batch is $b = G/K$, and substitute to get $\gamma$ as a function of $K$ alone. Solve $\gamma = 1$ for $K$ with the values in Code 4.1.1 ($G = 2048$, $g = 2$ bytes, $R = 312 \times 10^{12}$, $B = 600 \times 10^{9}$). Confirm your answer falls between the $K = 4$ and $K = 16$ rows of Output 4.1.1, and explain in one sentence why the parameter count $P$ never entered your calculation.

Exercise 4.1.2: Add the Latency Term Coding

Code 4.1.1 keeps only the bandwidth term of communication cost. Extend allreduce_seconds to add a fixed latency floor of the form $2(K-1)\alpha$ for a per-message latency $\alpha$ (try $\alpha = 5$ microseconds), modeling the $K-1$ rounds of a ring all-reduce. Re-run the strong-scaling sweep and identify the smallest model size $P$ at which the latency term still matters at $K = 1024$. Explain why latency dominates for tiny gradients and bandwidth dominates for large ones, connecting your finding to the cost model of Section 3.8.

Exercise 4.1.3: Weak Scaling Versus Strong Scaling Conceptual

The sweep in Code 4.1.1 is strong scaling: a fixed global batch divided among more workers. Describe weak scaling instead, where the local batch $b$ is held constant and the global batch grows with $K$. Using the ratio formula, argue what happens to $\gamma$ as $K$ grows under weak scaling, and explain why weak scaling keeps a step compute-bound far longer than strong scaling does. Then state the catch: name one reason you cannot grow the global batch without limit, and point to where Amdahl-style reasoning from Section 3.5 reappears.