Part IV: Parallel Deep Learning and Large Models
Chapter 15: Data-Parallel Deep Learning

PyTorch Distributed Data Parallel

"They wrapped me in one line of code and told everyone I was easy. Behind the curtain I was registering hooks, packing buckets, and racing the backward pass on every step. I never complained, because that is the job."

A DistributedDataParallel Wrapper Keeping Up Appearances
Big Picture

Everything this chapter built by hand (the exact gradient identity, the synchronous step, the ring all-reduce, the overlap of communication with computation) ships in PyTorch as a single wrapper called DistributedDataParallel, so that a distributed training loop looks almost exactly like a single-machine one. You run one process per GPU, wrap your model once, give each process a different slice of the data, and launch with torchrun. From that moment, every backward pass silently buckets the gradients, fires a NCCL all-reduce per bucket, and overlaps that network traffic with the rest of backpropagation, so by the time the optimizer steps, every replica already holds the averaged gradient. This section is the chapter's payoff: it shows the anatomy of a real DDP job, explains why DDP replaced the older single-process DataParallel, and reproduces DDP's core contract in a few dozen lines of pure Python so the magic is no longer magic.

By this point in the chapter you can derive data-parallel training from first principles. The gradient of an average loss is itself an average, so splitting a batch across $K$ workers and averaging their gradients reproduces the single-machine gradient exactly, a fact established back in Section 1.1. The combine step is an all-reduce, the collective built in Chapter 4, and the previous section showed how to hide its cost by overlapping it with the backward pass. What remains is to stop writing all of that yourself. PyTorch DistributedDataParallel (DDP) is the tool that packages the entire chapter into production code, and using it well is mostly a matter of understanding what it does on your behalf so you can reason about the cases where it needs help.

GPU 0 (rank 0) GPU 1 (rank 1) data shard 0 (DistributedSampler) data shard 1 (DistributedSampler) forward (identical weights) forward (identical weights) backward pass bucket 2 bucket 1 still computing backward pass bucket 2 bucket 1 still computing all-reduce bucket 2 all-reduce bucket 1 (NCCL, overlapped with backward) averaged gradient on every replica optimizer.step() runs locally, no extra sync
Figure 15.6.1: What DistributedDataParallel does on every step. Two processes, one GPU each, run forward and backward on different data shards. As soon as a bucket of gradients finishes in the backward pass, DDP launches a NCCL all-reduce for that bucket (dashed orange arrows) while the rest of the backward pass keeps computing, so communication overlaps computation. When backward finishes, every replica already holds the identical averaged gradient and the optimizer steps locally with no further synchronization.

1. The Anatomy of a DDP Job Beginner

A DDP training script is built from five moving parts, and once you can name them the rest is detail. First, each process joins a process group so the workers can find each other and run collectives. Second, you wrap your model in DistributedDataParallel, which broadcasts the initial weights from rank 0 and registers the backward hooks that will do the gradient averaging. Third, you give the data loader a DistributedSampler so each process reads a disjoint slice of the dataset rather than all of it. Fourth, you launch the whole thing with torchrun, which spawns one process per GPU and sets the environment variables (rank, world size, master address) that the process group reads. Fifth, you guard checkpoint saving so that only rank 0 writes the file, because all replicas hold identical weights and writing from all of them at once would be wasteful at best and corrupt the file at worst.

The code below is the skeleton of essentially every single-node-multi-GPU DDP job. It needs a real multi-GPU machine and the NCCL backend to run, so we show it illustratively here and reproduce its core behavior in pure Python later in the section; read it for shape, not to execute.

# Save as train_ddp.py, launch with:  torchrun --nproc_per_node=4 train_ddp.py
import os, torch, torch.nn as nn, torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler

