Part IV: Parallel Deep Learning and Large Models
Chapter 16: Model, Pipeline, and Sharded Parallelism

PyTorch FSDP

"They told me I was the whole model. Turns out I was one fourth of three layers, gathered into existence for a few microseconds at a time, then politely freed."

A Shard That Believes It Is the Whole Model
Big Picture

Fully Sharded Data Parallel (FSDP) is PyTorch's native way to train a model that does not fit on one device: it keeps every parameter sharded across the data-parallel workers at rest, and reconstitutes each layer's full weights only for the microsecond it is being used, then throws them away again. Section 16.4 described the idea abstractly as ZeRO-3: shard parameters, gradients, and optimizer state across the data-parallel group, and pay an all-gather to materialize a layer when you need it. FSDP is the implementation that turns that idea into roughly five lines of PyTorch. The whole method reduces to one repeated cycle, performed unit by unit through the forward and backward passes: all-gather the unit's full parameters, run compute, free the gathered copy, reduce-scatter the gradients. Everything else (the wrapping policy, the prefetch, the FSDP2 rewrite) is about making that cycle cheaper or its memory peak lower. This section makes the cycle concrete, runs it in miniature, and shows you the production call.

In Section 15.6 we used DistributedDataParallel (DDP), which replicates the entire model on every worker and synchronizes gradients with an all-reduce. DDP is the right tool exactly when the model, its gradients, and its optimizer state all fit comfortably on one device; the replication is then free of extra memory cost and the only communication is the gradient all-reduce. The moment the model stops fitting, replication is no longer an option, and Section 16.4 showed the escape: stop replicating and start sharding. FSDP is the PyTorch component that does the sharding for you. It wraps the model in a tree of units, shards the parameters of each unit across the data-parallel group, and inserts the gather-compute-free cycle automatically around every unit's forward and backward. You write almost the same training loop you wrote for DDP; FSDP changes what lives in memory and what crosses the network underneath it.

1. The Gather, Compute, Free Cycle Intermediate

Picture the model as an ordered list of FSDP units, each unit being a group of layers (often a single transformer block). At rest, unit $u$ does not hold its full weight matrix anywhere; instead each of the $K$ workers in the data-parallel group holds a $1/K$ slice of it. This is the ZeRO-3 state from Section 16.4: parameters, gradients, and optimizer state all sharded $K$ ways. The total parameter memory per worker is therefore the whole model divided by $K$, not the whole model. The catch is that no single worker can run a matrix multiply against a $1/K$ slice; it needs the full weight. FSDP supplies it just in time.

When the forward pass reaches unit $u$, FSDP issues an all-gather: every worker sends its slice of unit $u$ and receives the others, so each worker briefly holds the complete weights of unit $u$. The worker runs the unit's forward compute, producing activations, and then immediately frees the gathered full weights, returning to holding only its slice. The peak full-parameter memory is therefore one unit at a time, never the whole model. The backward pass mirrors this: FSDP all-gathers unit $u$ again to compute its gradient, frees the weights, and then issues a reduce-scatter so that each worker ends up holding the averaged gradient for only its own slice. Reduce-scatter is the collective that Chapter 4 introduced as the second half of a ring all-reduce; here it does double duty as both the gradient average and the re-sharding. Figure 16.5.1 traces the full cycle for one unit.

Sharded at rest worker 1 slice worker 2 slice worker 3 slice worker 4 slice all-gather Full unit u params resident (one unit only) compute Forward / backward acts, grads free full params Free unit u back to slice reduce-scatter grads grad slice per worker Prefetch unit u+1 all-gather next unit during this compute overlap hides communication The cycle repeats unit by unit: gather, compute, free, reduce-scatter. Peak full-parameter memory is one unit; the prefetch (dashed) overlaps the next gather with the current compute.
Figure 16.5.1: The FSDP gather-compute-free cycle for one unit. Parameters live sharded across the $K$ workers (left). FSDP all-gathers the full unit just before compute, runs the forward or backward, frees the full copy, and reduce-scatters the gradients back to per-worker slices. The dashed orange arrow is the prefetch of Section 4: the next unit's all-gather is launched during the current unit's compute so the communication hides behind useful work.
Key Insight: Sharding Trades Memory for Communication, One Unit at a Time

