Part IV: Parallel Deep Learning and Large Models
Chapter 18: Elastic and Fault-Tolerant Distributed Training

Distributed Checkpointing

"I crashed at hour forty. They restarted me at hour thirty-nine and fifty-five minutes, and nobody mentioned it again. That is the kindest thing a cluster has ever done for me."

A Worker Restored From Its Last Shard
Big Picture

A checkpoint is the saved state that lets a failed training run resume from minutes ago instead of from the beginning, and at thousand-GPU scale the difficulty is not whether to checkpoint but how to write terabytes of state, scattered across thousands of devices, without stalling the very training you are trying to protect. The previous section established that failures are routine, not exceptional, and gave us the mean time between failures (MTBF) that governs how often something breaks. This section turns that number into an engineering decision: what state must be saved to make a restart exact, how to write it in parallel so the cost stays small, and how often to write it so you pay neither too much for checkpoints nor too much for lost work. The answer to "how often" is a single formula, the Young/Daly optimal interval, and we will compute it from the MTBF of Section 18.1.

When a single-machine training script crashes, you lose whatever progress sat in memory and you start over. For a script that finishes in minutes, that is an annoyance. For a foundation-model run that occupies thousands of accelerators for weeks (Chapter 19), starting over is a catastrophe measured in hundreds of thousands of dollars of wasted compute. The defense is the checkpoint: periodically, the run writes enough of its state to durable storage that, after any failure, it can reload that state and continue as if the crash had cost only the work done since the last save. Section 18.1 showed that at scale a crash is essentially guaranteed within hours; this section makes the recovery from it cheap.

The naive version is one line of code and a disaster in practice. Calling a blocking save every few minutes works on one GPU and collapses on a thousand, because the state is now sharded across every device, the volume is enormous, and a synchronous write freezes all workers while the bytes crawl to storage. Getting checkpointing right at scale means answering three questions precisely: what to save, how to write it in parallel, and how often. We take them in that order.

1. What Must Be in a Checkpoint Beginner

A correct checkpoint is one from which the resumed run produces the same trajectory it would have produced without the crash. That bar is higher than "save the weights," and missing any component silently corrupts the resumed run rather than crashing it, which is far worse. Five pieces of state are required, and Table 18.2.1 names each one and the consequence of forgetting it.

Table 18.2.1: The five components of a correct training checkpoint. Omitting any one does not crash the restart; it silently diverges the resumed run from the run that would have continued uninterrupted.
ComponentWhat it holdsIf you forget it
Model parametersEvery weight tensor, in its training precisionThe most obvious loss; resume is impossible without it
Optimizer stateMomentum and variance buffers (Adam holds two per parameter)The optimizer cold-starts; loss spikes and convergence stalls
Learning-rate scheduleThe step counter that positions you on the warmup/decay curveThe learning rate jumps backward, destabilizing training
RNG stateThe pseudo-random generator state on every rankDropout masks and data shuffles diverge; runs are not reproducible
Data positionWhich examples have been consumed (sampler offset, epoch)Already-seen data is replayed or fresh data is skipped

Parameters and optimizer state dominate the byte count: an Adam-trained model carries roughly three numbers of optimizer and gradient state for every parameter, so a model with $P$ parameters checkpoints on the order of several times $P$ values. The schedule, RNG, and data position are tiny by comparison but are exactly the pieces a hurried implementation drops, which is why a run can "resume" and quietly never reach its target loss. The determinism this state protects is the subject of Section 18.3; here we only insist that all five be saved together as one atomic unit.

Key Insight: A Checkpoint Is a Resumption Contract, Not a Backup

The test of a checkpoint is not "did the file write" but "does loading it reproduce the run that never crashed." That demands parameters, optimizer state, the schedule step, every rank's RNG state, and the data cursor, saved atomically. The expensive components (parameters, optimizer state) are the ones everyone remembers; the cheap components (RNG, data position) are the ones whose absence corrupts the resumed run silently. Treat the checkpoint as the complete state needed to continue, and verify it by actually resuming, not by inspecting the file.

2. The Scale Problem: Sharded State, Synchronized Stalls Intermediate

On one device the checkpoint is one file and the only cost is the seconds it takes to write. Under sharded training, ZeRO and FSDP (Section 16.5), no single device holds the whole model: parameters, gradients, and optimizer state are partitioned so that rank $k$ owns only its slice. This is what makes trillion-parameter models trainable, and it is exactly what makes checkpointing hard. The state that must be saved is now spread across thousands of devices, the aggregate is terabytes, and the obvious implementation, gather everything to rank zero and let it write one file, funnels all that data through a single node and a single network path while every other worker sits idle.