def main():
    dist.init_process_group(backend="nccl")          # 1. join the K-worker group
    rank       = dist.get_rank()
    local_rank = int(os.environ["LOCAL_RANK"])        # which GPU on this machine
    torch.cuda.set_device(local_rank)                 # pin this process to one GPU

    model = MyModel().cuda(local_rank)
    model = DDP(model, device_ids=[local_rank])       # 2. wrap: broadcast + hooks

    sampler = DistributedSampler(train_set)           # 3. disjoint shard per rank
    loader  = DataLoader(train_set, batch_size=64, sampler=sampler)
    opt     = torch.optim.AdamW(model.parameters(), lr=3e-4)

    for epoch in range(num_epochs):
        sampler.set_epoch(epoch)                      # reshuffle shards each epoch
        for x, y in loader:
            x, y = x.cuda(local_rank), y.cuda(local_rank)
            opt.zero_grad()
            loss = loss_fn(model(x), y)
            loss.backward()                           # hooks fire all-reduce here
            opt.step()                                # every rank already averaged

        if rank == 0:                                 # 5. save once, from rank 0
            torch.save(model.module.state_dict(), "ckpt.pt")

    dist.destroy_process_group()

if __name__ == "__main__":
    main()
Code 15.6.1: The five-part skeleton of a DDP job: init_process_group, the DDP wrap, the DistributedSampler, the ordinary training loop where loss.backward() triggers the all-reduce, and the rank-0 checkpoint guard. The training loop differs from a single-machine loop by only the sampler and the rank guard.

The striking thing about Code 15.6.1 is how little of it is distributed-systems code. There is no explicit all-reduce, no manual gradient gathering, no bucket management; loss.backward() looks exactly as it does on one machine. That is the entire point of DDP, and it is the reason it is the default tool for the data-parallel case. The synchronization that earlier sections spelled out by hand has been folded into the backward pass itself, which is exactly the overlap strategy the previous section motivated.

Key Insight: DDP Is a Backward Hook, Not a New Training Algorithm

DDP does not change your model, your loss, or your optimizer. It registers a hook on each parameter that fires the moment that parameter's gradient is ready during the backward pass. The hook drops the gradient into a bucket; when a bucket fills, DDP launches a NCCL all-reduce for it while the backward pass keeps computing the next layers' gradients. By the time backward() returns, every replica holds the same averaged gradient, so each process can call optimizer.step() independently and stay bit-for-bit in sync. The algorithm is still ordinary synchronous SGD; DDP is just the plumbing that makes the all-reduce free by hiding it behind computation.

2. Why DDP Replaced DataParallel Beginner

PyTorch once shipped an easier-looking class called DataParallel (note: no "Distributed"). It ran in a single process and, on each step, scattered the input batch across the local GPUs, replicated the model to each, gathered the outputs back to the primary GPU, and computed the backward there. It needed no launcher and no process group, which made it tempting, and it is the wrong tool for almost every real job. Because it lives in one Python process, it is throttled by the Global Interpreter Lock, so the single driver thread becomes a bottleneck that the extra GPUs cannot relieve. It re-replicates the model to every GPU on every forward pass, which is pure overhead. It routes all outputs and the loss through one primary GPU, creating a memory and bandwidth hotspot that scales badly. And it cannot cross machine boundaries at all.

DistributedDataParallel fixes each of these by construction. It runs one process per GPU, so there is no shared interpreter lock and no single driver thread. It replicates the model exactly once, at construction, and thereafter only exchanges gradients. It never funnels activations through a primary device; each process keeps its own forward and backward entirely local and the only cross-device traffic is the gradient all-reduce. And it works identically across one machine or a thousand, because the process group abstraction does not care whether ranks share a host. The practical guidance from the PyTorch team is blunt and worth internalizing: use DistributedDataParallel, not DataParallel, even on a single machine with multiple GPUs.