FSDP never holds the whole model in full form on any worker. It holds one unit in full, for as long as that unit is computing, and one unit's worth of sharded slices for everything else. The memory bill drops from "whole model" to "whole model over $K$, plus one resident unit". The price is an extra all-gather per unit in the forward pass and another in the backward pass, on top of the reduce-scatter that replaces DDP's all-reduce. You are not getting the memory saving for free; you are paying for it in communication volume, and the rest of this section is about keeping that bill small.

2. The Wrapping Policy and Its Trade-Off Advanced

The single most consequential choice in FSDP is the wrapping policy: how the layers of your model are grouped into units. Each unit is the granularity at which FSDP all-gathers and frees. The two extremes bound the trade-off. If you wrap the entire model as one giant unit, FSDP all-gathers everything at once, which means the full model is briefly resident, and you have saved nothing on peak memory; you have merely added communication. If you wrap every individual layer as its own unit, peak full-parameter memory is one layer, the smallest possible, but you pay an all-gather and a free for every single layer, and many small collectives are far less efficient than a few large ones because each carries a fixed latency cost (the $\alpha$ term of the $\alpha\text{-}\beta$ model from Chapter 3).

The principle is therefore: more, smaller units lower the memory peak but raise the communication cost; fewer, larger units do the reverse. The sweet spot for transformers is almost always to wrap one transformer block per unit, which is both a natural memory granularity and a large enough chunk that the all-gather is bandwidth-bound rather than latency-bound. PyTorch ships a transformer_auto_wrap_policy that does exactly this when you name the block class. Let $M$ be the whole-model parameter count, $K$ the number of workers, and $U$ the number of units; the per-worker peak parameter memory is approximately

$$\text{peak} \approx \underbrace{\frac{M}{K}}_{\text{resting shards}} + \underbrace{\max_u m_u}_{\text{one resident unit}},$$

where $m_u$ is the size of unit $u$. Shrinking the largest unit shrinks the second term; the first term is fixed by $K$. The miniature in Section 5 measures exactly this peak and confirms it stays at one unit.

3. FSDP Versus DDP: A Decision, Not a Default Intermediate

FSDP and DDP solve the same problem (data-parallel training across $K$ workers) with opposite memory strategies, and the choice between them is a measurement, not a slogan. DDP replicates: every worker holds a full copy of the model, gradients, and optimizer state, and the only communication is one gradient all-reduce per step. FSDP shards: every worker holds $1/K$ of each, and pays all-gathers to reconstitute units on demand. Table 16.5.1 lays the two side by side.

Table 16.5.1: DDP and FSDP compared along the dimensions that decide which to use. The rule is simple: use the least sharding that fits.
DimensionDDP (Section 15.6)FSDP (this section)
Parameters per workerfull model$M / K$ plus one unit
Optimizer state per workerfull$1 / K$
Communication per stepone gradient all-reduceall-gather per unit (forward and backward) plus reduce-scatter
When it is the right toolmodel, grads, optimizer all fit on one devicethey do not fit; you must shard
Communication volumelowerhigher (roughly $1.5\times$ the bytes)

The decision rule fits in one line: use DDP when the model fits on one device, and FSDP when it does not. DDP moves fewer bytes, so when replication is affordable it is faster; FSDP moves more bytes but makes training possible at all when the model is too large to replicate. There is a middle ground, hybrid sharding, where FSDP shards within a node and replicates across nodes, trading some memory saving for less inter-node traffic; it is the natural choice when a node's combined GPU memory holds the model but a single GPU does not.

4. Prefetching: Hiding the All-Gather Behind Compute Advanced

The extra all-gathers are FSDP's central cost, and the central remedy is the overlap idea that runs through this whole book: launch the communication early, on a separate stream, so it finishes while the device is busy with compute it could do anyway. We first met this in Section 4.10 as overlapping the gradient all-reduce with the backward pass; FSDP applies the same principle to the all-gather. While unit $u$ is running its forward compute, FSDP issues the all-gather for unit $u+1$ in the background, so that by the time compute reaches unit $u+1$ its full parameters are already resident and the device never stalls waiting on the network. This is forward prefetch. The backward pass does the same in reverse, prefetching the all-gather of the next unit to be differentiated. The dashed arrow in Figure 16.5.1 is exactly this overlap.

