Part IV: Parallel Deep Learning and Large Models
Chapter 16: Model, Pipeline, and Sharded Parallelism

3D and 4D Parallelism

"I am tensor-parallel to my left, pipeline-parallel to my front, and data-parallel to my distant cousins three racks over. I contain a sixty-fourth of one layer, and I have never been more sure of my place."

A Device That Found Its Coordinates
Big Picture

No single parallelism strategy trains a frontier model; the largest runs compose three or more of them at once, arranging every device into a grid whose axes are deliberately mapped onto the interconnect tiers they each need. Tensor parallelism (Section 16.2) splits one layer across devices and demands the fastest link in the cluster. Pipeline parallelism (Section 16.3) splits the layer stack across stages and tolerates a slower link. Data parallelism (Chapter 15) replicates the whole arrangement and rides the slowest fabric of all. Stack them and you get 3D parallelism; add expert parallelism for mixtures-of-experts and sequence parallelism for long context and you reach 4D and beyond. This section shows how the axes compose, why their product must equal the device count exactly, and why choosing a good shape is itself an optimization problem that frontier labs now hand to automated search.

Each preceding section of this chapter gave you one tool. Tensor parallelism cut a matrix multiply across devices so a layer too wide for one accelerator could still run. Pipeline parallelism cut the layer stack into stages so a model too deep for one device could still fit. Sharded data parallelism (Section 16.5) split the optimizer state so a replica's memory footprint shrank. Sequence and context parallelism (Section 16.7) cut the activation along the token axis so a long sequence could still be held. Used alone, each tool runs out of room: tensor parallelism stops scaling once the all-reduce per layer leaves the fast intra-node link, pipeline parallelism wastes devices in its startup bubble once the stage count grows, and pure data parallelism cannot help at all when a single replica does not fit in memory. The frontier-model answer is not to pick one tool but to combine them so each covers the others' limits.

The combination is not arbitrary. The defining idea of this section is that the device grid is laid out to match the cluster's physical topology, the same topology-aware placement principle introduced for collectives in Section 4.9. Communication on each axis is mapped onto the interconnect tier that suits its volume and frequency. Get that mapping right and a thousand-device run keeps its accelerators busy; get it wrong and the same thousand devices spend most of their time waiting on the network.

1. Stacking Three Axes Into a Grid Intermediate

Picture the devices not as a flat list but as a three-dimensional grid. Along one axis sits the tensor-parallel group: a small set of devices, typically the eight inside a single node, that jointly hold one layer and all-reduce partial results every single layer of the forward and backward pass. That traffic is large and incessant, so it must ride the fastest link the cluster owns, the intra-node NVLink-class fabric. Along a second axis sits the pipeline: each point on this axis is a different stage of the layer stack, and stages exchange only the activations at their boundaries, a few times per micro-batch. That traffic is lighter, so the pipeline axis can span the medium-speed node-to-node links. Along the third axis sits data parallelism: each point is a full replica of the entire tensor-by-pipeline arrangement, and replicas talk only once per step, when they all-reduce gradients. That is the rarest communication of all, so the data axis is placed outermost, across the slowest cross-rack fabric.

The rule is a single sentence: put the tensor-parallel group on the fastest links, the pipeline across slower links, and data parallelism outermost. It falls directly out of matching communication frequency to bandwidth. Figure 16.9.1 draws the grid and the tier mapping; the demo in Section 4 will then put numbers behind why the ordering is the cheap one.

Replica 1 (one full model) TP group: stage 1 TP group: stage 2 TP group: stage 3 Tensor axis = fastest link (NVLink, intra-node) all-reduce every layer: large, constant traffic Pipeline axis = medium link Pipeline axis = medium link (node-to-node) stage-boundary activations: lighter, per micro-batch Replica 2 (identical copy) full TP × PP grid same shape as Replica 1 holds its own data shard Data axis slowest fabric (cross-rack) gradient all-reduce once per step tensor-parallel device (fast tier) pipeline stage (medium tier) data replica (slow tier) device count constraint: (tensor degree) × (pipeline degree) × (data degree) = total devices
Figure 16.9.1: The 3D device grid. Inside one replica, four devices form a tensor-parallel group sharing a layer on the fast NVLink tier (red), and three such groups stack into a pipeline of stages on the medium node-to-node tier (orange). Each replica is duplicated along the data axis on the slow cross-rack tier (green). The mapping puts the most frequent collective on the fastest link. The product of the three degrees must equal the device count exactly, the constraint formalized in Section 2.
Key Insight: Map the Hottest Collective to the Fastest Wire