Table 15.6.1: The two PyTorch data-parallel wrappers compared. DistributedDataParallel wins on every axis that matters for real training, which is why it is the default.
PropertyDataParallel (legacy)DistributedDataParallel
ProcessesOne, many threadsOne per GPU
GIL bottleneckYes, single driver threadNo, independent processes
Model replicationEvery forward passOnce, at construction
Cross-GPU trafficScatter inputs, gather outputsGradient all-reduce only
Multi-machineNoYes, identical code
Comm / compute overlapNoYes, bucketed during backward

3. Correctness Details DDP Quietly Enforces Intermediate

For the averaged gradient to equal the single-machine gradient, several invariants must hold, and DDP enforces or assumes each of them. Every replica must start from identical weights, which DDP guarantees by broadcasting rank 0's parameters and buffers when you wrap the model; this is why you do not need to seed every process identically yourself. Each process owns exactly one GPU, the configuration DDP is tuned for. The gradient is averaged, not merely summed, so that the effective update matches a single batch of size $K$ times the per-process batch; DDP divides the all-reduced sum by the world size for you. With $K$ processes each using a local batch of $B$, the global batch is $KB$, and the synchronized gradient is

$$\bar{g} = \frac{1}{K} \sum_{k=1}^{K} g_k, \qquad g_k = \frac{1}{B}\sum_{i \in \mathcal{B}_k} \nabla \ell(w; x_i, y_i),$$

which is exactly the mean gradient over the global batch when the shards are equal in size. Equal shard sizes matter: the DistributedSampler pads the last shards so every rank sees the same number of batches, because a rank that runs out of data early would leave the others waiting forever at the next all-reduce. The one place this exactness leaks is batch normalization, whose statistics are computed per process over only the local batch of $B$ examples, not the global $KB$. When that local batch is small, swapping the layers for SyncBatchNorm (which all-reduces the mean and variance across replicas) restores the single-machine semantics at the cost of an extra collective.

Fun Note: The Deadlock That Looks Like a Hang

The most common DDP support ticket is not a crash; it is a job that simply stops. One rank takes a branch that skips a layer, or hits a shorter data shard, and never produces a gradient for some parameter. The other ranks reach that parameter's all-reduce and wait. Forever. There is no error, no stack trace, just every GPU pinned at 100 percent utilization doing nothing, because a collective is a barrier and a barrier with one absent participant never returns. The cure is to make every rank execute the same set of all-reduces every step, which is exactly why DistributedSampler pads shards to equal length and why find_unused_parameters=True exists for genuinely dynamic graphs.

4. Gradient Accumulation and no_sync Intermediate

Sometimes the global batch you want is larger than the GPUs can hold even after splitting it $K$ ways. The standard remedy is gradient accumulation: run several forward and backward passes, summing gradients in the .grad buffers, and step the optimizer only once every few micro-batches. Under DDP this creates a subtlety. By default DDP all-reduces on every backward pass, so naive accumulation would synchronize gradients on every micro-batch, paying the communication cost many times for a single optimizer step. That is wasteful: you only need the averaged gradient on the step where you actually call optimizer.step(). DDP exposes the no_sync() context manager precisely for this. Inside it, the backward hooks still accumulate gradients locally but skip the all-reduce; you run all but the last micro-batch inside no_sync(), then run the final micro-batch outside it so a single all-reduce synchronizes the accumulated total.

# Accumulate over `accum` micro-batches, all-reduce only on the last one.
for i, (x, y) in enumerate(loader):
    is_last = (i + 1) % accum == 0
    ctx = model.no_sync() if not is_last else contextlib.nullcontext()
    with ctx:                                     # skip all-reduce until the end
        loss = loss_fn(model(x), y) / accum       # scale so the mean is correct
        loss.backward()                           # local grad accumulation
    if is_last:
        opt.step(); opt.zero_grad()               # one all-reduce already fired
Code 15.6.2: Gradient accumulation under DDP. Wrapping the non-final micro-batches in no_sync() suppresses their all-reduce, so the full accumulation costs one collective instead of accum of them. The loss is divided by accum so the accumulated gradient is a mean, not a sum.

