Part IV: Parallel Deep Learning and Large Models
Chapter 16: Model, Pipeline, and Sharded Parallelism

When the Model No Longer Fits on One Device

"I held the whole model with room to spare. Then someone whispered the word 'optimizer,' and suddenly I was renting space I did not have."

A Shard That Believes It Is the Whole Model
Big Picture

Data parallelism makes every worker faster, but it makes no worker bigger; the instant one copy of the model plus its training state exceeds a single accelerator's memory, replication stops working and the model itself has to be cut into pieces that live on different devices. Training memory is not just the parameters. It is the parameters, a gradient for every parameter, the optimizer's bookkeeping (which for Adam is two extra full-size tensors), and the activations saved for the backward pass, and that sum is roughly an order of magnitude larger than the weights alone. This section measures that budget exactly, shows precisely where it crosses the 40 GB and 80 GB lines of real accelerators, and lays out the four remedies this chapter develops: split the operators, split the layers, shard the state, and trade compute for memory.

Chapter 15 gave us a powerful move: replicate the model on every worker, hand each worker a different slice of the batch, and combine the gradients with an all-reduce that is mathematically exact. That move scales throughput beautifully, and for any model that fits on one device it is usually the right first tool. It has one hard limit, and the limit has nothing to do with how many workers you own. Data parallelism keeps a full copy of the model, its gradients, and its optimizer state on each and every device. If that full copy does not fit on one device, adding a thousand more devices does not help at all, because each of those devices must still hold the whole thing. You have hit the memory wall, and crossing it is what this chapter is about.

The confusion that traps newcomers is the gap between how big a model "is" and how much memory it takes to train. A model is often quoted by its parameter count, and a parameter in half precision is two bytes, so a seven-billion-parameter model sounds like a fourteen-gigabyte object that should slot comfortably into an 80 GB accelerator. It does not, and the reason is that training drags along several more tensors the size of the model, plus a pile of activations that grows with the batch and the sequence length. Before we name a single sharding technique, we have to see that budget clearly, because the budget is what forces every technique that follows.

1. The Wall Data Parallelism Cannot Cross Beginner

Recall the structure of data-parallel training. There are $K$ workers; worker $k$ holds an identical replica of the model, computes a gradient on its own shard of the batch, and the workers all-reduce their gradients so that every replica takes the same optimizer step. The elegance is that the replicas stay in lockstep and the math is exact. The cost, the one that matters here, is that the per-device memory footprint of data parallelism is independent of $K$. Eight workers and eight thousand workers each store one complete copy of the parameters, the gradients, and the optimizer state. Data parallelism buys speed by spreading the batch; it buys nothing on the memory axis.

This is why a model that overflows one device cannot be rescued by data parallelism alone. The remedy has to reduce what a single device must hold, and there are only a few ways to do that: give each device a different part of the model rather than the whole thing, or split the per-device state itself so no device stores all of any tensor. Both ideas require the model to be physically partitioned across devices, which is the defining feature of every method in this chapter and the reason model parallelism is a genuinely different discipline from the data parallelism of the previous chapter.

Key Insight: Data Parallelism Scales Speed, Not Capacity

Adding data-parallel workers shrinks the time per epoch but never shrinks the memory each worker needs, because every worker stores a full replica of the model and its training state. The moment that replica exceeds one device's memory, you are out of the regime data parallelism can serve, and no number of additional workers changes that. Crossing the memory wall always means partitioning the model itself, never replicating it more.

2. The Real Memory Budget of Training Intermediate

Let a model have $P$ parameters and let us train it with mixed precision and the Adam optimizer, the standard recipe for large models. Four distinct things consume accelerator memory, and it pays to count each one in bytes per parameter. First, the parameters used in the forward pass are stored in half precision, $2$ bytes each. Second, a gradient is produced for every parameter, also in half precision, another $2$ bytes. Third, the optimizer state: mixed-precision training keeps a full-precision master copy of the weights ($4$ bytes) so that tiny updates are not lost to rounding, and Adam additionally tracks a first moment $m$ and a second moment $v$, each a full-precision tensor the size of the model ($4 + 4$ bytes). That is $12$ bytes per parameter of optimizer and master state, the part that surprises people. Summing the model-state pieces,