When prefetch works perfectly, the all-gather of every unit is fully hidden behind the compute of the previous unit, and FSDP's wall-clock per step approaches DDP's despite moving more bytes. When it does not (because a unit's compute is too short to cover its successor's gather, or the interconnect is too slow), the exposed all-gather time is the price you pay for not fitting in DDP. Profiling the overlap, and resizing units so that compute covers communication, is the core tuning loop for FSDP, and it is the same overlap accounting that Chapter 4 taught for collectives in general.

Fun Note: The Parameter That Exists Only When Observed

There is something pleasingly quantum about an FSDP parameter. At rest it is smeared across four workers as four slices, none of which is the real weight. It snaps into full existence the instant a matrix multiply observes it, lives for the duration of that one operation, and dissolves back into slices the moment compute moves on. A whole transformer block is, most of the time, a superposition of fragments that have never met. The all-gather is the act of measurement, and like any good measurement it is expensive, which is why we try to do it while looking the other way.

5. The Cycle in Miniature, Verified Intermediate

FSDP needs a real multi-GPU cluster to run, so to see the gather-compute-free cycle we implement it in pure Python on one machine, simulating the $K$ workers as $K$ row-slices of each weight matrix held in a list. The model is a four-layer linear network, wrapped as four units (one per layer). At rest each unit is stored sharded; during the forward and backward pass we all-gather each unit just before compute, free it after, and reduce-scatter its gradients. We track the peak number of full parameters resident at any moment and check it against the whole-model count. To prove the sharding changes nothing about the math, we run an ordinary unsharded baseline with identical initialization and compare the two loss curves.

import numpy as np

rng = np.random.default_rng(0)
K = 4                                  # simulated FSDP workers (shards)
dims = [16, 32, 32, 16, 8]             # 4 weight matrices -> 4 FSDP units
N, lr, steps = 256, 0.05, 40

# ... (synthetic task and shared init W0 elided; full script in the book repo) ...

def shard(W):        return np.array_split(W, K, axis=0)      # K row-slices
def all_gather(s):   return np.concatenate(s, axis=0)         # rebuild full unit
def reduce_scatter(g): return np.array_split(g, K, axis=0)    # grads back to slices

units = [shard(W.copy()) for W in W0]            # persistent SHARDED state
peak_full_params = 0
for _ in range(steps):
    acts, pre, h = [Xin], [], Xin
    for u in range(4):
        full_W = all_gather(units[u])            # ALL-GATHER unit u
        peak_full_params = max(peak_full_params, full_W.size)  # one unit resident
        z = h @ full_W; pre.append(z); h = np.maximum(z, 0); acts.append(h)
        del full_W                               # FREE gathered params
    g = (2.0 / target.size) * (h - target)
    for u in reversed(range(4)):
        g = g * (pre[u] > 0)
        full_W = all_gather(units[u])            # re-gather for backward
        full_grad = acts[u].T @ g; g = g @ full_W.T
        del full_W                               # FREE again
        grad_shards = reduce_scatter(full_grad)  # REDUCE-SCATTER grads
        units[u] = [s - lr * gs for s, gs in zip(units[u], grad_shards)]
Code 16.5.1: The FSDP cycle stripped to its essence. State lives sharded in units; every unit is gathered to full form only inside its own for u iteration and freed (del full_W) before the next, so peak_full_params can never exceed the largest single unit. The reduce-scatter re-shards each gradient onto its owning worker.
workers K                       : 4
whole-model parameters          : 2176
largest single FSDP unit        : 1024
peak resident full params (FSDP): 1024
peak / whole-model ratio        : 0.471
baseline final loss             : 4.771891e-02
FSDP     final loss             : 4.771891e-02
max loss-curve gap (all steps)  : 0.00e+00
Output 16.5.1: The sharded run and the unsharded baseline produce bit-identical loss curves (maximum gap $0$ across all forty steps), while peak resident full parameters stay at one unit (1024 of 2176, ratio 0.471). Sharding changed the memory profile and nothing about the result, exactly as the ZeRO-3 identity of Section 16.4 promises.

