Part I: Foundations of Distributed AI
Chapter 4: Communication Primitives for Distributed Training

Overlapping Communication with the Backward Pass, and Gradient Bucketing

"I used to wait politely for every gradient to finish before saying a word. Then someone taught me to talk while the work was still happening, and the cluster has not been bored since."

An All-Reduce That Learned to Multitask
Big Picture

The cheapest all-reduce is the one you never wait for. Because gradients become available layer by layer as the backward pass runs, the all-reduce of an early-finished layer can travel the network while later layers are still computing, hiding most of the communication behind work the accelerator was going to do anyway. Two systems tricks make this hiding effective. Overlap schedules each gradient's all-reduce the instant it is ready, so the network and the accelerator run at once instead of taking turns. Gradient bucketing groups many small gradient tensors into a few large ones, so each all-reduce moves enough bytes to be bandwidth-bound rather than latency-bound. Together they turn the communication tax of Section 4.1 from a serial penalty paid after every step into a cost that nearly vanishes behind compute. This is the single most important systems optimization for making data-parallel training scale, and it is exactly what PyTorch DistributedDataParallel does for you automatically.

Every collective in this chapter, all-reduce (Section 4.3), all-gather and reduce-scatter (Section 4.5), all-to-all (Section 4.6), has so far been treated as an operation you call and then wait for. That framing is convenient but pessimistic. In a real training step the gradient all-reduce does not have to be a wall the accelerator stops at; it can run on the interconnect at the same time the accelerator keeps computing. This section is about why that overlap is possible at all, why naive overlap still leaves bandwidth on the table, and how bucketing recovers it. It is the optimization that decides whether adding a second machine makes training faster or merely busier.

1. Gradients Arrive Early, So Communication Can Start Early Intermediate

Backpropagation runs from the output layer back toward the input. The chain rule computes the gradient of the loss with respect to the last layer's parameters first, then the second-to-last, and so on, so a deep network's gradients do not appear all at once at the end of the backward pass; they stream out, one layer's worth at a time, from the top of the network down. By the time the backward pass is a quarter done, the gradients of roughly the last quarter of the layers are already final and will not change.

This ordering is the whole opportunity. The all-reduce that averages the last layer's gradient across workers can be launched the moment that gradient is ready, while the accelerator is still busy computing the gradients of earlier layers. The network transfer and the remaining backward compute then run concurrently on two different pieces of hardware, the interconnect and the accelerator, neither of which has to idle waiting for the other. If the backward compute takes long enough, almost every all-reduce finishes underneath it, and the step pays for communication only at the very end, for whatever could not be hidden.

Contrast that with the naive schedule, in which the worker completes the entire backward pass, then calls one all-reduce over all the gradients, then waits. There the accelerator sits idle for the full communication time, and the step costs compute plus communication rather than the maximum of the two. Section 4.1 introduced communication as a tax; overlap is how you avoid paying it in serial.

Key Insight: Overlap Turns a Sum Into a Maximum

Without overlap, a data-parallel step costs $T_\text{compute} + T_\text{comm}$, because the accelerator and the network take turns. With perfect overlap, gradients that finish early are communicated while later gradients are still being computed, and the step costs only $\max(T_\text{compute},\, T_\text{comm})$ plus the small unavoidable tail of the last bucket. When compute dominates, that tail is all you pay, and communication becomes nearly free. Scaling out stops being a race between adding workers and drowning in their chatter.

Figure 4.10.1 shows the two schedules side by side as a timeline. In the overlapped schedule the all-reduce of each bucket slides underneath the compute of the layers that have not finished yet, so the colored communication bars sit beneath, not after, the compute bars.

No overlap compute, then communicate backward compute (all layers) one all-reduce over every gradient step ends late Overlap + bucketing each bucket's all-reduce runs under later compute backward compute (same total) bucket A bucket B bucket C tail step ends early communication hidden behind compute
Figure 4.10.1: Two schedules for the same backward pass. Top: the accelerator finishes all compute, then pays the full all-reduce time in series, so the step ends late. Bottom: each bucket of gradients is all-reduced as soon as its layers finish, so the communication bars (orange) slide beneath the later compute bars (blue) and only a short tail remains exposed at the end. The blue compute total is identical in both rows; only the placement of the orange communication differs.