$$M_{\text{state}} = \underbrace{2P}_{\text{weights (fp16)}} + \underbrace{2P}_{\text{grads (fp16)}} + \underbrace{12P}_{\text{master + Adam } m,\,v \text{ (fp32)}} = 16\,P \text{ bytes}.$$

So the model state alone is about $16$ bytes per parameter, eight times the naive "two bytes per parameter" intuition. A seven-billion-parameter model therefore needs roughly $112$ GB just for state, before a single activation is stored, which already overflows an 80 GB accelerator. The fourth consumer is the activations: the intermediate tensors saved during the forward pass so the backward pass can reuse them. Unlike the first three, activation memory does not scale with $P$; it scales with how much data is in flight. For a transformer it grows with the batch size $B$, the sequence length $S$, the hidden width $H$, and the number of layers $L$,

$$M_{\text{act}} \approx c \cdot B \cdot S \cdot H \cdot L,$$

where the constant $c$ collects the several intermediate tensors a transformer block keeps per token (a common rule of thumb puts $c$ near $34$ for a standard block without recomputation). The full training budget is $M_{\text{total}} = M_{\text{state}} + M_{\text{act}} = 16P + c\,BSHL$. The figure below draws this budget as four stacked bars against the capacity lines of real accelerators, so you can see at a glance which models cross which wall.

40 GB 80 GB 0 80 memory (GB) 125 M 6.6 GB 1.3 B 44.9 GB 6.7 B 167.8 GB 70 B 1383 GB parameters (fp16) gradients (fp16) optimizer + master activations bars above the dashed lines cannot be trained on one device at all.
Figure 16.1.1: The training-memory budget for four model sizes, each bar stacked into parameters, gradients, optimizer state, and activations, against the $40$ GB and $80$ GB capacity lines (dashed). The $125$ M model fits easily. The $1.3$ B model already overflows $40$ GB. The $6.7$ B and $70$ B bars (arrows mark where they run off the top) tower far above both lines, and notice how the orange optimizer slab dominates: it is the $12P$ bytes of master weights and Adam moments, not the parameters themselves, that drives the wall.

3. Counting It Exactly, in Code Intermediate

The formula is more convincing when you watch it cross the device lines for real model sizes. The program below applies $M_{\text{total}} = 16P + 34\,BSHL$ to a ladder of transformer configurations, from a $125$ M model up to $70$ B, prints the four components in gigabytes, and flags whether each total fits on a $40$ GB or an $80$ GB accelerator. It closes by contrasting the naive two-bytes-per-parameter guess against the real training footprint for one model, to make the inflation factor concrete. It is pure Python and uses no accelerator; it is doing arithmetic on the same budget Figure 16.1.1 draws.

GB = 1024 ** 3

def breakdown(P, batch, seq, layers, hidden, bytes_per_param=2):
    # Mixed-precision + Adam training-memory model, in bytes.
    params_b = P * bytes_per_param   # fp16 weights used in the forward pass
    grads_b  = P * 2                 # fp16 gradient, one per parameter
    optim_b  = P * 12                # fp32 master weight + Adam m + Adam v (4+4+4)
    # Activations of a transformer scale with batch * seq * hidden * layers.
    act_b    = 34 * batch * seq * hidden * layers
    total = params_b + grads_b + optim_b + act_b
    return params_b, grads_b, optim_b, act_b, total

configs = [   # name,    P,           layers, hidden, batch, seq
    ("125 M",   125_000_000,   12,   768,  8, 2048),
    ("1.3 B",  1_300_000_000,  24,  2048,  8, 2048),
    ("6.7 B",  6_700_000_000,  32,  4096,  8, 2048),
    ("13 B",   13_000_000_000, 40,  5120,  8, 2048),
    ("30 B",   30_000_000_000, 48,  7168,  8, 2048),
    ("70 B",   70_000_000_000, 80,  8192,  8, 2048),
]

for name, P, L, H, B, S in configs:
    p, g, o, a, t = breakdown(P, B, S, L, H)
    f40 = "yes" if t < 40 * GB else "NO"      # fits an A100-40GB?
    f80 = "yes" if t < 80 * GB else "NO"      # fits an A100/H100-80GB?
    print(f"{name:>7} | {p/GB:6.1f}G {g/GB:6.1f}G {o/GB:7.1f}G "
          f"{a/GB:7.1f}G {t/GB:7.1f}G |  {f40:>3}   {f80:>3}")