The fix mirrors the structure of the training itself: if the state is sharded, save it sharded. Each rank writes its own slice, in parallel, directly to a distributed filesystem or object store (the storage layer of Section 8.2). A thousand ranks writing a thousand shards at once turns a serial bottleneck into an aggregate-bandwidth problem the cluster is built to handle. The catch is that even a parallel write is a synchronous write: while the bytes flush to storage, the accelerators that should be computing the next step are stalled at a barrier, burning the most expensive compute in the building on doing nothing. Figure 18.2.1 contrasts the synchronous sharded write with the asynchronous scheme that removes the stall.

Sharded synchronous write: parallel, but stalls compute Rank 0 Rank 1 Rank 2 Rank 3 shard 0... Distributed store (one shard per rank) compute is blocked at a barrier until every flush completes Asynchronous: snapshot fast, flush in the background GPU stateon device copy Host RAMsnapshot training resumes here next step (compute) backgroundflush Distributed storedurable, off the hot path Stall reduced from the full write time to just the fast device-to-host snapshot.
Figure 18.2.1: Two ways to checkpoint sharded state. Top: every rank writes its own shard in parallel to the distributed store (Section 8.2), but all ranks stall at a barrier until the slowest flush finishes. Bottom: asynchronous checkpointing copies each rank's state to host memory in a fast device-to-host snapshot, training resumes immediately, and a background thread flushes host memory to durable storage off the critical path. The stall shrinks from the whole write to just the snapshot.

3. Asynchronous and In-Memory Checkpointing Intermediate

The asynchronous scheme in the lower half of Figure 18.2.1 rests on one observation: copying a tensor from accelerator memory to host RAM is fast (a device-to-host transfer over a wide local bus), while flushing that same data to a distributed filesystem is slow (a network write to durable storage). Synchronous checkpointing pays both costs on the critical path. Asynchronous checkpointing pays only the first: each rank snapshots its shard to a pinned host-memory buffer, training resumes the instant that copy completes, and a background thread drains the host buffer to durable storage while the next steps run. The training stall collapses from "time to write to disk" down to "time to copy to host RAM," often a tenfold reduction or more.

For the fastest recovery of all, some systems skip durable storage on the common path entirely and keep a checkpoint in the memory of a peer node. If a worker dies, its replacement pulls the last state from a neighbor's RAM over the interconnect rather than reading it back from a filesystem, recovering in seconds. The trade is durability: an in-memory or peer checkpoint vanishes if enough nodes fail at once, so production systems pair frequent fast in-memory checkpoints with occasional slow durable ones, a tiering that mirrors the storage hierarchy of Section 8.2. The frequent cheap checkpoints bound the lost work; the rare durable ones survive a correlated outage.

Library Shortcut: torch.distributed.checkpoint Saves the Shards for You

Writing a sharded, parallel, resharding-aware checkpoint by hand means coordinating per-rank file layout, handling the case where you resume on a different number of GPUs, and overlapping the flush with compute. PyTorch's torch.distributed.checkpoint (DCP) does all of it: each rank saves its own shard in parallel, the format is independent of the world size so you can restart elastically on a different device count (Section 18.4), and async_save performs exactly the snapshot-then-background-flush of Figure 18.2.1.

import torch.distributed.checkpoint as dcp

state = {"model": model.state_dict(),       # sharded FSDP params
         "optim": optim.state_dict(),        # sharded optimizer state
         "step":  scheduler.last_epoch,       # schedule position
         "rng":   torch.get_rng_state()}      # per-rank RNG

# Each rank writes its OWN shard in parallel; returns a future immediately.
future = dcp.async_save(state, checkpoint_id=f"ckpt/step_{step}")
# ... training continues here while the flush happens in the background ...
future.result()    # only block if you must guarantee durability before proceeding
Code 18.2.1: A sharded asynchronous checkpoint in five effective lines. Roughly a hundred lines of manual per-rank file handling, world-size bookkeeping, and background-thread management collapse into dcp.async_save, which handles the parallel shard write, world-size-independent layout, and the snapshot-then-flush overlap.

4. How Often to Checkpoint: The Young/Daly Optimal Interval Advanced

Even a cheap asynchronous checkpoint is not free, so checkpointing too often wastes time on writes, and checkpointing too rarely wastes time re-doing lost work after a crash. There is an interval that minimizes the total waste, and it is given by a classical result from high-performance computing, the Young/Daly formula. Let $C$ be the cost (in stalled training time) of one checkpoint and let $\text{MTBF}$ be the mean time between failures from Section 18.1. When checkpoints are cheap relative to the failure interval, the optimal compute time between checkpoints is