The gap of exactly zero is the same lesson as the gradient identity of Section 1.1: rearranging where the arithmetic happens, and when each weight is materialized, does not change the arithmetic. Peak full-parameter residency held at the largest unit (1024 numbers) rather than the whole model (2176), so a model twice this size would still fit in the same resident budget by adding workers. That is the entire value proposition of FSDP, demonstrated without a GPU.

6. FSDP in Production, and FSDP2 Intermediate

In real PyTorch you do not write the cycle; you declare the wrapping policy and FSDP inserts the gathers, frees, and reduce-scatters for you. The code below shows the shape of a production wrap. It needs a multi-process launch (torchrun) and a real cluster, so we present it illustratively rather than running it here.

Library Shortcut: torch FSDP Wraps the Whole Cycle in One Call

Code 16.5.1 spelled out the gather, free, and reduce-scatter by hand across roughly twenty lines. PyTorch's FullyShardedDataParallel collapses all of it into a single wrap plus an auto-wrap policy; the all-gathers, frees, reduce-scatters, prefetch scheduling, and stream management are inserted for you and never appear in your training loop.

# Run with: torchrun --nproc_per_node=8 train_fsdp.py
import functools, torch, torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from mymodel import TransformerBlock, build_model

dist.init_process_group("nccl")               # join the group of K workers
model = build_model().cuda()

# Wrap one transformer block per FSDP unit: the memory/communication sweet spot.
policy = functools.partial(transformer_auto_wrap_policy,
                           transformer_layer_cls={TransformerBlock})
model = FSDP(model, auto_wrap_policy=policy)   # params now sharded across K workers

opt = torch.optim.AdamW(model.parameters(), lr=3e-4)
for batch in loader:                           # the loop looks just like DDP
    opt.zero_grad()
    loss = model(batch).loss                   # FSDP all-gathers each unit on demand
    loss.backward()                            # ... and reduce-scatters its grads
    opt.step()                                 # optimizer step on local shards only
Code 16.5.2: A production FSDP wrap. The only lines that differ from a DDP script (Section 15.6) are the import, the auto_wrap_policy, and the FSDP(...) call; the twenty-line cycle of Code 16.5.1 lives entirely inside that wrap. The optimizer sees only this worker's parameter shards, which is why optimizer state is sharded $K$ ways for free.

The original FSDP (now called FSDP1) wrapped each unit as a single flattened parameter, which complicated mixed precision, parameter freezing, and tensor-subclass features. FSDP2, the rewrite that became the recommended path in 2024, changes the representation: instead of one flat buffer per unit it shards each parameter individually as a DTensor (a distributed tensor that knows its own sharding), a design often called per-parameter sharding. This makes partial freezing, per-parameter dtypes, and composition with tensor parallelism (Section 16.2) clean rather than special-cased, and it removes much of FSDP1's bookkeeping. The gather-compute-free cycle is unchanged; what changed is the data structure the cycle operates on.

Research Frontier: Per-Parameter Sharding and Composable Parallelism (2024 to 2026)

FSDP2's move to per-parameter DTensor sharding is part of a broader 2024-to-2026 push toward composable parallelism in PyTorch. Because every parameter now carries its own sharding metadata, FSDP2 composes cleanly with tensor parallelism and pipeline parallelism through the same DeviceMesh abstraction, the basis of the "torchtitan" reference stack for training large language models with 2D and 3D parallelism in native PyTorch (Liang et al., 2024). The per-parameter representation also unlocks cleaner low-precision training: FSDP2 supports gathering parameters in a reduced dtype and reduce-scattering gradients in another, and integrates with float8 training paths. The research direction is to make data, tensor, pipeline, and expert sharding orthogonal layers you can stack with a few lines rather than a bespoke framework, which is exactly the 3D-parallelism trade-off space that Chapter 3 taught you to cost out and that Section 16.6 reaches through DeepSpeed and Megatron-LM.

Thesis Thread: The Collective Returns, Re-Sharded