2. Why Bucketing: Many Small All-Reduces Are Latency-Bound Intermediate

Overlap alone is not enough, because a deep network has many parameter tensors, often hundreds, and most are small. If each tensor triggered its own all-reduce the instant it was ready, the worker would launch hundreds of tiny collectives per step. Recall the alpha-beta cost model from Section 3.8 and its application to collectives in Section 4.4: moving a message of $n$ bytes costs $\alpha + n/\beta$, where $\alpha$ is a fixed per-message latency and $\beta$ is bandwidth. For a small $n$ the fixed $\alpha$ dominates, so the transfer is latency-bound and the expensive interconnect runs far below its bandwidth.

Gradient bucketing fixes this by coalescing many small gradient tensors into a few large contiguous buffers, called buckets, and issuing one all-reduce per bucket. A bucket of $B$ tensors totaling $n$ bytes pays the latency $\alpha$ once instead of $B$ times, and because $n$ is now large the term $n/\beta$ dominates: the all-reduce is bandwidth-bound, which is the efficient regime. The bucket fires its all-reduce as soon as the last gradient assigned to it is ready, which is also what makes overlap practical: a handful of bucket-sized collectives are far easier to schedule underneath compute than hundreds of tiny ones.

Bucket size is a genuine trade-off, not a knob to max out. Tiny buckets give the most overlap, because a bucket can fire as soon as its few layers finish, but they pay $\alpha$ many times and waste bandwidth. One giant bucket is maximally bandwidth-efficient but cannot fire until the entire backward pass is done, which destroys all overlap and reduces to the naive serial schedule. The sweet spot is a few buckets of tens of megabytes each: large enough to be bandwidth-bound, numerous enough that early buckets overlap the rest of the backward pass.

Fun Note: The 25 MB Default Nobody Tunes

PyTorch's DistributedDataParallel ships with a default bucket size of 25 MB. The number looks suspiciously round because it is: it was chosen empirically to be comfortably past the latency-bound regime on common interconnects while still leaving several buckets to overlap. Most training jobs run for years on that default without anyone touching it, which is the highest compliment a systems heuristic can receive.

3. Measuring How Much Communication Overlap Hides Advanced

The cleanest way to feel the effect is to model a backward pass as a sequence of per-layer compute times and per-layer gradient sizes, then schedule the all-reduces two ways: serially after all compute, and overlapped as each bucket becomes ready. The code below does exactly that. It uses the alpha-beta cost model for each bucket's all-reduce, sweeps the bucket size from one layer per bucket up to one bucket for the whole model, and reports the step time under each schedule, the fraction of communication hidden behind compute, and the speedup.

NUM_LAYERS = 48
BYTES_PER_LAYER = 25_000_000        # 25 MB of gradients per layer (~6M fp32 params)
COMPUTE_PER_LAYER_S = 0.0030        # 3.0 ms of backward compute per layer
ALPHA_S = 0.000200                  # 200 us per-message launch latency
BETA_BPS = 12e9                     # 12 GB/s effective all-reduce bandwidth
TOTAL_COMPUTE = NUM_LAYERS * COMPUTE_PER_LAYER_S

def comm_time(num_bytes):
    return ALPHA_S + num_bytes / BETA_BPS            # alpha-beta cost (Section 4.4)

def make_buckets(bucket_layers):                     # split layers into fixed-size groups
    out, remaining = [], NUM_LAYERS
    while remaining > 0:
        take = min(bucket_layers, remaining)
        out.append(take); remaining -= take
    return out

def no_overlap_time(bucket_layers):                  # backward, THEN all communication
    buckets = make_buckets(bucket_layers)
    total_comm = sum(comm_time(b * BYTES_PER_LAYER) for b in buckets)
    return TOTAL_COMPUTE + total_comm, total_comm

def overlapped_time(bucket_layers):                  # each bucket overlaps later compute
    buckets = make_buckets(bucket_layers)
    t_compute, net_free_at = 0.0, 0.0                # compute cursor; network-idle time
    for b in buckets:
        t_compute += b * COMPUTE_PER_LAYER_S         # compute the layers feeding bucket
        ct = comm_time(b * BYTES_PER_LAYER)
        start = max(t_compute, net_free_at)          # wait for layers AND a free channel
        net_free_at = start + ct
    return max(t_compute, net_free_at), net_free_at  # step ends when both streams done

