Part I: Foundations of Distributed AI
Chapter 4: Communication Primitives for Distributed Training

All-Gather and Reduce-Scatter: The Primitives Behind ZeRO and FSDP

"I only ever hold one eighth of the weights. Right before the layer runs, I borrow the other seven eighths, do the math, and politely give them back."

A Shard That Believes It Is the Whole Model
Big Picture

All-reduce hides two simpler collectives inside it, and pulling them apart is what lets a model too big for one accelerator train anyway. Reduce-scatter sums one vector across all ranks and then hands each rank only its own slice of the result. All-gather does the reverse: each rank contributes its slice and ends up holding the full concatenation. Run reduce-scatter and then all-gather and you have performed an all-reduce, which is exactly how the ring algorithm of Section 4.4 works. The same two primitives, used on their own rather than back to back, are the engine of sharded data parallelism (ZeRO and FSDP): each rank stores only a shard of the parameters and optimizer state, all-gathers the full weights just in time to compute a layer, and reduce-scatters the gradients afterward so that it keeps only the shard it owns. This section builds both primitives from scratch, verifies the all-reduce identity in running code, and shows why FSDP trades extra communication for a large cut in memory.

In Section 4.3 we used all-reduce as a single indivisible operation: every worker arrives with a gradient vector, and every worker leaves with the elementwise sum. Section 4.4 then opened the box and showed that the bandwidth-optimal ring algorithm never moves the whole vector to one place; it works in two passes, a reduce phase and a broadcast phase, each touching only $1/K$ of the data per step. Those two phases are not improvised. Each one is a named collective in its own right, useful far beyond all-reduce, and this section gives them their names and their independent uses. The payoff is immediate: once you can reduce-scatter and all-gather separately, you can hold only a fraction of a model on each device and still train it, which is the central trick of memory-efficient large-model training.

1. Two Primitives, and the Identity That Links Them Intermediate

Fix a group of $K$ ranks. Suppose every rank $k$ holds a vector $g_k$ of length $P$, and we want the elementwise sum $s = \sum_{k=1}^{K} g_k$. All-reduce leaves the full vector $s$ on every rank. The two primitives of this section split that goal cleanly in half by deciding who ends up holding what.

Reduce-scatter performs the same elementwise reduction but then keeps each rank's appetite small: it partitions the length-$P$ result into $K$ contiguous slices and delivers slice $k$ to rank $k$ only. After a reduce-scatter, no single rank holds the whole sum; rank $k$ holds just $s$ restricted to its slice, a vector of length $P/K$. Writing $s^{(j)}$ for the $j$-th slice of the summed vector,

$$\text{reduce-scatter:}\quad \text{rank } k \text{ ends with } s^{(k)} = \Big(\sum_{i=1}^{K} g_i\Big)^{(k)}, \qquad k = 1, \dots, K.$$

All-gather is purely about movement, with no arithmetic. Each rank $k$ starts holding one slice $a_k$ of length $P/K$, and the operation makes every rank end with the full concatenation of all slices in rank order:

$$\text{all-gather:}\quad \text{every rank ends with } a = \big[\,a_1 \,\Vert\, a_2 \,\Vert\, \cdots \,\Vert\, a_K\,\big].$$

Now compose them. Start with the vectors $g_k$, reduce-scatter so that rank $k$ holds $s^{(k)}$, then all-gather those owned slices. Each rank contributes its $s^{(k)}$ and receives the concatenation $[s^{(1)} \Vert \cdots \Vert s^{(K)}]$, which is exactly the full sum $s$. Every rank now holds $s$. That is the definition of all-reduce, so

$$\boxed{\;\text{all-reduce} \;=\; \text{reduce-scatter} \;\to\; \text{all-gather}\;}$$