The whole art of multi-axis parallelism reduces to one ordering principle. Rank the parallelism axes by how often and how much they communicate: tensor parallelism every layer (hottest), pipeline parallelism every micro-batch (warm), data parallelism every step (coolest). Rank the cluster's links by bandwidth: intra-node fastest, node-to-node medium, cross-rack slowest. Then pair them in order. Any other pairing forces a hot collective onto a cold wire and the accelerators stall. This is the same topology-aware placement that Section 4.9 taught for a single collective, now applied to three of them simultaneously.

2. The Configuration Constraint and the Combinatorial Space Intermediate

A 3D configuration is a triple of parallel degrees: the tensor degree $t$, the pipeline degree $p$, and the data degree $d$. Because every device occupies exactly one cell of the grid, the degrees are bound by a hard identity. With $D$ total devices,

$$d \times p \times t = D.$$

This is a partition of $D$ into ordered factors, and it constrains the search sharply: only factorizations of $D$ are legal. When you add expert parallelism with degree $e$ for a mixture-of-experts model (the subject of Chapter 17) and sequence parallelism with degree $s$ for long context (Section 16.7), the grid gains axes and the constraint generalizes to

$$d \times p \times t \times e \times s = D,$$

which is why practitioners speak of 4D and 5D parallelism. The number of legal configurations is the number of ordered factorizations of $D$ into the chosen count of axes, and it grows quickly. For the three-axis case, the count is bounded by summing over divisors, and for a device count with many factors it reaches into the dozens or hundreds. Each configuration has a wildly different memory footprint and communication cost, and only some are feasible at all, because a configuration that leaves each device holding more than its memory can fit simply does not run.

So the engineer faces a discrete optimization: among all factorizations of $D$ that fit in memory, find the one that minimizes time per step. There is no closed form, because the cost of each axis depends on the model shape, the batch size, and the exact bandwidths of the cluster. This is why configuration search has become its own subfield, treated as the frontier in Section 5 and as the central concern of Section 16.10, the next section, which turns the search into a tuning methodology you can apply by hand.

Fun Note: Why 64 Is a Friendlier Number Than 60

Clusters are built in powers of two for exactly this reason. With $D = 64 = 2^6$ you get a generous lattice of factorizations: $(d,p,t)$ can be $(1,8,8)$, $(2,4,8)$, $(4,4,4)$, $(8,8,1)$, and many more, so the search has room to find a good shape. With $D = 60$ the factor structure is lumpier and the few legal shapes may all be poor. Hardware vendors and cluster architects are doing you a quiet favor every time they ship nodes in multiples of eight.

3. From 3D to 4D and Beyond Advanced

The third dimension is rarely the last. Two further axes appear in frontier stacks often enough that 3D is best read as a floor, not a ceiling. The first is expert parallelism. A mixture-of-experts layer replaces one feed-forward block with many experts and routes each token to a few of them; placing different experts on different devices turns the routing into an all-to-all collective and adds an axis whose degree $e$ multiplies into the device-count identity. Chapter 17 develops this axis in full, including why its all-to-all is the most placement-sensitive collective of them all. The second is sequence or context parallelism, which splits a single long sequence's activations across devices so that context lengths far beyond one device's memory become trainable; Section 16.7 introduced it, and in a 4D stack it sits as yet another factor in the product.

What matters conceptually is that none of these axes is special. Each is just another way of cutting the work, each carries a characteristic collective (tensor parallelism an all-reduce, pipeline a point-to-point send, data parallelism an all-reduce, expert parallelism an all-to-all, sequence parallelism an all-gather), and each must be mapped onto the interconnect tier its collective can afford. The device-count identity grows a factor; the placement rule stays the same. A frontier training stack is the disciplined composition of all of them, which is exactly the arrangement that Chapter 19 assembles end-to-end when it puts a real foundation-model run together.

Thesis Thread: The Whole Book Composes Here