serial_one, _ = no_overlap_time(1)
for bl in (1, 2, 4, 8, 16, 48):
    no_t, comm_t = no_overlap_time(bl)
    ov_t, _ = overlapped_time(bl)
    hidden = 1.0 - (ov_t - TOTAL_COMPUTE) / comm_t   # share of comm hidden under compute
    print(f"bucket={bl:>2} layers  buckets={len(make_buckets(bl)):>2}  "
          f"no-overlap={no_t*1000:6.1f}ms  overlap={ov_t*1000:6.1f}ms  "
          f"hidden={hidden*100:5.1f}%  speedup={no_t/ov_t:.2f}x")
print(f"\nideal (compute only)        : {TOTAL_COMPUTE*1000:.1f} ms")
print(f"naive (1-layer, no overlap) : {serial_one*1000:.1f} ms")
best_t, _ = overlapped_time(8)
print(f"overlap + 8-layer buckets   : {best_t*1000:.1f} ms "
      f"-> {serial_one/best_t:.2f}x faster than naive")
Code 4.10.1: A discrete-event model of one data-parallel backward pass. overlapped_time advances a compute cursor layer by layer and lets each bucket's all-reduce start as soon as both its layers are done and the single network channel is free, capturing the two real constraints on overlap.
bucket= 1 layers  buckets=48  no-overlap= 253.6ms  overlap= 146.3ms  hidden= 97.9%  speedup=1.73x
bucket= 2 layers  buckets=24  no-overlap= 248.8ms  overlap= 148.4ms  hidden= 95.8%  speedup=1.68x
bucket= 4 layers  buckets=12  no-overlap= 246.4ms  overlap= 152.5ms  hidden= 91.7%  speedup=1.62x
bucket= 8 layers  buckets= 6  no-overlap= 245.2ms  overlap= 160.9ms  hidden= 83.3%  speedup=1.52x
bucket=16 layers  buckets= 3  no-overlap= 244.6ms  overlap= 177.5ms  hidden= 66.7%  speedup=1.38x
bucket=48 layers  buckets= 1  no-overlap= 244.2ms  overlap= 244.2ms  hidden=  0.0%  speedup=1.00x

ideal (compute only)        : 144.0 ms
naive (1-layer, no overlap) : 253.6 ms
overlap + 8-layer buckets   : 160.9 ms -> 1.58x faster than naive
Output 4.10.1: Two effects in one table. Reading down the no-overlap column, large buckets shave a few milliseconds by paying the latency $\alpha$ fewer times (253.6 ms with 48 tiny all-reduces down to 244.2 ms with one). Reading down the overlap column, small buckets hide far more communication behind compute (97.9% hidden at one layer per bucket, 0% with a single bucket that cannot start until the backward pass ends). The practical winner is a moderate bucket: 8-layer buckets reach 160.9 ms, within 12% of the 144 ms compute-only floor.

The table makes the tension between the two effects concrete. Overlap wants small buckets so communication starts early; bandwidth efficiency wants large buckets so latency is amortized. Neither extreme is best. The moderate bucket, a few tens of megabytes here, gets most of the overlap and most of the bandwidth, landing close to the ideal floor where the step costs essentially just the compute. That floor, $\max(T_\text{compute}, T_\text{comm})$, is the prize: when compute dominates, well-bucketed overlap makes the all-reduce nearly free.

Practical Example: The 8-GPU Job That Refused to Speed Up

Who: An ML platform engineer onboarding a vision team onto a new 8-GPU server.

Situation: The team had wrapped their model in DistributedDataParallel and expected close to 8x throughput, but measured only about 5x, with GPUs visibly idle between steps.

Problem: A profiler trace showed the backward pass finishing, then a long bar of all-reduce during which every GPU sat waiting, the classic serial communication tail.

Dilemma: Buy a faster interconnect, which is slow to procure and costly, or find out why the existing communication was not already overlapping with compute as DDP promises.

Decision: Investigate first. The model had been built with find_unused_parameters=True and a custom forward that touched parameters in an order that defeated DDP's bucket-ready detection, so buckets only fired after the whole backward finished.