This is not a loose analogy; it is an equality of results, and it is precisely the two-phase structure of the ring all-reduce in Section 4.4. The reduce phase of the ring is a reduce-scatter (each rank ends owning the fully reduced value of one chunk), and the broadcast phase is an all-gather (each rank's owned chunk is propagated to everyone). Both phases move $P(K-1)/K$ elements per rank, so an all-reduce costs about twice a single all-gather, a fact we will spend later in the FSDP cost argument.

Key Insight: All-Reduce Is a Composite, Not an Atom

Treat all-reduce as reduce-scatter followed by all-gather, and three things fall out at once. First, you understand the bandwidth-optimal ring algorithm, because it literally is those two phases. Second, you gain two reusable primitives that solve different problems: reduce-scatter when each rank needs only its own slice of a sum, all-gather when each rank needs everyone's slice. Third, you can break the composite apart in time, all-gathering weights at one moment and reduce-scattering gradients at another, which is the entire idea behind sharded data parallelism.

2. Verifying the Identity in Code Intermediate

The cleanest way to trust the identity is to implement both primitives on simulated shards and check that their composition reproduces a direct all-reduce, number for number. The code below represents the $K$ ranks as rows of a single array (so we can see everything at once on one machine), implements reduce-scatter and all-gather as plain array operations, and compares reduce-scatter then all-gather against a direct sum. It then plays the FSDP cycle in miniature: shard a full weight vector, all-gather it back to compute, and reduce-scatter a batch of per-rank gradients so each rank keeps only its slice.

import numpy as np

rng = np.random.default_rng(0)
K, P = 4, 12                      # workers, parameter-vector length
g = rng.standard_normal((K, P))   # g[k] = the gradient computed on worker k

# --- reduce-scatter: sum across workers, then each rank keeps only its slice ---
def reduce_scatter(g):
    K, P = g.shape
    assert P % K == 0
    total = g.sum(axis=0)                       # the summed (reduced) full vector
    slices = total.reshape(K, P // K)           # slice s belongs to rank s
    return slices                               # out[k] = rank k's owned shard

# --- all-gather: each rank contributes its shard; everyone ends with the concat ---
def all_gather(shards):
    return np.concatenate(shards, axis=0)       # length-P vector, identical on all ranks

# --- the identity: all-reduce == reduce-scatter then all-gather ---
owned = reduce_scatter(g)                       # K shards, one per rank
allgathered = all_gather(owned)                 # full summed vector, on every rank
allreduce_ref = g.sum(axis=0)                   # direct all-reduce (SUM)

print("workers K, params P        :", K, P)
print("reduce-scatter shard shape :", owned.shape, "(K shards of P/K)")
print("max abs diff vs all-reduce :", f"{np.max(np.abs(allgathered - allreduce_ref)):.2e}")
print("identity holds (RS+AG==AR) :", np.allclose(allgathered, allreduce_ref))

# --- FSDP cycle in miniature: shard -> all-gather to compute -> reduce-scatter grads ---
full_w = rng.standard_normal(P)                 # the layer's true full weight
w_shards = full_w.reshape(K, P // K)            # each rank stores only its shard
gathered_w = all_gather([w_shards[k] for k in range(K)])
print("FSDP all-gather rebuilds w :", np.allclose(gathered_w, full_w))
per_rank_grad = rng.standard_normal((K, P))     # each rank's grad for the FULL layer
kept = reduce_scatter(per_rank_grad)            # after backward, keep only your shard
print("each rank keeps P/K grads  :", kept.shape[1], "of", P)
Code 4.5.1: Reduce-scatter and all-gather built from array primitives, with the all-reduce identity checked directly and the FSDP gather-compute-scatter cycle played out on a single weight vector.
workers K, params P        : 4 12
reduce-scatter shard shape : (4, 3) (K shards of P/K)
max abs diff vs all-reduce : 0.00e+00
identity holds (RS+AG==AR) : True
FSDP all-gather rebuilds w : True
each rank keeps P/K grads  : 3 of 12
Output 4.5.1: The composition of reduce-scatter and all-gather matches the direct all-reduce exactly (zero difference), all-gather rebuilds the full weight from four shards, and after the backward step each rank retains only $P/K = 3$ of the $12$ gradient entries it computed.

The difference is not merely small; it is exactly zero, because both primitives only add and copy floating-point numbers that the direct all-reduce also adds in the same grouping. The last two lines are the heart of the next part: each rank can store a third of a weight, briefly borrow the rest to compute, and after the backward pass discard everything except its own third of the gradient. Nothing about correctness changed; only the peak memory each rank must hold changed, and dramatically.

3. Sharded Data Parallelism: ZeRO and FSDP Advanced

Plain data parallelism (Chapter 15) replicates the entire model on every worker: every rank holds a full copy of the parameters, the gradients, and, expensively, the optimizer state. For an Adam-trained model in mixed precision, the optimizer state and master weights can be several times the size of the parameters themselves, so replication wastes most of the memory on every device holding identical copies. The insight of ZeRO (Zero Redundancy Optimizer) and its PyTorch sibling FSDP (Fully Sharded Data Parallel) is that this redundancy is unnecessary: with $K$ ranks, shard each of those tensors into $K$ pieces and have rank $k$ permanently own only piece $k$. No rank holds a full copy of anything large at rest.

The cost of holding only a shard is that the full weights of a layer do not exist on any single device when that layer needs to run. The two primitives of this section supply them on demand, and the rhythm is a tight cycle repeated layer by layer:

  1. All-gather the weights. Just before a layer's forward pass, the group all-gathers that layer's parameter shards so every rank momentarily holds the full layer weights and can compute the layer's output. The gathered copy is freed as soon as the layer finishes, so only one layer's full weights ever sit in memory at once.
  2. All-gather again for backward. The same all-gather reconstructs the layer's full weights during the backward pass so each rank can compute the gradient with respect to the complete weight tensor.
  3. Reduce-scatter the gradients. After backward, each rank holds a full-length gradient for the layer (summed contributions from its own microbatch). A reduce-scatter sums these across ranks and leaves rank $k$ holding only slice $k$ of the averaged gradient, exactly the slice whose optimizer state and parameters rank $k$ owns. Rank $k$ then updates its shard locally.

Figure 4.5.1 draws both halves of the story: the identity on top (reduce-scatter plus all-gather reconstituting an all-reduce) and the FSDP gather-compute-scatter cycle below. Notice that plain data parallelism's single gradient all-reduce per step has been replaced by, per layer, two all-gathers (forward and backward) plus one reduce-scatter. Since an all-reduce already equals one all-gather plus one reduce-scatter, FSDP communicates roughly one-and-a-half times the volume of plain data-parallel training. That extra traffic is the price, and in return each rank's resident memory for parameters, gradients, and optimizer state falls by a factor of about $K$.

Identity: reduce-scatter then all-gather equals all-reduce each rank: full g g₁ g₂ g₃ g₄ reduce-scatter sum, keepown slice each rank: 1 slice of s s⁽¹⁾ s⁽²⁾ s⁽³⁾ s⁽⁴⁾ all-gather everyone getsfull sum s s = Σ gₖ s = Σ gₖ s = Σ gₖ s = Σ gₖ each rank: full s (all-reduce) FSDP cycle per layer: gather, compute, scatter rank k storesonly weight shard k all-gather full layerweights (forward) compute fwd+ backward reduce-scatter rank k keeps grad shard k,updates its own optimizer state next step: each rank again holds only its shard, peak memory cut by about K
Figure 4.5.1: Top, the all-reduce identity: starting from a full gradient on each rank, a reduce-scatter sums and leaves every rank one slice of the result, then an all-gather propagates the slices so every rank holds the full sum, which is exactly an all-reduce. Bottom, the FSDP cycle: each rank stores only its weight shard at rest, all-gathers the full layer weights just in time for the forward and backward passes, then reduce-scatters the gradients so each rank keeps and updates only its own shard. The wrap-around arrow shows the cycle repeating with peak memory reduced by roughly a factor of $K$.
Thesis Thread: One All-Reduce, Split Apart in Time

The all-reduce of Section 4.3 synchronized gradients while every machine still held the whole model. Sharded data parallelism keeps the very same primitives but separates the halves in time: all-gather the weights when you need to compute, reduce-scatter the gradients when you are done. That separation is what converts a memory wall into a communication bill, and a communication bill is something the rest of this book knows how to manage. The parameter-server push-and-pull of Chapter 11 sharded parameters across servers for the same reason; FSDP sharded them across the workers themselves and reaches for collectives instead of a central server.

Library Shortcut: torch.distributed Gives You Both Primitives, and FSDP Wires Them For You

Code 4.5.1 implemented the primitives by hand to expose the mechanics. In a real job you call them directly, or, more often, you never call them at all because FSDP inserts the all-gathers and reduce-scatters into the forward and backward passes automatically. The raw collectives are one line each:

# Run with: torchrun --nproc_per_node=4 thisfile.py
import torch, torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

dist.init_process_group("nccl")
rank, world = dist.get_rank(), dist.get_world_size()

# --- raw primitives ---
shard = torch.ones(3, device="cuda") * rank          # this rank's slice
full  = [torch.empty(3, device="cuda") for _ in range(world)]
dist.all_gather(full, shard)                          # every rank now has all slices

grad  = torch.randn(3 * world, device="cuda")        # full-length gradient
out   = torch.empty(3, device="cuda")
dist.reduce_scatter_tensor(out, grad, op=dist.ReduceOp.SUM)  # keep only my slice

# --- or let FSDP schedule both around every layer ---
model = FSDP(my_model)                                # shards params/grads/optimizer
# forward/backward now all-gather weights and reduce-scatter grads under the hood
Code 4.5.2: The two collectives as single torch.distributed calls, and the one-line FSDP wrapper that schedules them around every layer. Roughly twenty lines of manual shard bookkeeping per layer collapse into the wrapper, which also overlaps the gathers with computation and frees each gathered weight as soon as its layer finishes.
Practical Example: The 13-Billion-Parameter Model That Refused to Fit

Who: A research engineer at a startup fine-tuning a 13-billion-parameter language model on a node of eight 40 GB GPUs.

Situation: Plain data-parallel training (one full replica per GPU) ran out of memory before the first step finished; Adam's optimizer state and master weights alone needed far more than 40 GB per device.

Problem: A single replica's parameters, gradients, and optimizer state summed to roughly 200 GB in mixed precision, and replication forced all of it onto every one of the eight GPUs.

Dilemma: Rent scarce 80 GB GPUs at a steep premium and keep replicating, or shard the model state across the eight 40 GB cards they already had and accept extra communication per step.

Decision: They sharded with FSDP, betting that the node's fast intra-node interconnect (Section 4.2) would keep the added all-gather and reduce-scatter traffic cheap relative to the memory it saved.

How: They wrapped the transformer blocks in FSDP with per-block wrapping so each block's weights were all-gathered only while that block ran, then freed, keeping at most one block's full weights resident at a time.

Result: Resident state per GPU fell from about 200 GB to roughly 25 GB, the model trained on the eight 40 GB cards, and the per-step slowdown from the extra collectives was about 20 percent, far cheaper than renting 80 GB hardware.

Lesson: When the binding ceiling is memory rather than throughput, sharding the model state and paying in communication is the move; the all-gather and reduce-scatter are the levers that make a model fit where it otherwise could not.

4. When the Trade Pays, and When It Does Not Intermediate

Sharding is not free, and the same cost discipline from Chapter 3 decides whether it helps. FSDP adds about half again the communication of plain data parallelism, so its benefit appears only when memory is the binding constraint. If a model already fits comfortably with room for the optimizer state, plain replication is faster because it communicates less. As the model grows past one device's memory, replication simply stops working, and FSDP's extra traffic becomes the price of training at all rather than a slowdown over a working alternative. The interesting regime is the middle, where the model nearly fits: there, partial sharding (shard the optimizer state but replicate the parameters, the lightest ZeRO stage) often wins by cutting memory enough to fit while adding the least communication.

Two levers govern the trade. The first is interconnect speed: the all-gathers and reduce-scatters are bandwidth-bound, so a fast intra-node fabric makes sharding nearly free while a slow cross-node link makes it punishing, which is why FSDP shines within a node and is often combined with other parallelism strategies across nodes (Chapter 16). The second is overlap: because each layer's weight all-gather can be issued before the layer's compute begins, a good implementation prefetches the next layer's weights while the current layer runs, hiding most of the communication behind computation, a theme Section 4.10 develops for gradients.

Fun Note: The Model That Is Never All in One Place

Under full sharding there is a strange moment to savor: at rest, the complete model does not exist anywhere. No single GPU, and indeed no single process, holds all the weights at once. The full layer materializes for a few milliseconds inside an all-gather, does its job, and dissolves back into shards. The model is less an object than a recurring agreement among the ranks to briefly assemble, compute, and disband.

Research Frontier: FSDP2, ZeRO++, and Sharding That Hides Its Own Cost (2024 to 2026)

The sharded-data-parallel line is actively engineered to shrink its communication tax. PyTorch's FSDP2 (the fully_shard API stabilized through 2024 to 2025) rebuilt the original FSDP on per-parameter DTensor sharding, giving cleaner composition with tensor and pipeline parallelism, lower memory fragmentation, and simpler mixed-precision and checkpointing semantics. On the DeepSpeed side, ZeRO++ (Wang et al., 2023, deployed and refined through this window) attacks the bandwidth directly with three changes: quantized weights during the all-gather (sending roughly half the bytes), a hierarchical secondary shard kept within each node to avoid slow cross-node gathers, and quantized gradients during the reduce-scatter, together reporting up to a fourfold cut in communication volume. A complementary thread overlaps and prefetches the collectives so aggressively that, on fast fabrics, the all-gather nearly vanishes behind compute. The common thread is that the primitives of this section stay fixed; the research is about moving fewer bytes through them and hiding the rest. We return to these trade-offs with the full parallelism toolkit in Chapter 16.

We now have all-reduce decomposed into its two working halves and have seen those halves carry the entire weight of memory-efficient large-model training. Reduce-scatter and all-gather move data along the contiguous-slice partition we chose; the next section asks what happens when the destinations are not slices of one vector but entire messages that must each go to a different rank, the all-to-all pattern that routes tokens to experts in a mixture-of-experts model. That is the subject of Section 4.6.

Exercise 4.5.1: Count the Bytes Analysis

A model has $P = 8 \times 10^{9}$ parameters trained across $K = 8$ GPUs in a single node. Using the ring-style rule that a reduce-scatter and an all-gather each move $P(K-1)/K$ elements per rank, compute the per-rank communication volume of one plain data-parallel step (a single gradient all-reduce) and of one FSDP step (per layer: two all-gathers plus one reduce-scatter, but treat the model as one big layer for the estimate). Express the FSDP volume as a multiple of the data-parallel volume, and state in one sentence the condition on the model's memory footprint under which paying that multiple is worthwhile.

Exercise 4.5.2: Implement a Ring Reduce-Scatter Coding

Extend Code 4.5.1 by replacing the one-shot reduce_scatter with a ring version that simulates $K$ ranks passing chunks around a logical ring for $K-1$ steps, accumulating into the chunk each rank will finally own (mirror the reduce phase of the ring all-reduce in Section 4.4). Verify it produces the identical owned slices as the array version, then compose it with your all-gather and confirm the all-reduce identity still holds bit for bit. Print the total number of element-transfers and check it equals $K \cdot P(K-1)/K = P(K-1)$.

Exercise 4.5.3: Replicate, Shard, or Half-Shard? Conceptual

You must train three models on a node of eight 80 GB GPUs: (a) a 1-billion-parameter model, (b) a 30-billion-parameter model, (c) a 7-billion-parameter model where the parameters fit replicated but Adam's optimizer state does not. For each, decide among plain data parallelism (replicate everything), full FSDP (shard parameters, gradients, and optimizer state), and optimizer-state-only sharding (the lightest ZeRO stage). Justify each choice in terms of the binding ceiling and the communication each option adds, and explain why choosing full sharding for case (a) would waste bandwidth for no memory benefit.