This section is where the book's axes stop being separate chapters and become one grid. Data parallelism from Chapter 15, the model and pipeline splits from earlier in this chapter, the sharded optimizer state from Section 16.5, the expert routing of Chapter 17, and the topology-aware placement of Section 4.9 all meet in a single device grid governed by one multiplicative constraint. The thesis that scale-out is the composition of distinct, composable cuts (introduced with the six axes in Section 1.1) reaches its sharpest expression here: frontier models are trained by stacking the cuts, not by inventing a new one.

4. A Configuration Search You Can Run Intermediate

The cleanest way to feel the combinatorial space is to enumerate it. The program below takes a device count $D$, walks every legal factorization into $(d, p, t)$, estimates each one's per-device memory and per-step communication with the simple cost models this chapter has built, discards the infeasible shapes that would exceed device memory, and reports the cheapest survivor. It is a miniature of what production auto-parallelizers do, with the cost models kept transparent enough to read line by line. The tensor-parallel bandwidth is deliberately downgraded once the group spills past one node, which is the single most important piece of topology awareness in the model.

from itertools import product

D            = 64          # total devices (must equal dp * pp * tp)
P_params     = 70e9        # model parameters
bytes_state  = 16          # bytes/param for params+grad+Adam state (mixed precision)
L_layers     = 80          # transformer layers (limits useful pipeline depth)
act_per_dev  = 3.0e9       # activation bytes one device would hold (1 stage, full batch)
mem_per_dev  = 80e9        # HBM per device (e.g. 80 GB accelerator)
GPUS_PER_NODE = 8          # devices sharing the fast NVLink domain inside one node
microbatches = 16          # pipeline micro-batches per step (sets the bubble size)
t_stage      = 8.0e-3      # compute time per pipeline stage per micro-batch, sec

BW_fast      = 600e9       # NVLink-class, bytes/sec (only WITHIN a node)
BW_pipe      = 100e9       # fast node-to-node
BW_data      = 25e9        # cross-rack fabric

def tensor_bw(tp):
    # tensor-parallel all-reduce stays on NVLink only while the group fits a node;
    # spilling past GPUS_PER_NODE drops it onto the slow node-to-node link.
    return BW_fast if tp <= GPUS_PER_NODE else BW_pipe

def divisors(n):
    return [k for k in range(1, n + 1) if n % k == 0]

def memory_per_device(dp, pp, tp):
    state = bytes_state * P_params / (dp * pp * tp)   # ZeRO-style sharding over dp
    act   = act_per_dev / (pp * tp)                   # pp cuts layers, tp cuts within
    return state + act

def comm_cost(dp, pp, tp):
    layer_bytes = bytes_state * P_params / L_layers
    t_tensor = (tp - 1) * layer_bytes / (tp * tensor_bw(tp)) if tp > 1 else 0.0
    t_pipe   = (pp - 1) * act_per_dev / (pp * BW_pipe)       if pp > 1 else 0.0
    t_bubble = (pp - 1) / microbatches * t_stage            if pp > 1 else 0.0
    grad_bytes = bytes_state * P_params / (pp * tp)
    t_data   = (dp - 1) * grad_bytes / (dp * BW_data)       if dp > 1 else 0.0
    return t_tensor + t_pipe + t_bubble + t_data

ranked = []
for tp, pp in product(divisors(D), divisors(D)):
    if D % (tp * pp) != 0:
        continue
    dp = D // (tp * pp)
    if pp > L_layers:                         # cannot have more stages than layers
        continue
    mem = memory_per_device(dp, pp, tp)
    if mem > mem_per_dev:                      # infeasible: would OOM
        continue
    ranked.append((comm_cost(dp, pp, tp), dp, pp, tp, mem))
ranked.sort()

print(f"devices D                 : {D}")
print(f"feasible configurations   : {len(ranked)}")
cost, dp, pp, tp, mem = ranked[0]
print(f"best (dp, pp, tp)         : ({dp}, {pp}, {tp})   product = {dp*pp*tp}")
print(f"  memory / device         : {mem/1e9:6.2f} GB   (cap {mem_per_dev/1e9:.0f} GB)")
print(f"  comm time / step        : {cost*1e3:6.2f} ms")
print("\ntop 4 feasible shapes (dp, pp, tp -> comm ms, mem GB):")
for cost, dp, pp, tp, mem in ranked[:4]:
    print(f"  ({dp:2d}, {pp:2d}, {tp:2d})  ->  {cost*1e3:6.2f} ms   {mem/1e9:6.2f} GB")