# The naive guess versus the truth, for the 6.7B model.
_, _, _, _, t = breakdown(6_700_000_000, 8, 2048, 32, 4096)
print(f"\nnaive 2B/param : {6_700_000_000*2/GB:.2f} GB")
print(f"real training  : {t/GB:.2f} GB  ({t/(6_700_000_000*2):.1f}x larger)")
Code 16.1.1: A pure-Python training-memory calculator. Each row sums the four consumers from Section 2 and tests the total against the two capacity lines; the closing block exposes how far the real footprint sits above the two-bytes-per-parameter intuition.
  125 M |    0.2G    0.2G     1.4G     4.8G     6.6G |  yes   yes
  1.3 B |    2.4G    2.4G    14.5G    25.5G    44.9G |   NO   yes
  6.7 B |   12.5G   12.5G    74.9G    68.0G   167.8G |   NO    NO
   13 B |   24.2G   24.2G   145.3G   106.2G   300.0G |   NO    NO
   30 B |   55.9G   55.9G   335.3G   178.5G   625.5G |   NO    NO
   70 B |  130.4G  130.4G   782.3G   340.0G  1383.1G |   NO    NO

naive 2B/param : 12.48 GB
real training  : 167.84 GB  (13.4x larger)
Output 16.1.1: Only the $125$ M model fits on either device; the $1.3$ B model already overflows $40$ GB, and everything from $6.7$ B up overflows even $80$ GB by a wide margin. The $6.7$ B model that "should" be a $12.5$ GB object actually needs $167.8$ GB to train, $13.4$ times the naive guess, because optimizer state and activations dwarf the weights.

Read the columns and the wall comes into focus. The optimizer column is the single largest contributor for every model above a billion parameters, exactly the $12P$ bytes the formula predicted, which is why simply sharding the optimizer state across devices (the subject of Section 16.4) buys the largest single win. Activations are the second giant, and they are the one term you can shrink by spending extra compute rather than extra memory, which is the trade Section 16.8 makes. The parameters and gradients, the parts people think of as "the model," are the two smallest slivers. A serious large-model training run is mostly storing the machinery of optimization, not the model.

Fun Note: The Model Is the Smallest Thing in the Room

It feels backwards, but in Output 16.1.1 the actual weights of the $70$ B model take $130$ GB while the optimizer takes $782$ GB. If the parameters were a person, the Adam moments and master copy would be that person's moving truck, packing crates, and inventory spreadsheet. Training is less about carrying the model and more about carrying everything you need to keep improving it.

4. The Taxonomy of Fixes This Chapter Develops Beginner

Every term in the budget $16P + 34\,BSHL$ points at a different way to fit a model that does not fit. There are four families of remedy, and the rest of this chapter is one family per cluster of sections. Table 16.1.1 names them, ties each to the part of the budget it attacks, and points to where the book develops it. The unifying idea is that each remedy partitions some tensor across devices so that no single device must hold all of it.

Table 16.1.1: The four families of model-fitting remedy, the memory term each one attacks, and where this chapter develops it. Real large-model training stacks several of these together with the data parallelism of Chapter 15.
RemedyWhat it partitionsBudget term it attacksWhere
Tensor (operator) parallelismEach layer's matrices, split across devicesParameters, grads, activations of a layerSection 16.2
Pipeline parallelismWhole layers grouped into stages on different devicesParameters, grads, optimizer per stageSection 16.3
Sharded state (ZeRO / FSDP)Optimizer, gradients, then parameters, sharded across data-parallel ranksThe $12P$ optimizer state first, then all $16P$Sections 16.4 to 16.5
Activation checkpointingNothing; recompute activations instead of storing themThe $34\,BSHL$ activation termSection 16.8