The all-reduce you computed by hand in Section 1.1 and met as DDP's gradient sync in Section 15.6 has now split into its two halves. FSDP's reduce-scatter is the second half of a ring all-reduce, and its all-gather is the first half, used not to share a finished sum but to materialize a sharded parameter on demand. The same primitives from Chapter 4 are being recombined to buy memory instead of just averaging gradients. When you reach expert parallelism in Chapter 17, the collective will mutate once more, into all-to-all; the lesson is that scale-out methods are mostly distinct recombinations of the same handful of collectives.

Practical Example: The 13-Billion-Parameter Model That Would Not Load

Who: An applied scientist at a language-tooling startup fine-tuning a 13-billion-parameter model on eight 40 GB GPUs.

Situation: The team's existing pipeline used DDP, which had worked fine for their earlier 1.3-billion-parameter model.

Problem: With DDP, every GPU tried to hold the full 13-billion-parameter model plus AdamW state (roughly 12 bytes per parameter for the optimizer alone), about 156 GB of optimizer state, and the job died with an out-of-memory error before the first step.

Dilemma: Buy 80 GB GPUs at far higher cost (scale up), or shard the model across the eight GPUs they already had with FSDP (scale out), accepting more communication for a chance at fitting.

Decision: They switched to FSDP, because the binding ceiling was memory, not throughput, and sharding addresses memory directly while replication cannot.

How: They wrapped the model with transformer_auto_wrap_policy on the block class (Code 16.5.2), enabled forward and backward prefetch, and left the training loop otherwise unchanged.

Result: Per-GPU memory dropped from "whole model" to roughly $M/8$ plus one transformer block, the job fit with headroom for a larger batch, and prefetch hid most of the extra all-gather time so the step was only modestly slower than DDP would have been if it had fit.

Lesson: When memory is the ceiling and the model cannot be replicated, FSDP is the tool; the extra communication is a price worth paying because the alternative is not slower training but no training at all.

FSDP is the modern default for large-model training in PyTorch precisely because it makes the ZeRO-3 idea of Section 16.4 a five-line change to a DDP script. You now understand the cycle it runs, the wrapping policy that tunes its memory-versus-communication trade-off, the prefetch that hides its cost, and the rule for choosing it over DDP. The next section steps up to the two frameworks that pioneered and industrialized these ideas at the largest scales, DeepSpeed (which introduced ZeRO) and Megatron-LM (which introduced tensor parallelism), and shows how they compose sharding with the other parallelism axes. That story begins in Section 16.6.

Exercise 16.5.1: Count the Bytes on the Wire Analysis

A model has $M$ parameters wrapped into $U$ equal FSDP units across $K$ workers, trained in 16-bit precision (2 bytes per parameter). Per training step, write expressions for the total bytes each worker sends-plus-receives for (a) the forward all-gathers, (b) the backward all-gathers, and (c) the gradient reduce-scatter. Compare the sum to DDP's single gradient all-reduce (which moves about $2M$ bytes per worker for a ring algorithm). Confirm the "roughly $1.5\times$" claim of Table 16.5.1, and state what happens to each term as $U$ grows with $K$ fixed.

Exercise 16.5.2: Move the Wrapping Boundary Coding

Starting from Code 16.5.1, change the unit granularity: first wrap all four layers as a single unit (gather everything once), then wrap each layer as its own unit (the current code), and finally an intermediate grouping of two units of two layers each. For each, report peak_full_params and the number of all-gather calls per step. Verify that the loss curve is identical in all three cases and that peak memory and all-gather count move in opposite directions, demonstrating the trade-off of Section 2 numerically.

Exercise 16.5.3: When Does Prefetch Pay Off? Conceptual

Forward prefetch overlaps unit $u+1$'s all-gather with unit $u$'s forward compute. Suppose unit $u$'s compute takes time $c$ and its successor's all-gather takes time $a$. State the condition on $c$ and $a$ under which prefetch fully hides the communication, and the exposed (non-overlapped) time when it does not. Then argue qualitatively how making units larger (Section 2) affects both $c$ and $a$, and why there can be a unit size that minimizes total step time rather than either memory or communication alone.