three_d = [r for r in ranked if r[1] > 1 and r[2] > 1 and r[3] > 1]
if three_d:
    cost, dp, pp, tp, mem = three_d[0]
    print(f"\nbest genuine 3D shape     : (dp={dp}, pp={pp}, tp={tp})  ->  {cost*1e3:6.2f} ms")
Code 16.9.1: A standard-library configuration search over the $(d, p, t)$ factorizations of $D$ devices. The feasibility filter on memory_per_device and the topology-aware tensor_bw downgrade are the two lines that make the cheapest shape a realistic one rather than a degenerate one.
devices D                 : 64
feasible configurations   : 28
best (dp, pp, tp)         : (1, 8, 8)   product = 64
  memory / device         :  17.55 GB   (cap 80 GB)
  comm time / step        :  50.17 ms

top 4 feasible shapes (dp, pp, tp -> comm ms, mem GB):
  ( 1,  8,  8)  ->   50.17 ms    17.55 GB
  ( 1, 16,  4)  ->   53.12 ms    17.55 GB
  ( 1, 32,  2)  ->   56.23 ms    17.55 GB
  ( 1, 64,  1)  ->   61.03 ms    17.55 GB

best genuine 3D shape     : (dp=2, pp=4, tp=8)  ->  744.42 ms
Output 16.9.1: Twenty-eight of the legal factorizations of 64 are memory-feasible. The cheapest keeps the tensor group at exactly 8 (one node, so its all-reduce stays on NVLink) and uses the rest for the pipeline. Pushing the tensor degree past 8 or adding a data axis raises the step time because the heaviest collectives then cross slow wires, as the lone genuine 3D shape shows.

The winning shape, $(d, p, t) = (1, 8, 8)$, is exactly the textbook arrangement: a tensor-parallel group of eight devices confined to one node where the per-layer all-reduce stays on the fast link, and a pipeline of eight stages spanning the slower node-to-node fabric. The moment the search tries to widen the tensor group past a node, the tensor_bw downgrade fires and the cost climbs. The moment it introduces a data axis, the gradient all-reduce over the slow cross-rack fabric dominates everything, which is why the only genuine three-axis shape costs roughly fifteen times more per step here. That is not a flaw in 3D parallelism; it is the search correctly telling you that at this device count and model size, the extra axis is not yet earning its communication. Add more devices or a larger model and the balance shifts. The point is that the answer is a search result, not a fixed recipe.

Library Shortcut: A Mesh Replaces the Bookkeeping

The grid in Code 16.9.1 is hand-built bookkeeping. Production frameworks expose the same grid as a first-class device mesh, and you declare the axes by name instead of tracking factor arithmetic. In PyTorch the mesh is a few lines, and the framework derives every process group and routes each collective onto the right axis for you:

# Run with: torchrun --nproc_per_node=8 --nnodes=8 thisfile.py   (64 devices)
from torch.distributed.device_mesh import init_device_mesh

# Name the three axes; the product of the sizes must equal the world size.
mesh = init_device_mesh("cuda", (1, 8, 8), mesh_dim_names=("dp", "pp", "tp"))

tp_group = mesh["tp"].get_group()   # the 8 devices that all-reduce each layer
pp_group = mesh["pp"].get_group()   # the 8 pipeline stages
dp_group = mesh["dp"].get_group()   # the data-parallel replicas (size 1 here)
Code 16.9.2: The same $(1, 8, 8)$ grid as Output 16.9.1, declared as a named device mesh. Roughly forty lines of factor enumeration and process-group wiring collapse to one init_device_mesh call, and frameworks such as DeepSpeed and Megatron-LM layer their 3D engines on exactly this abstraction.
Practical Example: The Run That Was Slow Until Someone Reordered the Axes

Who: A distributed-training engineer bringing up a 70-billion-parameter model on a new 64-device cluster.

Situation: The first launch used $(d, p, t) = (8, 1, 8)$: eight data replicas, no pipeline, tensor-parallel of eight. It trained correctly but at roughly half the expected throughput.

Problem: Profiling showed the accelerators idle for a large fraction of each step, waiting on the gradient all-reduce that the eight data replicas exchanged across the slow cross-rack fabric.