The first remedy splits the operators. A single matrix multiply inside a layer is cut so that each device computes a slice of the output, and the slices are stitched with a collective; this is tensor parallelism, and Section 16.2 builds it from the matrix algebra up. The second remedy splits the layers: assign the first group of layers to device one, the next group to device two, and stream micro-batches through the resulting pipeline so the stages stay busy. The third remedy keeps the data-parallel structure of Chapter 15 but refuses to replicate the state; instead it shards the optimizer state, then the gradients, then the parameters across the data-parallel ranks, gathering each tensor only when a layer needs it. This is the ZeRO idea behind FSDP, and it is the direct descendant of the parameter-server sharding from Chapter 11. The fourth remedy is different in kind: it stores almost no activations and recomputes them during the backward pass, paying extra compute to reclaim the $34\,BSHL$ term.

Thesis Thread: Replication Gives Way to Partition

Chapter 15 distributed training by replicating the model and partitioning the data; the collective that held it together was all-reduce. This chapter turns the dial the other way. To cross the memory wall we partition the model and accept that no device holds the whole thing, which makes new collectives essential: reduce-scatter and all-gather to shard and regather state in ZeRO and FSDP (Section 16.4), and point-to-point sends to pass activations between pipeline stages (Section 16.3). The primitives first met in Chapter 4 are exactly the tools that make a partitioned model behave like one coherent model, just as all-reduce made replicated workers behave like one.

Library Shortcut: estimate_zero_per_gpu_memory in One Call

The hand calculation in Code 16.1.1 is the same arithmetic production tooling does for you before you launch. DeepSpeed ships a memory estimator that reports, for a given parameter count and number of GPUs, how much memory each ZeRO stage would need per device, so you can pick a sharding strategy without trial and error:

# pip install deepspeed
from deepspeed.runtime.zero.stage3 import estimate_zero3_model_states_mem_needs_all_live
import torch.nn as nn

model = nn.Transformer(d_model=4096, num_encoder_layers=32, num_decoder_layers=0)
# Prints per-GPU memory for ZeRO-3 across 1, 8, 64 ... GPUs, with and
# without optimizer/parameter offload to CPU.
estimate_zero3_model_states_mem_needs_all_live(model, num_gpus_per_node=8, num_nodes=1)
Code 16.1.2: The roughly twenty lines of budget arithmetic in Code 16.1.1 collapse to a single DeepSpeed call that sweeps the same $16P$ state across ZeRO stages and device counts. The library handles the per-stage sharding math, the offload accounting, and the gradient-bucket overhead that Sections 16.4 and 16.5 unpack.
Practical Example: The 7B Model That Refused to Load

Who: A research engineer at a startup fine-tuning an open-weights $6.7$ B language model on domain text.

Situation: The team had a single node with eight $40$ GB A100s and assumed that, since the half-precision weights were only about $13$ GB, the model would train comfortably on one GPU with the rest free for a big batch.

Problem: The job hit an out-of-memory error before the first step finished, even with a batch size of one. The $13$ GB model was demanding far more than $40$ GB the moment the optimizer was constructed.

Dilemma: Rent $80$ GB GPUs at a premium and hope one is enough, drop to a worse optimizer to save state, or partition the model across the eight GPUs they already owned.

Decision: They ran the Code 16.1.1 budget first and saw the real footprint was about $168$ GB, more than four $40$ GB GPUs even before a generous batch. No single $80$ GB GPU would have fixed it either. The numbers said partition, not upgrade.

How: They wrapped the model in FSDP (the ZeRO-3 idea of Section 16.5) to shard the $16P$ state across all eight GPUs, and turned on activation checkpointing to cut the $34\,BSHL$ term, exactly the two remedies Table 16.1.1 points at for the optimizer and activation columns.

Result: The sharded job fit with room for a batch of $16$, on the hardware they already had, at no extra rental cost. The per-GPU state dropped to roughly one eighth of the unsharded figure.

Lesson: Compute the full $16P + 34\,BSHL$ budget before you reach for a bigger GPU. The bottleneck is almost never the weights, so the fix is almost always to partition the optimizer state and activations, not to buy more memory per device.

5. Why This Is the Right Time to Pay the Cost Intermediate