$$\tau^{*} = \sqrt{2\,C\,\cdot\,\text{MTBF}}.$$

The intuition is a balance: the term $C$ pushes the interval longer (fewer expensive writes), while the failure rate $1/\text{MTBF}$ pushes it shorter (less work to redo when a crash strikes). The square root is where the two forces meet. The fraction of wall-clock time lost to the combined overhead, checkpoint writes plus expected re-execution, is small when $\tau^{*}$ is chosen well, and it is this fraction, not the raw interval, that tells you whether your checkpointing strategy is healthy. The code below computes both for a synchronous and an asynchronous checkpoint against a three-hour MTBF, then compares the total stall each accumulates over a six-hour run.

import math

# Young/Daly optimal checkpoint interval: tau* = sqrt(2 * C * MTBF), C << MTBF.
MTBF    = 3.0 * 3600.0   # mean time between failures: 3 hours, in seconds
C_sync  = 90.0           # synchronous checkpoint: 90 s of stalled training
C_async = 4.0            # asynchronous: only the host-memory snapshot stalls

def young_daly(C, mtbf):
    tau = math.sqrt(2.0 * C * mtbf)                 # optimal interval between checkpoints
    waste = C / (tau + C) + (tau / 2.0) / mtbf      # write overhead + expected lost work
    return tau, waste

for label, C in [("synchronous", C_sync), ("asynchronous", C_async)]:
    tau, waste = young_daly(C, MTBF)
    print(f"{label:13s} C={C:5.0f}s  tau*={tau/60:6.2f} min  wasted={waste*100:5.2f}%")

RUN = 6.0 * 3600.0                                   # a 6-hour run
print()
for label, C in [("synchronous", C_sync), ("asynchronous", C_async)]:
    tau, _ = young_daly(C, MTBF)
    n_ckpt = RUN / tau
    print(f"{label:13s} interval={tau/60:6.2f} min  "
          f"checkpoints={n_ckpt:4.1f}  total stall={n_ckpt * C / 60:5.2f} min")
Code 18.2.2: The Young/Daly optimal interval and resulting wasted-time fraction, computed from a checkpoint cost $C$ and the MTBF, then turned into the total checkpoint stall accumulated over a six-hour run for synchronous versus asynchronous writes.
synchronous   C=   90s  tau*= 23.24 min  wasted=12.52%
asynchronous  C=    4s  tau*=  4.90 min  wasted= 2.70%

synchronous   interval= 23.24 min  checkpoints=15.5  total stall=23.24 min
asynchronous  interval=  4.90 min  checkpoints=73.5  total stall= 4.90 min
Output 18.2.2: The cheaper asynchronous checkpoint earns a much shorter optimal interval (4.9 versus 23.2 minutes), cuts the wasted-time fraction from 12.5% to 2.7%, and despite firing nearly five times as often accumulates less than a quarter of the total stall, because each write is so much cheaper.

The result is worth pausing on. Lowering the per-checkpoint cost does more than reduce the cost of each write; it lets you afford a much shorter interval, which shrinks the expected lost work too. The asynchronous run checkpoints almost five times as often yet stalls less than a quarter as long, and its overall wasted fraction is roughly a fifth of the synchronous run's. This is precisely why the engineering effort of Section 3 pays off: cheap checkpoints are not just cheaper writes, they unlock a fundamentally better point on the cost curve. The same $\tau^{*}$ feeds directly into the restart and elasticity machinery of Section 18.3 and Section 18.4.

Thesis Thread: Sharding Returns, This Time to Survive Failure

The same partitioning that lets a trillion-parameter model live across thousands of devices (Section 16.5) is what lets its checkpoint be written in parallel: shard the state, and each rank saves its slice independently, turning a serial bottleneck into aggregate bandwidth. Distribution is not only how the model is trained; it is how the run is made survivable. Every scale-out technique in this book eventually meets fault tolerance, and the meeting point is almost always a checkpoint whose structure mirrors the structure of the computation.

Practical Example: The Restart That Stopped Costing a Day

Who: An ML platform engineer running a 400-billion-parameter pretraining job on a 2,048-GPU cluster.

Situation: The run checkpointed synchronously every two hours by gathering all state to rank zero and writing one monolithic file.

Problem: Each checkpoint stalled all 2,048 GPUs for eleven minutes, and a hardware failure roughly once a day forced re-execution of up to two hours of work.

Dilemma: Checkpoint more often to cut lost work, but then the eleven-minute synchronous stalls would dominate; or checkpoint less often and lose more work per crash. The synchronous write made both directions painful.

Decision: Move to sharded asynchronous checkpointing so each rank writes its own slice and training resumes after a fast host-memory snapshot, then recompute the interval from the cluster's measured MTBF with Young/Daly.