Dilemma: Buy a faster cross-rack fabric, an expensive and slow procurement, or change the parallelism shape so the heavy collective never crosses the slow tier in the first place.

Decision: They reshaped to $(1, 8, 8)$, trading the data axis for a pipeline, after a cost-model search like Code 16.9.1 flagged it as the cheapest feasible shape.

How: The change was a one-line edit to the device-mesh declaration of Code 16.9.2 plus a pipeline-schedule wrapper; no model code changed.

Result: Step time fell by close to forty percent and accelerator utilization rose sharply, because the only remaining cross-node traffic was the light pipeline activations, not the heavy gradient all-reduce.

Lesson: The same devices and the same model can differ almost twofold in throughput on the choice of grid shape alone. The shape is a tuning knob, and the cost model tells you which way to turn it.

5. Letting a Machine Choose the Shape Advanced

Code 16.9.1 brute-forced 28 candidates because a 3D space over 64 devices is small. Real stacks face a far larger space: five or more axes, dozens of model-shape choices, micro-batch counts, recomputation policies, and overlap schedules, multiplied across hundreds or thousands of devices. Enumerating it by hand stops being possible, and the search itself becomes the engineering problem. This is the domain of auto-parallelization, and it is an active frontier precisely because a good configuration can mean the difference between a run that finishes in a week and one that finishes in two.

Research Frontier: Auto-Parallelization and Frontier Training Stacks (2024 to 2026)

The configuration search Code 16.9.1 does by brute force is automated by systems in the lineage of Alpa (Zheng et al., 2022), which splits the space into inter-operator (pipeline) and intra-operator (tensor) parallelism and solves each with a cost-model-driven search rather than human tuning. The idea has since been folded into production stacks: Megatron-LM and DeepSpeed expose 3D and 4D configurations whose degrees are increasingly chosen by cost-model search rather than fixed by hand, and recent work pushes toward fully automated mesh selection that accounts for real measured bandwidths. A parallel 2024-to-2026 thread tackles the heterogeneous and elastic case, where the device pool changes mid-run (the elastic-training concern of Chapter 18) and the optimal shape must be recomputed online. The throughline is that frontier-model training has turned parallel-configuration choice from a craft into a search problem with an explicit objective: minimize time per step subject to the memory-feasibility and device-count constraints of Section 2.

With the composition rule, the device-count constraint, and the search in hand, the remaining question is practical: faced with a specific model and a specific cluster, how do you actually pick and tune the degrees without running every candidate to convergence? That is the methodology of the next section. Section 16.10 turns the cost-model intuition of Code 16.9.1 into a step-by-step strategy for choosing and tuning a parallelism configuration, the capstone that closes this chapter.

Exercise 16.9.1: Count the Legal Shapes Conceptual

For $D = 64$ devices, the configuration constraint is $d \times p \times t = 64$. (a) How many ordered triples $(d, p, t)$ of positive integers satisfy it? (b) Now add an expert axis so $d \times p \times t \times e = 64$; argue qualitatively whether the count of legal four-tuples is larger or smaller than the count of triples, and why. (c) Explain in one sentence why a device count of $D = 64$ admits a richer search than $D = 66$, referring to the factor structure of each number.

Exercise 16.9.2: Move the Bottleneck Coding

Modify Code 16.9.1 so that BW_data (the cross-rack fabric) is raised from 25 to 200 gigabytes per second, modeling a cluster with a fast flat network. Re-run the search and report how the best $(d, p, t)$ shape and its step time change. Then explain, in terms of which collective rides which tier, why a faster data-axis fabric makes shapes with a data degree greater than one competitive again. What does this tell you about why network topology, not just device count, dictates the optimal parallelism shape?

Exercise 16.9.3: Read the Feasibility Frontier Analysis

Using the memory model in Code 16.9.1, derive by hand the smallest product $p \times t$ (model-sharding degree) that keeps memory_per_device under the 80 GB cap for a model of $P = 200$ billion parameters at 16 bytes of state per parameter, ignoring activations. Compare your hand-derived threshold to what the program reports when you set P_params = 200e9 and inspect which shapes survive the feasibility filter. Explain why raising the parameter count shrinks the feasible region and pushes the optimum toward more aggressive model sharding, connecting your answer to the alpha-beta communication-cost reasoning of Chapter 3.