Partitioning a model is not free, and it is worth being honest about what it costs so the remedies do not look like magic. Splitting a layer's matrices across devices (tensor parallelism) inserts a collective into the forward and backward pass of every layer, so it demands a fast intra-node interconnect and usually stays inside one machine. Splitting layers into pipeline stages introduces idle time, the "bubble" while the pipeline fills and drains, which Section 16.3 quantifies and works to shrink. Sharding the state with ZeRO and FSDP trades memory for extra all-gather traffic, because a parameter sharded away must be gathered back before the layer that needs it can run. Activation checkpointing trades memory for recomputation, typically adding around a third more compute. Every remedy moves load off the memory axis and onto the communication or compute axis, which is the same fundamental tradeoff the performance models of Chapter 3 taught you to weigh.

The reason to pay these costs is simply that the alternative is not training the model at all. For any model past a few billion parameters the budget in Output 16.1.1 makes data parallelism impossible on its own, so the question is never "partition or replicate?" but "which partition, and how do they combine?" Real foundation-model training, the subject of Chapter 19, stacks all four remedies of Table 16.1.1 together with the data parallelism of Chapter 15 into a single configuration, often called 3D or 4D parallelism. This chapter builds those remedies one at a time so that, by the end, that stacked configuration reads as a deliberate set of choices rather than a wall of acronyms.

Research Frontier: Pushing the Wall Outward (2024 to 2026)

Because the memory wall sets the ceiling on what a given cluster can train, several active lines work to lower the budget itself. Fully Sharded Data Parallel matured into PyTorch's native FSDP2 with per-parameter sharding and cleaner composition with tensor parallelism, narrowing the gap to DeepSpeed ZeRO. On the optimizer term specifically, memory-efficient optimizers such as GaLore (Zhao et al., 2024), which projects gradients into a low-rank subspace, and the Adam-Mini family report training large models while storing a small fraction of the $12P$ Adam state, attacking the orange slab that dominates Figure 16.1.1. On the activation term, FlashAttention-style fused kernels keep the attention activation footprint near linear in sequence length, and selective recomputation chooses which activations to checkpoint by cost. The throughline is that the $16P + 34\,BSHL$ budget of this section is treated as an engineering target to be driven down, not a constant of nature; we revisit the optimizer side with the machinery to evaluate it in Chapter 10.

We now have the wall (data parallelism cannot grow a device), the measurement (the $16P + 34\,BSHL$ budget, crossing $40$ GB and $80$ GB exactly where Output 16.1.1 says), and the map (the four remedies of Table 16.1.1). The first remedy reaches inside a single layer and splits its matrices across devices so that even one layer too large for one accelerator can run. That is tensor parallelism, and it begins in Section 16.2.

Exercise 16.1.1: Which Remedy for Which Term? Conceptual

Using the budget $16P + 34\,BSHL$ and Table 16.1.1, decide which remedy you would reach for first in each case, and say which budget term binds: (a) a model whose parameters alone are too large for one device even at a batch size of one; (b) a model that fits but only if you could halve the optimizer state; (c) a model that fits at a short sequence length but overflows when you quadruple the sequence length with the batch held fixed. Explain why applying the wrong remedy would leave the binding term untouched.

Exercise 16.1.2: Shard the Budget Coding

Extend Code 16.1.1 with a function per_device(P, batch, seq, layers, hidden, world_size, stage) that returns the per-device memory under ideal ZeRO sharding, where stage=1 shards only the $12P$ optimizer state across world_size devices, stage=2 additionally shards the $2P$ gradients, and stage=3 shards all of the $16P$ state (activations stay per device). For the $6.7$ B configuration, print the per-device total for stages $1$, $2$, $3$ at world_size in $\{8, 16, 64\}$, and report the smallest world size at each stage that brings the per-device total under $40$ GB. Comment on which stage first makes the model trainable on $40$ GB hardware.

Exercise 16.1.3: The Cost of Recomputation Analysis

Activation checkpointing replaces the $34\,BSHL$ activation term with a much smaller stored footprint (roughly the activations at stage boundaries) at the price of recomputing the forward pass during the backward pass, adding about $33\%$ more compute. For the $13$ B configuration in Output 16.1.1, estimate the new total memory if checkpointing cuts the activation term to one tenth of its stored value, and state whether that alone brings the model under $80$ GB. Then argue, using the compute-versus-memory framing of Chapter 3, when paying $33\%$ more compute to save activation memory is worthwhile and when it is not.