"I came back from the dead with the right weights and the wrong data cursor, and spent the next hour learning the same thousand examples I had already memorized."
A Worker That Resumed Too Eagerly
A checkpoint is only half of fault tolerance; the other half is resuming from it correctly, which means restoring not just the model weights but the exact data-stream position and the random state on every rank, then accepting that a distributed resume is statistically equivalent to the original run rather than bitwise identical. The previous section showed how to write a checkpoint cheaply and often. This section spends that checkpoint. A correct restart detects the failure, fences off the dead or hanging ranks, loads the latest consistent checkpoint on every surviving rank, rewinds the data loader to where the checkpoint left off, and continues the loss curve as if nothing happened. Get any one of those pieces wrong, weights without data position, data position without random state, and the loss curve visibly bends away from where it should be. We will see that bend, measure it, and explain why bitwise-identical resume across thousands of GPUs is usually impossible and rarely necessary.
In Section 18.2 we made checkpointing cheap enough to do often: sharded so each rank writes only its slice, asynchronous so the training step barely pauses. That investment pays off only at the moment of recovery. A thousand-GPU job that checkpoints every fifteen minutes but cannot reliably restart from those checkpoints has bought nothing but disk traffic. Restart is where the saved state becomes a running job again, and it is unforgiving: a single rank that loads the wrong shard, or a data loader that forgets which examples it had already served, turns a clean recovery into a silent corruption that shows up days later as a model that underperforms for no visible reason.
The restart flow has four steps that must happen in order, and a fifth concern, determinism, that hangs over all of them. We detect the failure, fence the stragglers and the dead ranks so they cannot poison the surviving group, restore the latest checkpoint that every rank agrees on, replay the data stream to the saved position, and resume. Figure 18.3.1 lays out the flow against the loss curve, and contrasts the curve a correct resume produces with the one a broken replay produces.
1. The Restart Flow: Detect, Fence, Restore, Resume Beginner
A failure announces itself in one of a few ways, and each demands a different first response. A crashed process drops its connection to the collective group, so the surviving ranks see the next all-reduce hang until it times out; the NCCL watchdog in PyTorch raises after a configurable window. A hung rank is worse, because it holds its connection open while making no progress, so the whole synchronous group stalls behind it at the next barrier. A slow rank, the straggler of Section 18.5, is not a failure at all but degrades into one if it falls far enough behind. The detection layer's job is to turn all three into the same clean signal: this rank is out, tear down and restart.
Once a failure is detected, the surviving ranks must not simply press on. A half-dead rank that is still partly connected can corrupt a collective: an all-reduce that includes a rank with stale or garbage gradients silently averages garbage into everyone's update. Fencing means forcibly removing the suspect rank from the process group before any further collective runs, usually by tearing the whole group down and rebuilding it from the survivors. This is why restart in classic synchronous training is all-or-nothing: you stop every rank, not just the failed one, because a collective is only correct when its membership is agreed upon. Elastic training, in Section 18.4, relaxes this by rebuilding the group at a new size; here we assume the simpler fixed-size restart where every rank reloads and rejoins.
Restoring is the step where sharded checkpointing from Section 18.2 pays back. Every rank loads its own shard of the model, optimizer, and learning-rate-scheduler state from the latest checkpoint that all ranks agree is complete. The agreement matters: if rank 3 wrote its shard for step 1000 but rank 7 crashed mid-write, the only consistent checkpoint is step 750, and every rank must roll back to it together. A checkpoint is usable only when its slowest writer finished, which is why the checkpoint metadata records a step number that the whole job commits to atomically. Resume is then almost anticlimactic: rebuild the process group, broadcast a sanity check that all ranks loaded the same step, and re-enter the training loop.
The model weights are the obvious state to restore, and the one every tutorial mentions. They are not enough. A correct resume restores three coupled pieces of state on every rank: the model and optimizer tensors (where learning is), the data-loader position (which examples come next), and the random-number generator state (what dropout, augmentation, and shuffling will do). Drop the data-loader position and you replay or skip examples. Drop the RNG state and the same batch takes a different stochastic path. A checkpoint that saves weights alone is a checkpoint that resumes to a different run than the one that crashed, and the loss curve will tell on you.
2. Replay: Re-establishing the Exact Data-Stream Position Intermediate
The hardest piece of the three to get right is the data-loader position, because in a distributed job the data stream is itself sharded and shuffled. Recall from Chapter 8 that each rank reads its own shard of the dataset, often from a streaming pipeline that reshuffles every epoch with an epoch-dependent seed. The loader's true position is therefore not a single integer but a small tuple: which epoch we are in, which seed produced that epoch's shuffle, and how many examples into the shuffled order this rank has already consumed. Replay means reconstructing that tuple exactly, so the next batch after resume is the batch that would have come next had the crash never happened.
Two failure modes flank the correct behavior, and both are common. If you restart the data stream from the top, the resumed run re-trains on every example from the checkpoint's cursor to the end of the epoch, examples the model already saw, before reaching new data. The model overweights those examples, which biases the gradient and shows up as the loss spike in Figure 18.3.1. If instead you advance the cursor too far, perhaps by counting global steps but forgetting that the crashed step never completed, you skip a band of examples entirely, and the model never trains on them. Skipping is the quieter bug because it produces no spike, just a slightly worse model from a slightly smaller effective dataset. Both are avoided by treating the data-loader cursor as first-class checkpoint state, saved and restored alongside the weights.
Modern data-loading libraries make this position serializable so you do not reconstruct it by hand. The principle is the same whichever tool you use: the loader exposes a small state dictionary, you write it into the checkpoint, and on resume you load it back before the first batch. We return to the library form in the shortcut below; the demonstration code next implements the cursor explicitly so you can see exactly what is being saved.
A classic resume bug saves the cursor after incrementing it but restores it before the step that was about to run, so every restart quietly replays one batch. On a job that preempts every few minutes on cheap spot instances, that one extra batch per restart compounds: the job spends a measurable fraction of its life re-learning the same handful of examples, like a forgetful student who reopens the book to the same page after every interruption. The fix is a single line, but the symptom, a job that converges slower than its FLOP budget says it should, is maddening to diagnose without the cursor in the logs.
3. Determinism: What You Can and Cannot Guarantee Advanced
It is tempting to demand that a resumed run be bitwise identical to the run that crashed: same weights, same losses, same final model down to the last bit. On a single deterministic CPU process this is achievable, and our demonstration achieves it. Across thousands of GPUs it usually is not, and understanding why separates a realistic recovery target from an impossible one. Three sources of nondeterminism conspire against bitwise reproducibility, and they are worth naming because each has a different remedy and a different cost.
The first is the reduction order of collectives. An all-reduce sums one vector per worker, and floating-point addition is not associative: $(a + b) + c$ can differ from $a + (b + c)$ in the last bits. A ring all-reduce that visits workers in a different order after a restart, because the surviving ranks were renumbered, produces a gradient that differs from the original in those last bits, and the difference compounds over thousands of steps. The second is nondeterministic GPU kernels: many fast CUDA kernels (certain convolutions, atomic-add reductions, some attention implementations) trade reproducibility for speed and return slightly different results run to run even on the same hardware. The third is RNG divergence across ranks: if each rank's random state is not seeded and restored per rank, dropout masks and augmentations differ after resume, and two runs that should agree no longer do.
What you can guarantee is more modest and more useful: statistical equivalence. A correct resume restores the per-rank RNG state, the data-loader position, and the optimizer state, so the resumed run is drawn from the same training distribution as the original and converges to a model of the same quality. It may not be bitwise identical, but the loss curves are indistinguishable within the noise the optimizer already contains. This is the same distinction Chapter 5 draws for reproducible measurement on clusters: you report results with seeds and environment fixed, and you accept run-to-run variation within a stated tolerance rather than chasing a determinism that the hardware will not give you. You can buy bitwise determinism if you truly need it, by forcing deterministic kernels and fixing reduction order, but it can slow training by a noticeable margin, and for almost all training jobs statistical equivalence is the right target.
Who: A research engineer training a 7-billion-parameter language model on a 256-GPU cluster.
Situation: A node failed at step 84,000; the job auto-restarted from the step-83,500 checkpoint and continued.
Problem: Right after every restart, the training loss jumped by roughly fifteen percent and took two thousand steps to recover, and it happened on every preemption.
Dilemma: The team first suspected the learning-rate scheduler resuming at the wrong value, then suspected a too-aggressive optimizer; both would have meant retuning hyperparameters and burning more compute.
Decision: Before changing any hyperparameter, they logged the data-loader cursor at save and at restore, and found the loader restarted each epoch's stream from index zero instead of the saved cursor.
How: They serialized the loader's state dictionary into the checkpoint and restored it on resume, so the stream continued from the saved position instead of replaying the consumed prefix.
Result: The post-restart spike vanished; the loss curve continued smoothly through restarts, exactly as the green curve in Figure 18.3.1, and the hyperparameters were never the problem.
Lesson: A loss spike at the resume point is almost always a replay bug, not a learning-rate bug. Log the data-loader cursor across the checkpoint boundary before you touch the optimizer.
4. A Working Resume, Correct and Broken Intermediate
The code below makes all of this concrete in pure Python. It runs a tiny stochastic training loop, SGD on a linear-regression loss with a shuffled, sharded-style data stream and a per-step random perturbation standing in for dropout, three times. The first run never crashes and is our reference. The second checkpoints everything at step 60, weights, the data-loader cursor, and the exact bit-state of the random generator, then resumes correctly. The third resumes the buggy way: it restores the weights but reseeds the RNG and restarts the data stream from the top, the two mistakes Section 2 and Section 3 warned about.
import numpy as np
D, TOTAL, STEPS, BATCH, LR = 8, 4000, 120, 32, 0.02
CRASH_AT = 60 # the step where the job dies
data_rng = np.random.default_rng(12345) # the fixed "dataset"
X = data_rng.standard_normal((TOTAL, D))
w_star = data_rng.standard_normal(D)
Y = X @ w_star + 0.05 * data_rng.standard_normal(TOTAL)
def loss(w, idx):
r = X[idx] @ w - Y[idx]
return float(r @ r / len(idx))
def make_epoch_order(seed, epoch): # deterministic per-epoch shuffle
g = np.random.default_rng((seed, epoch))
order = np.arange(TOTAL); g.shuffle(order)
return order
def next_batch(state): # advance the data-loader cursor
order = make_epoch_order(state["seed"], state["epoch"])
if state["cursor"] + BATCH > TOTAL: # roll to a freshly shuffled epoch
state["epoch"] += 1; state["cursor"] = 0
order = make_epoch_order(state["seed"], state["epoch"])
start = state["cursor"]
idx = order[start:start + BATCH]
state["cursor"] = start + BATCH
return idx
def run(steps, w, loader, noise, record):
for _ in range(steps):
idx = next_batch(loader)
g = (2.0 / len(idx)) * (X[idx].T @ (X[idx] @ w - Y[idx]))
g = g + 0.01 * noise.standard_normal(D) # stochastic regularizer (RNG)
w = w - LR * g
record.append(loss(w, np.arange(TOTAL)))
return w
# (1) reference run, no crash
w = np.zeros(D); loader = {"seed": 7, "epoch": 0, "cursor": 0}
noise = np.random.default_rng(999); ref = []
w = run(STEPS, w, loader, noise, ref)
# checkpoint EVERYTHING at the crash step, then "die"
w = np.zeros(D); loader = {"seed": 7, "epoch": 0, "cursor": 0}
noise = np.random.default_rng(999); pre = []
w = run(CRASH_AT, w, loader, noise, pre)
ckpt = {"w": w.copy(), "loader": dict(loader),
"noise_state": noise.bit_generator.state} # exact RNG state, not a seed
# (2) correct resume: restore weights + cursor + RNG state
w_ok = ckpt["w"].copy(); loader_ok = dict(ckpt["loader"])
noise_ok = np.random.default_rng(); noise_ok.bit_generator.state = ckpt["noise_state"]
ok = list(pre); w_ok = run(STEPS - CRASH_AT, w_ok, loader_ok, noise_ok, ok)
# (3) buggy resume: weights only; RNG reseeded, data stream restarted
w_bad = ckpt["w"].copy(); loader_bad = {"seed": 7, "epoch": 0, "cursor": 0}
noise_bad = np.random.default_rng(999) # BUG: RNG reset
bad = list(pre); w_bad = run(STEPS - CRASH_AT, w_bad, loader_bad, noise_bad, bad)
print(f"final loss no-crash reference : {ref[-1]:.6f}")
print(f"final loss correct resume : {ok[-1]:.6f}")
print(f"final loss buggy resume : {bad[-1]:.6f}")
print(f"max |correct - reference| : {max(abs(a-b) for a,b in zip(ok, ref)):.2e}")
print(f"max |buggy - reference| : {max(abs(a-b) for a,b in zip(bad, ref)):.2e}")
print(f"\n{'step':>4} {'reference':>9} {'correct':>7} {'buggy':>5}")
for s in (59, 60, 61, 65, 119): # the resume boundary and beyond
print(f"{s:>4} {ref[s]:.5f} {ok[s]:.5f} {bad[s]:.5f}")
# the buggy loader was reset to cursor 0, replaying the consumed prefix this epoch
print(f"\nbuggy resume replays examples already trained on this epoch: "
f"cursor was {ckpt['loader']['cursor']}, reset to 0")
final loss no-crash reference : 0.002954
final loss correct resume : 0.002954
final loss buggy resume : 0.002994
max |correct - reference| : 0.00e+00
max |buggy - reference| : 1.85e-03
step reference correct buggy
59 0.05652 0.05652 0.05652
60 0.05183 0.05183 0.05368
61 0.04852 0.04852 0.04979
65 0.03755 0.03755 0.03734
119 0.00295 0.00295 0.00299
buggy resume replays examples already trained on this epoch: cursor was 1920, reset to 0
The numbers match Figure 18.3.1 exactly. Before the crash at step 60 all three runs are identical, because they are the same run. At step 60 the correct resume continues to track the reference to the bit, a maximum difference of zero across all 120 steps, while the buggy resume jumps from the reference's $0.05183$ to $0.05368$ and stays above the reference for the rest of training, the loss spike of Figure 18.3.1 made numerical. The final losses, $0.002954$ for both the reference and the correct resume versus $0.002994$ for the buggy one, look close, but on a real model trained for a hundred thousand steps with a spike at every preemption, that gap accumulates into a measurably worse model. This single-process demo achieves bitwise equality precisely because it dodges the three nondeterminism sources of Section 3; the lesson that survives the move to a real cluster is the qualitative one: restore all three pieces of state, or watch the curve bend.
In Code 18.3.1 we tracked the data-loader cursor by hand in a dictionary and restored it explicitly. Production stacks expose this as a serializable loader state, so the whole cursor-and-shuffle position collapses into two calls. PyTorch's torchdata stateful DataLoader and the streaming loaders from MosaicML Streaming and WebDataset all follow the pattern: state_dict() to capture the position, load_state_dict() to restore it.
from torchdata.stateful_dataloader import StatefulDataLoader
loader = StatefulDataLoader(dataset, batch_size=32, num_workers=4)
# At checkpoint time, save the loader state alongside model and optimizer.
ckpt = {"model": model.state_dict(),
"optim": optimizer.state_dict(),
"loader": loader.state_dict(), # exact stream position, per rank
"rng": torch.get_rng_state()} # plus CUDA + per-rank seeds
# On resume, restore all four before the first batch.
loader.load_state_dict(ckpt["loader"]) # replay lands on the right example
state_dict / load_state_dict pair. The roughly twenty lines of hand-rolled epoch-and-cursor bookkeeping become two calls, and the library handles per-worker sharding, shuffle seeds, and prefetch position internally.5. Diagnosing Loss Spikes After Resume Intermediate
When a loss spike appears right after a restart, the cause is almost always in the resume, not in the model, and a short checklist isolates it fast. First, log the data-loader cursor at save and at restore and confirm they match; a mismatch is the replay bug of Section 2 and the single most common culprit, as the practical example showed. Second, confirm the learning-rate scheduler resumed at the step it crashed on, not at step zero, since a scheduler that warms up again injects a large learning rate into a partly-trained model and spikes the loss. Third, confirm the optimizer's momentum and variance buffers were restored; an Adam optimizer that resumes with zeroed second-moment estimates takes oversized steps until the buffers refill. Fourth, confirm per-rank seeds were restored so the augmentation and dropout streams continue rather than restart.
Not every post-resume spike is a bug. A small, transient bump that decays within a few dozen steps is consistent with the floating-point nondeterminism of Section 3: the resumed run took a microscopically different path and the optimizer is re-settling. The diagnostic that separates benign from pathological is the recovery time. A spike that decays in tens of steps is the harmless statistical-equivalence wobble; a spike that takes thousands of steps to recover, or never fully recovers, is a state-restoration bug eating real training progress. Keeping the no-crash reference curve from a short clean run on hand, exactly as Code 18.3.1 does, turns this from a judgment call into a comparison.
As training runs grow to tens of thousands of accelerators, the cost of every restart, and every percent of throughput lost to spikes, has pushed resume from an afterthought to a research target. Recent training-system reports describe in-memory and peer-to-peer checkpointing that recovers in seconds rather than minutes by reading a healthy replica's state instead of remote storage (CheckFreq-style and Gemini-style designs, and the recovery machinery documented around the Llama 3 and OPT training runs). A parallel line attacks determinism directly: bitwise-reproducible collectives and deterministic kernel flags (PyTorch's use_deterministic_algorithms, NCCL configurations that fix reduction order) make a resumed step reproduce the original at a throughput cost teams are increasingly willing to pay for debuggability. A third thread builds self-healing supervisors that detect a failure, fence the dead rank, and trigger an elastic restart with no human in the loop, the subject of Section 18.4. The common thread is that correctness of resume, not just speed of checkpointing, is now treated as a first-class system property at frontier scale.
We now have a resume that continues the curve instead of bending it: detect the failure, fence the survivors into a clean group, restore the model, optimizer, data cursor, and per-rank RNG together, and accept statistical rather than bitwise equivalence. What we have assumed throughout is that the job comes back at the same size it left. The next section drops that assumption. Elastic training rebuilds the process group at whatever size the cluster can currently offer, so a job that loses a node does not wait for a replacement but shrinks and keeps going, and a job that gains nodes grows. That flexibility, and the rescaling of batch size and learning rate it forces, is the subject of Section 18.4.
Code 18.3.1's buggy resume replays data by restarting the epoch stream, and produces a visible loss spike. Describe the opposite bug, a resume that skips a band of examples (for instance, by advancing the cursor past the batch the crashed step never finished). Explain why skipping produces no loss spike at the resume point yet still yields a worse final model, and why skipping is therefore the more dangerous bug to ship. State one logging signal you would add to catch a skip bug that the loss curve alone will not reveal.
Extend Code 18.3.1 to use momentum SGD ($v \leftarrow \mu v + g$, then $w \leftarrow w - \eta v$ with $\mu = 0.9$). Add a fourth run that resumes correctly in every respect except that it restores the weights and the data cursor but resets the momentum buffer $v$ to zero. Measure the loss spike at the resume step and the number of steps it takes to rejoin the reference curve. Compare it to the fully-correct resume, and explain why a zeroed momentum buffer behaves like a temporary learning-rate change.
Section 3 lists three sources of nondeterminism: collective reduction order, nondeterministic GPU kernels, and per-rank RNG divergence. For a synchronous data-parallel job on 1024 GPUs, rank each source by how hard it is to eliminate and estimate the throughput cost of doing so (deterministic kernels, fixed reduction order, per-rank seed management). Then argue, using the statistical-equivalence target of Chapter 5, for which kinds of run (a production pretraining job, a published benchmark for a paper, a debugging session chasing a numerical bug) bitwise determinism is worth its cost and for which it is not.