How: They removed the unused-parameter flag, restored a clean module ordering so gradients became ready in reverse-layer order, and left the 25 MB bucket default alone.

Result: The all-reduce slid back underneath the backward pass, the idle bar disappeared, and throughput rose from about 5x to 7.4x, matching the overlap-column behavior in Output 4.10.1.

Lesson: Overlap is not automatic just because you called the right wrapper; it depends on gradients becoming ready in an order the framework can bucket and fire early. When scaling looks worse than expected, profile for an exposed communication tail before blaming the hardware.

Library Shortcut: DistributedDataParallel Buckets and Overlaps for You

Everything modeled in Code 4.10.1, choosing bucket boundaries, detecting when a bucket's gradients are all ready, firing its all-reduce mid-backward, and averaging the result, is built into PyTorch DistributedDataParallel. You write an ordinary training loop; DDP registers autograd hooks that coalesce gradients into buckets and launch each all-reduce as soon as the bucket fills, all overlapped with the rest of loss.backward():

# Run with: torchrun --nproc_per_node=8 train.py
import torch, torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

dist.init_process_group("nccl")                     # join the group of K workers
model = DDP(build_model().cuda(),                   # wrap once; hooks installed here
            bucket_cap_mb=25)                        # the bucket size from Section 2

for x, y in data_loader:                            # an ordinary training loop
    loss = model(x).loss
    loss.backward()      # gradients bucket + all-reduce + average, overlapped, here
    optimizer.step(); optimizer.zero_grad()
Code 4.10.2: The roughly forty lines of bucketing-and-overlap scheduling in Code 4.10.1 collapse to one DDP wrap with a bucket_cap_mb argument. The framework handles bucket assignment, the autograd hooks that detect bucket readiness, the all-reduce launch during the backward pass, and the division by world size. Chapter 15 turns this into a complete data-parallel training system.
Research Frontier: Overlapping Communication Inside the Layer (2024 to 2026)

Overlapping the gradient all-reduce with the backward pass is now standard; the active frontier pushes overlap into places that used to be strictly serial. In tensor and sequence parallelism the forward and backward passes contain all-gather and reduce-scatter collectives that classically blocked the matrix multiplications around them. Systems such as Centauri (Chen et al., 2024) and FLUX (Chang et al., 2024) decompose those collectives into fine-grained tiles and fuse them with the GEMM kernels so communication hides inside a single layer's compute, not just across layers. A parallel line on asynchronous tensor parallelism reorders the computation so each device can begin its share of the matmul before the full activation has arrived, overlapping the all-gather with the very operation that consumes it. The DiLoCo-style local-update methods from Chapter 10 attack the same tax from the other side, communicating less often rather than hiding each communication better. The common thread is that the boundary between compute and communication, once a hard barrier, is now something to dissolve.

4. Chapter Summary: The Collectives Catalog and What Each Serves Beginner

This chapter built the vocabulary of distributed training from one observation, that communication and not compute bounds scale-out (Section 4.1), up to the optimization that makes communication cheap (this section). The throughline is a small catalog of collective operations, each matched to a specific AI operation that needs it. Table 4.10.1 collects them, and it is worth committing to memory because every parallel method in Parts III through V is, at its core, a choice of which of these collectives to call and how to schedule it.

Table 4.10.1: The collectives catalog of Chapter 4. Each primitive is the communication backbone of a distinct AI operation; the right-hand column names where the book uses it at scale.
CollectiveWhat it doesThe AI operation it serves
All-reduceSum (or average) one vector per worker; every worker gets the resultGradient synchronization in data-parallel SGD (Section 4.3, Chapter 15)
All-gatherEach worker collects every worker's shard into the full tensorReassembling sharded parameters before a layer's compute in FSDP/ZeRO (Section 4.5)
Reduce-scatterSum across workers, then leave each worker only its shard of the resultSharded gradient reduction in FSDP/ZeRO; the dual of all-gather (Section 4.5)
All-to-allEvery worker sends a distinct piece to every other workerRouting tokens to expert devices in mixture-of-experts (Section 4.6)
BroadcastOne worker sends the same tensor to all othersDistributing initial or updated weights from a server or rank 0 (Section 4.7)
GatherOne worker collects a distinct piece from every otherAggregating experience or results to a learner or coordinator (Section 4.7)