How: They switched to torch.distributed.checkpoint.async_save, layered a frequent in-memory peer checkpoint under an occasional durable one, and set the interval to the computed $\tau^{*}$ of about six minutes.

Result: Per-checkpoint stall fell from eleven minutes to under thirty seconds, average lost work per crash dropped from two hours to a few minutes, and total wasted compute fell by roughly an order of magnitude, exactly the kind of gap Output 18.2.2 quantifies.

Lesson: Make the checkpoint cheap first, then let the Young/Daly interval ride down with it. The cost of a write and the frequency you can afford are the same problem, not two.

Research Frontier: Checkpointing at the Frontier (2024 to 2026)

As runs grew to tens of thousands of accelerators, checkpointing became a published systems problem in its own right. Google's Gemini technical report describes in-memory and redundant checkpointing that let training recover from frequent hardware failures while keeping goodput above 97%, treating fast recovery as a first-class throughput lever rather than an afterthought. PyTorch's Distributed Checkpoint added asynchronous and world-size-independent saving so that elastic restarts (Section 18.4) reshard the state automatically, and research systems such as Gemini (Wang et al., 2023) and CheckFreq-style adaptive scheduling push checkpoint frequency toward the per-step limit by hiding the write behind compute. The common thread is that the 2024 to 2026 frontier no longer asks "how do we checkpoint without too much overhead" but "how do we checkpoint so often that a failure costs seconds," which turns the Young/Daly interval of Output 18.2.2 from minutes toward the duration of a single step.

Fun Note: The Save Point Was Always the Hard Part

Anyone who has lost an hour of a video game to a crash right before the next save point already understands Young/Daly in their bones. The game designer who spaces save points too far apart and the engineer who checkpoints too rarely are making the identical mistake, and both learn it from furious users. The difference is that the engineer can compute the optimal spacing in advance from $C$ and the MTBF, while the game designer just has to guess.

5. Putting It Together Beginner

A production checkpointing strategy is now three decisions, each answered above. Save all five components of Table 18.2.1 as one atomic unit, so a restart is exact and not merely plausible. Write the state sharded and asynchronously, so the per-checkpoint cost $C$ is small and the accelerators barely pause. Set the interval to the Young/Daly $\tau^{*}$ computed from that small $C$ and the cluster's MTBF, so you spend the minimum total time on writes plus re-execution. With those in place, a failure that the previous section showed to be inevitable costs a few minutes of recompute instead of the whole run.

What we have not yet done is make the restart itself correct. Loading the right bytes is necessary but not sufficient: the resumed run must replay its data in the right order and reproduce its random decisions, or it will diverge from the trajectory it was meant to continue. That is the determinism problem, and it is where Section 18.3 takes us next, turning the saved state of this section into a bit-faithful resumption.

Exercise 18.2.1: The Silently Broken Resume Conceptual

An engineer checkpoints only the model parameters and optimizer state, omitting the learning-rate schedule step, the RNG state, and the data position. The run resumes after a crash and does not error, but its final validation loss is worse than an uninterrupted run. For each of the three omitted components in Table 18.2.1, describe the specific way the resumed trajectory diverges from the uninterrupted one, and rank the three by how much damage each omission is likely to do. Why is a checkpoint that loads without error more dangerous than one that fails to load?

Exercise 18.2.2: Tune the Interval to the Cluster Coding

Extend Code 18.2.2 to sweep MTBF from 30 minutes to 12 hours and plot (or tabulate) the optimal interval $\tau^{*}$ and the wasted-time fraction for both the synchronous and asynchronous checkpoint costs. At what MTBF does the synchronous strategy's wasted fraction exceed 20%, and what does that imply for a cluster that is growing (and therefore, per Section 18.1, seeing its MTBF fall)? Then add a third row for an in-memory peer checkpoint with $C = 0.5$ s and comment on why a real system still keeps the slower durable checkpoint despite its much larger $C$.

Exercise 18.2.3: The Cost of Gathering to Rank Zero Analysis

A 200-billion-parameter model trained with Adam in 16-bit precision must checkpoint parameters plus two optimizer buffers per parameter. Estimate the total checkpoint size in bytes. If you gather everything to rank zero and write through that node's single 5 gigabytes-per-second link, estimate the synchronous write time. Now suppose instead that 1,024 ranks each write their shard in parallel, with the distributed store sustaining an aggregate 400 gigabytes per second. Estimate the sharded write time and the speedup. Using your sharded $C$ and a two-hour MTBF, compute $\tau^{*}$ from the Young/Daly formula and state the wasted-time fraction.