This is the deep link back to communication-efficient optimization. Skipping the all-reduce on intermediate micro-batches is the same idea as the local-update schemes of Section 10.5: do more local computation between synchronizations so the network cost is amortized over more work. Gradient accumulation is the simplest, exact instance of that trade, because the local steps here are merely accumulating into one gradient rather than taking independent optimizer steps, so the result is identical to a single large batch.

5. Reproducing DDP's Contract in Pure Python Intermediate

The clearest way to demystify DDP is to build its core contract without any GPUs or NCCL. The demonstration below defines a tiny MiniDDP wrapper around a plain parameter vector. It holds $K$ replicas initialized identically (DDP's broadcast from rank 0), and on every step it computes each replica's gradient on its own data shard, then averages the gradients across replicas by walking them in fixed-size buckets and all-reducing one bucket at a time, exactly as DDP's backward hooks do. After training, it checks two things that DDP promises: that all replicas remain bit-for-bit identical, and that the synchronized trajectory matches an ordinary single-process run over the full batch.

import numpy as np

rng = np.random.default_rng(7)
N, d, K, STEPS, LR = 4096, 16, 4, 30, 0.05

X = rng.standard_normal((N, d))                  # one shared regression problem
w_true = rng.standard_normal(d)
y = X @ w_true + 0.1 * rng.standard_normal(N)
w_init = np.zeros(d)

def full_grad(w, Xb, yb):                         # mean-squared-error gradient
    n = len(yb)
    return (2.0 / n) * (Xb.T @ (Xb @ w - yb))

# Reference: ordinary single-process training over the whole batch.
w_ref = w_init.copy()
for _ in range(STEPS):
    w_ref -= LR * full_grad(w_ref, X, y)

class MiniDDP:
    """On each step, bucket the per-replica gradients and all-reduce (sum then
    divide by world size) so every replica applies the identical averaged grad,
    exactly as torch DDP does inside its backward hook."""
    def __init__(self, w0, world_size, bucket_size):
        self.world_size, self.bucket_size = world_size, bucket_size
        self.replicas = [w0.copy() for _ in range(world_size)]   # identical init

    @staticmethod
    def _all_reduce_buckets(grads, bucket_size):
        P = grads[0].size
        out = [np.empty_like(g) for g in grads]
        for start in range(0, P, bucket_size):                   # one all-reduce
            stop = min(start + bucket_size, P)                   # per bucket
            stacked = np.stack([g[start:stop] for g in grads])
            averaged = stacked.sum(axis=0) / len(grads)          # SUM / world_size
            for r in range(len(grads)):
                out[r][start:stop] = averaged
        return out

    def step(self, shards, lr):
        local = [full_grad(self.replicas[r], X[s], y[s])         # local grad only
                 for r, s in enumerate(shards)]
        synced = self._all_reduce_buckets(local, self.bucket_size)
        for r in range(self.world_size):
            self.replicas[r] -= lr * synced[r]

shards = np.array_split(np.arange(N), K)                         # disjoint shards
ddp = MiniDDP(w_init, world_size=K, bucket_size=5)               # 16 params, bk 5
for _ in range(STEPS):
    ddp.step(shards, LR)

spread = max(np.max(np.abs(ddp.replicas[r] - ddp.replicas[0])) for r in range(K))
err    = np.max(np.abs(ddp.replicas[0] - w_ref))
print("world size K              :", K)
print("steps                     :", STEPS)
print("max spread across replicas:", f"{spread:.2e}")
print("max |DDP - single-process|:", f"{err:.2e}")
print("final loss (single)       :", f"{np.mean((X @ w_ref - y) ** 2):.6f}")
print("final loss (mini-DDP)     :", f"{np.mean((X @ ddp.replicas[0] - y) ** 2):.6f}")
Code 15.6.3: A miniature DDP in pure Python. The _all_reduce_buckets method is the heart of it: flatten each replica's gradient, walk it in fixed-size buckets, and average each bucket across replicas, mirroring the bucketed NCCL all-reduce that real DDP fires from its backward hooks.
world size K              : 4
steps                     : 30
max spread across replicas: 0.00e+00
max |DDP - single-process|: 2.22e-16
final loss (single)       : 0.045429
final loss (mini-DDP)     : 0.045429
Output 15.6.3: The four replicas stay exactly identical (spread of zero) and the synchronized trajectory matches the single-process reference to machine epsilon ($2.2 \times 10^{-16}$), with indistinguishable final losses. The bucketed all-reduce reproduces single-machine training exactly, which is DDP's whole promise.

The two reported quantities are the two guarantees DDP makes. A spread of exactly zero across replicas means the workers never drift apart, because they apply the same averaged gradient every step from the same starting point. A difference of $2.2 \times 10^{-16}$ from the single-process run, the smallest representable gap in double precision, means the bucketing and averaging changed nothing about the answer. Real DDP adds GPUs, NCCL, and overlap with the backward pass, but the contract it enforces is precisely the one this forty-line wrapper demonstrates: many replicas, one trajectory, identical to training on one machine.

Library Shortcut: DDP Is the One Idiom That Replaces the Whole Chapter

The miniature wrapper in Code 15.6.3 is roughly forty lines, and a faithful real version would need process-group setup, NCCL bindings, bucket scheduling, overlap with the backward pass, gradient division, and unused-parameter handling, easily several hundred lines of careful systems code. PyTorch collapses all of it to a single wrap:

model = DDP(model, device_ids=[local_rank])   # the entire chapter, in one line
# ...then train normally; loss.backward() does the bucketed all-reduce for you.
Code 15.6.4: The DDP idiom. One constructor call replaces the manual replication, bucketing, and all-reduce of Code 15.6.3, and the library internally handles process-group transport, bucket scheduling, and the overlap of communication with the backward pass.

That single line is why loss.backward() in Code 15.6.1 needs no modification: DDP has already wired the gradient averaging into the autograd graph. You write almost-normal training code and get synchronous data parallelism that is exact to floating-point rounding and overlapped for free.

6. When the Wrapper Is Not Enough Advanced

DDP solves the data-parallel case completely, and the data-parallel case is the one where the whole model fits on one GPU and only throughput binds. The moment the model itself no longer fits, DDP has nothing to offer, because it replicates the full model to every process. That is the boundary where this chapter ends and the next begins: sharded data parallelism and model parallelism, which split the parameters themselves across devices rather than replicating them. The collectives change too, from a single all-reduce to the reduce-scatter and all-gather pair that sharded methods rely on, a transition foreshadowed when this book first introduced collective communication. Within its lane, though, DDP is close to optimal, and the right instinct is to reach for it first and only graduate to heavier machinery when a memory ceiling, not a throughput ceiling, forces the move.

Practical Example: The DataParallel Job That Would Not Scale

Who: A computer-vision team at a robotics startup training an image segmentation model on an eight-GPU server.

Situation: They had wrapped their model in nn.DataParallel a year earlier because it was a one-line change and it worked on their four-GPU prototype box.

Problem: Moving to eight GPUs barely improved throughput; GPU 0 sat near full memory while the others idled, and one CPU core was pinned at 100 percent.

Dilemma: Buy a larger single GPU to fit the bottlenecked primary device, a scale-up move that did nothing about the idle GPUs, or rewrite the training loop for DistributedDataParallel, which meant a launcher, a process group, and a sampler change.

Decision: They migrated to DDP, because the symptoms (one busy CPU core, a memory hotspot on the primary GPU, poor scaling) were the textbook signature of the GIL bottleneck and output-gathering that DDP removes by design.

How: They replaced the DataParallel wrap with a DistributedDataParallel wrap, added a DistributedSampler, guarded the checkpoint save with a rank check, and launched with torchrun --nproc_per_node=8, about thirty lines following Code 15.6.1.

Result: Throughput scaled close to linearly with the eight GPUs, the memory hotspot vanished because each process kept its activations local, and per-epoch time dropped by roughly a factor of six over the old DataParallel run.

Lesson: DataParallel's one-line convenience hides a single-process bottleneck that gets worse with more GPUs. For anything beyond a quick experiment, start with DDP; the extra setup is small and the scaling is real.

Research Frontier: Compiled and Communication-Optimized DDP (2024 to 2026)

DDP's design is stable, but the machinery around it keeps getting faster. PyTorch 2.x integrates DDP with torch.compile, so the framework can reorder and fuse the all-reduce buckets against the compiled backward graph rather than scheduling them with runtime heuristics, tightening the overlap that Figure 15.6.1 depicts. A second thread pushes gradient compression and low-precision collectives into the DDP communication hook API, letting bf16 or quantized all-reduces cut the bytes on the wire with little accuracy loss, in the lineage of the communication-avoiding methods this book tracks. A third thread is the gradual handoff from DDP to FullyShardedDataParallel as the default for large models, where parameters, gradients, and optimizer state are sharded rather than replicated; the 2024 to 2026 releases blur the line by letting FSDP behave like DDP for models that still fit, so the same code path scales from one-GPU-fits to does-not-fit. The constant across all three is the principle from Chapter 4: the all-reduce is the cost to engineer down, and the wrapper is where that engineering hides.

You now have the practical tool that packages this entire chapter. The exact gradient identity, the synchronous step, the ring all-reduce, and the overlap with backpropagation all live inside a single wrapper that lets you write almost-normal training code. The next section turns to Horovod, an alternative framework that brought the same all-reduce-based data parallelism to TensorFlow, Keras, and PyTorch alike, and that pioneered some of the overlap techniques DDP now uses, so that you can recognize the shared idea wearing a different library's clothes.

Exercise 15.6.1: Trace the Hang Conceptual

A colleague reports that their four-GPU DDP job runs the first few hundred steps fine, then hangs with all GPUs at 100 percent utilization and no error. Their model has a branch (if x.mean() > 0: out = self.extra_layer(out)) that depends on the input. Using the deadlock mechanism described in Section 3, explain exactly why the job hangs, why it can run for a while before doing so, and name two distinct fixes (one at the model level, one at the DDP-construction level). Why is a missing all-reduce a silent hang rather than a crash?

Exercise 15.6.2: Add no_sync to MiniDDP Coding

Extend the MiniDDP class in Code 15.6.3 with a gradient-accumulation mode that mirrors no_sync(). Give each replica an internal gradient buffer; on accumulation steps, add the local gradient to the buffer and skip _all_reduce_buckets; on the final step of each accumulation window, add the local gradient, all-reduce the accumulated buffer once, apply the update, and clear the buffer. Verify that accumulating over accum micro-shards of size $B$ produces the same final weights as one all-reduced step over a shard of size accum * B, and count how many all-reduce calls each version makes. Confirm the bucketed all-reduce count drops by the factor you expect.

Exercise 15.6.3: Bucket Size Versus Overlap Analysis

Real DDP exposes a bucket_cap_mb knob that controls how large each gradient bucket grows before its all-reduce fires. Reason about the trade-off using the $\alpha + \beta n$ communication-cost model from Chapter 4, where $\alpha$ is per-message latency and $\beta$ is the per-byte cost. Argue why very small buckets pay the latency term $\alpha$ too many times, why very large buckets leave little of the backward pass to overlap with (so the all-reduce is exposed at the end), and sketch what an ideal bucket size depends on (model depth, layer sizes, network latency, backward-pass duration). Why is there no single best bucket size across all models?