The cost of each of these is governed by the alpha-beta model of Section 3.8, refined into ring and tree algorithms in Section 4.4, carried by libraries such as NCCL (Section 4.8), placed onto the fastest links by topology-aware scheduling (Section 4.9), and finally hidden behind compute by the overlap and bucketing of this section. That is the complete arc of Chapter 4: name the primitive, price it, implement it efficiently, place it well, and overlap it away.

Thesis Thread: The Primitives Are Set; Now They Get Scaled Out

Chapter 4 closes the foundations of communication. Every collective in Table 4.10.1 returns, transformed, as the engine of a parallel method later in the book: all-reduce becomes the heartbeat of data-parallel training (Chapter 15), all-gather and reduce-scatter become the memory-saving mechanics of sharded training (Chapter 16), and all-to-all becomes the router of expert parallelism (Chapter 17). The overlap discipline you just learned is what keeps all of them affordable at thousands of workers. When a later chapter says "and then we synchronize," it is calling one of these six, scheduled to hide behind the compute it serves.

5. Exercises Intermediate

Exercise 4.10.1: When Does Overlap Stop Helping? Conceptual

Overlap turns a step's cost from $T_\text{compute} + T_\text{comm}$ into roughly $\max(T_\text{compute}, T_\text{comm})$. Describe a concrete training regime in which $T_\text{comm} > T_\text{compute}$, so that even perfect overlap leaves communication exposed and the step is communication-bound. Name two things you could change (one about the model or batch, one about the cluster) to move back into the compute-bound regime where overlap hides almost everything. Tie your answer to the alpha-beta parameters of Section 3.8.

Exercise 4.10.2: Sweep the Bucket Size Coding

Extend Code 4.10.1 to plot step time against bucket size for three interconnects: a fast one ($\alpha = 5\,\mu s$, $\beta = 100$ GB/s), the default in the code, and a slow one ($\alpha = 1$ ms, $\beta = 2$ GB/s). For each, find the bucket size that minimizes the overlapped step time. Explain why the optimal bucket grows as the interconnect gets slower or higher-latency, and connect that to the latency-versus-bandwidth trade-off of Section 2.

Exercise 4.10.3: Measure Real Overlap Efficiency Analysis

Given a profiler trace of a real data-parallel step (or a synthetic one you construct), define an overlap efficiency metric: the fraction of total all-reduce time that occurred concurrently with backward compute, equal to $1 - T_\text{exposed}/T_\text{comm}$ where $T_\text{exposed}$ is communication time during which the accelerator was idle. Compute it for the 8-layer-bucket row of Output 4.10.1. Then argue what an efficiency well below 1 would tell you to inspect: bucket ordering, gradient-readiness order, or the presence of a stray blocking collective.

Project Ideas

These close Chapter 4 by turning its ideas into something you can build and measure. Each is sized for a focused effort and connects directly to the collectives catalog above.

  1. Overlap-efficiency profiler. Take a small model in PyTorch DistributedDataParallel on two or more GPUs (or two processes on CPU with the Gloo backend), capture a profiler trace of one training step, and compute the overlap-efficiency metric from Exercise 4.10.3. Sweep bucket_cap_mb across several values and chart how efficiency and step time respond, reproducing the shape of Output 4.10.1 on real hardware. Report where the measured optimum sits relative to the 25 MB default.
  2. Bucketing from scratch. Implement your own gradient bucketing on top of raw torch.distributed.all_reduce: register backward hooks, coalesce gradients into fixed-size flat buffers, fire one asynchronous all-reduce per full bucket, and copy results back. Verify your result matches DDP numerically, then compare its step time against per-tensor all-reduce (no bucketing) to quantify the latency you saved by amortizing $\alpha$.
  3. Cost-model calibrator. Fit the alpha-beta parameters $\alpha$ and $\beta$ of your actual interconnect by timing all-reduces across a range of message sizes, then plug them into Code 4.10.1 and check how well the model's predicted step times match measured ones. A close fit validates the whole Chapter 4 cost framework on your cluster; a poor fit is a clue about contention, topology, or a slow link worth investigating with the tools of Section 4.9.