Part IV: Parallel Deep Learning and Large Models
Chapter 16: Model, Pipeline, and Sharded Parallelism

Activation Checkpointing as a Per-Node Enabler

"I threw away everything I learned on the way forward, confident I could remember it all again on the way back. The accountant called this a memory saving. My clock called it overtime."

A Layer That Recomputes Itself
Big Picture

Every parallelism strategy in this chapter splits a model that is too big across devices, but each device still has to hold the activations of the layers it owns, and those activations, not the parameters, are usually what overflows first. Activation checkpointing is a per-device knob, not a distribution method: it stores only a sparse set of activation snapshots on the forward pass and recomputes the rest on demand during the backward pass. It trades one extra forward pass for a drop in activation memory from $O(L)$ to $O(\sqrt{L})$ in the depth $L$ of the network. That single trade is what lets the tensor, pipeline, sharded, and sequence parallelism of this chapter actually fit a real model and a real context length on the hardware you have. This section explains the trade, quantifies its compute cost, and shows where it sits relative to the distribution methods that surround it.

The previous section pushed the sequence dimension across devices so that very long contexts would fit, and in doing so it made the same point from a different angle: as Section 16.7 showed, activation memory grows with both batch size and sequence length, and it does so faster than parameter memory once the model is sharded. Tensor parallelism (Section 16.2), pipeline parallelism (Section 16.3), and the ZeRO and FSDP family (Section 16.4) all attack the cost of parameters, gradients, and optimizer state. None of them, on their own, attacks the cost of the activations that every layer leaves behind for its backward pass. That is the gap activation checkpointing fills, and it fills it on one device at a time.

It belongs in this chapter, rather than in the per-node serving chapter (Chapter 22) where the other scale-up techniques live, for one reason: without it, the distribution strategies you just learned frequently will not fit at all. A model can be perfectly sharded across thirty-two devices and still fail to train because each shard's activations exceed its local memory. Checkpointing is the enabler that closes that final gap, so we treat it here as the per-node prerequisite that makes distribution practical, clearly labeled as a single-device technique that supports the distributed methods around it rather than replacing them.

1. Where Activation Memory Comes From Beginner

To train a network by backpropagation, the backward pass needs the intermediate values that the forward pass produced. When you compute the gradient of the loss with respect to a layer's input, the chain rule asks for that layer's input (and often its output) as it was during the forward pass. The standard implementation therefore stores every layer's activation tensor as it goes forward, and keeps all of them alive until the matching backward step consumes them. For a network of $L$ layers, peak activation memory is proportional to $L$: every layer contributes one tensor that must survive from its forward step until its backward step.

This is the cost that the parallelism strategies in this chapter do not reduce. Sharding the parameters with ZeRO splits the optimizer state across data-parallel workers, but each worker still runs a full forward and backward over its microbatch and so still materializes a full stack of activations. Tensor parallelism splits each layer's matrices, which shrinks the activation of that one layer by the tensor-parallel degree, but the number of layers, and therefore the depth of the activation stack, is unchanged. The depth term $L$ is exactly what checkpointing targets.

Key Insight: Activations, Not Parameters, Are the Memory You Forgot to Budget

The parallelism methods of this chapter shrink the per-device cost of parameters, gradients, and optimizer state. They leave the activation stack, whose peak grows linearly with network depth and with batch and sequence length, almost untouched. On large models with long contexts the activation stack routinely dwarfs the sharded parameter footprint, so a model that "fits" by parameter count can still refuse to train. Activation checkpointing is the one knob that attacks that stack directly, and it works on each device independently, which is why it composes with every distribution strategy at once.

2. The Trade: Store a Few, Recompute the Rest Intermediate

Activation checkpointing, also called gradient checkpointing, breaks the network into segments. On the forward pass it stores the activation only at the boundary between segments (the checkpoints) and discards everything inside each segment. When the backward pass reaches a segment, it reruns that segment's forward computation, starting from the stored boundary checkpoint, to regenerate the discarded activations just in time to use them, then frees them again. The parameters and the final loss are identical; the only thing that changes is when intermediate activations exist.

The memory and compute costs depend on how many checkpoints you keep. Suppose you place a checkpoint every $s$ layers, giving $L/s$ stored checkpoints. Peak memory is the stored checkpoints plus the one segment currently being recomputed, roughly $(L/s + s)$ activation tensors. That expression is minimized when the two terms balance, at $s = \sqrt{L}$, which gives a peak of about $2\sqrt{L}$ tensors instead of $L$. This is the classic square-root rule: with optimal checkpoint placement, activation memory drops from $O(L)$ to $O(\sqrt{L})$, while the extra compute is exactly one additional forward pass over the network, because every discarded segment is recomputed once.

Figure 16.8.1 contrasts the two regimes: the baseline that keeps the full activation stack alive against the checkpointed stack that keeps only sparse snapshots and recomputes the segments between them on the way back.

Forward pass over layers 1 ... L (here L = 8) Full activations: store every layer → O(L) memory L1 L2 L3 L4 L5 L6 L7 L8 8 tensors held Checkpointing: store only segment boundaries → O(√L) memory L1 ★ L2 L3 L4 ★ L5 L6 L7 ★ L8 3 checkpoints (★) held at a time ★ stored on forward recomputed during backward cost: one extra forward pass
Figure 16.8.1: Full activation storage (top) keeps one tensor per layer alive, so peak memory grows as $O(L)$. Checkpointing (bottom) stores activations only at segment boundaries (the starred layers) and recomputes the dashed intermediate layers from the nearest stored checkpoint when the backward pass needs them. With boundaries every $\sqrt{L}$ layers, peak memory is $O(\sqrt{L})$ and the price is exactly one extra forward pass.

3. Modeling the Memory and the Overhead Intermediate

The square-root rule is easy to state but worth seeing as concrete numbers, because the memory win grows with depth while the compute cost stays flat. The code below models a stack of $L$ identical layers, charging one activation-memory unit and one forward-compute unit per layer. It compares storing every activation against the optimal $s = \sqrt{L}$ placement, and it turns the extra recomputation into a step-time overhead by assuming the backward pass costs roughly twice a forward pass, so a normal step is about three forward-units and a checkpointed step adds one more.

import math

# A transformer-style stack of L identical layers. Each layer, on the forward
# pass, produces one activation tensor of "act" bytes that the backward pass
# needs. We model only the activation memory (the part checkpointing attacks).
act = 1.0          # activation memory per layer, in arbitrary units
fwd = 1.0          # forward compute per layer, in arbitrary units

def no_checkpoint(L):
    # Store every layer's activation: O(L) memory, one forward + one backward.
    mem = L * act
    return mem, L * fwd, 0.0

def sqrt_checkpoint(L):
    # Optimal placement: keep activations only at sqrt(L) segment boundaries.
    seg = max(1, round(math.sqrt(L)))      # layers per segment
    n_ckpt = math.ceil(L / seg)            # number of stored checkpoints
    # Peak = stored checkpoints + one segment recomputed in flight.
    mem = (n_ckpt + seg) * act
    return mem, L * fwd, L * fwd           # recompute = one extra forward

print(f"{'L':>6} {'full mem':>9} {'ckpt mem':>9} {'mem saved':>10} "
      f"{'fwd+bwd cost':>13} {'extra compute':>14}")
print("-" * 70)
for L in (4, 16, 64, 256, 1024):
    fm, ff, _ = no_checkpoint(L)
    cm, cf, cr = sqrt_checkpoint(L)
    base_step = ff + 2 * ff                 # forward + backward (~2x forward)
    ckpt_step = cf + 2 * cf + cr            # same, plus the recompute forward
    overhead = 100.0 * (ckpt_step - base_step) / base_step
    print(f"{L:>6} {fm:>9.0f} {cm:>9.0f} {100*(1-cm/fm):>9.0f}% "
          f"{'+%.0f%%' % overhead:>13} {cr:>13.0f}u")

# One very deep stack to show the sqrt scaling pull apart from linear.
deep = 10000
full_mem, _, _ = no_checkpoint(deep)
ckpt_mem, _, _ = sqrt_checkpoint(deep)
print(f"\nL = {deep}: full = {full_mem:.0f}u  ~ O(L)")
print(f"L = {deep}: ckpt = {ckpt_mem:.0f}u  ~ O(sqrt(L)) = {ckpt_mem:.0f}u")
print(f"activation memory reduced {full_mem/ckpt_mem:.0f}x for a {deep}-layer stack")
Code 16.8.1: A from-scratch model of activation memory with and without checkpointing across network depth. The memory columns compare $O(L)$ storage against the $O(\sqrt{L})$ optimal placement; the overhead column turns the single recompute forward into a percentage step-time cost.
     L  full mem  ckpt mem  mem saved  fwd+bwd cost  extra compute
----------------------------------------------------------------------
     4         4         4         0%          +33%             4u
    16        16         8        50%          +33%            16u
    64        64        16        75%          +33%            64u
   256       256        32        88%          +33%           256u
  1024      1024        64        94%          +33%          1024u

L = 10000: full = 10000u  ~ O(L)
L = 10000: ckpt = 200u  ~ O(sqrt(L)) = 200u
activation memory reduced 50x for a 10000-layer stack
Output 16.8.1: As depth grows, the memory saving climbs from nothing at $L=4$ to 94 percent at $L=1024$ and a 50-fold reduction at $L=10000$, while the recompute overhead stays pinned at one extra forward pass. The memory shrinks like $\sqrt{L}$; the compute cost does not.

Two facts stand out. First, the saving is negligible for shallow stacks and dramatic for deep ones, which is exactly why checkpointing matters for the large models this chapter distributes and not for a small classifier. Second, the compute overhead is flat: it is always one extra forward pass, regardless of depth. In this idealized accounting it shows up as a +33 percent step time because the recompute forward is one unit added to a three-unit step. Real networks see less, because not every operation is recomputed (matrix multiplies dominate the kept compute) and the recompute overlaps with other work, which is why measured wall-clock overhead in practice is typically in the 20 to 30 percent range rather than the worst-case third.

Fun Note: The Hard Drive That Learned to Forget on Purpose

There is an old joke that the fastest way to free up memory is to throw your data away and hope you do not need it. Activation checkpointing is that joke taken seriously, with a safety net: it throws the activations away, but it keeps the recipe (the stored checkpoint and the layer code) to bake them again on demand. The network ends up doing its homework twice, once on the way out and once on the way back, purely so it never has to carry the whole stack of completed homework around at once.

4. When the Recompute Cost Is Worth It Advanced

Checkpointing is a lever, not a default. You reach for it when memory is the binding ceiling and you have compute to spare, which is the common situation when training large models: the accelerators are not saturated because the batch had to be shrunk to fit, and trading 20 to 30 percent more compute to unlock a much larger batch or a much longer context is a clear win. You leave it off when compute is the ceiling and memory is comfortable, because then the extra forward pass is pure waste. The decision is the same memory-versus-compute trade that runs through the performance models of Chapter 3, applied to one device.

When even checkpointing is not enough, the next rung is to move activations off the accelerator entirely. Activation offload copies the stored checkpoints to CPU host memory, or in the most aggressive setups to NVMe storage, and streams them back when the backward pass needs them. This buys a great deal more headroom (host memory and disk are far larger than accelerator memory) at the cost of bandwidth: the activations now traverse the much slower host link or storage bus, so offload is slower than recomputation and is used only when the model would otherwise not fit at all. The same offloading idea applies to parameters and optimizer state in the ZeRO-Infinity design discussed in Section 16.4; here it is the activations that take the trip.

Practical Example: Doubling the Context Window Without a Single New GPU

Who: A platform engineer extending a 13-billion-parameter model from a 4K to an 8K training context on an existing eight-GPU node.

Situation: The model was already sharded with FSDP and fit at 4K, but every attempt to double the sequence length hit an out-of-memory error during the backward pass.

Problem: As Section 16.7 made explicit, activation memory grows with sequence length, and at 8K the activation stack overflowed each shard even though the parameters fit comfortably.

Dilemma: Rent a node with larger-memory accelerators, doubling the hourly cost, or enable activation checkpointing and pay in compute time instead of dollars.

Decision: They wrapped each transformer block in checkpointing, because compute headroom existed (the GPUs were under-utilized at the smaller batch) while memory did not.

How: One call to wrap each block in torch.utils.checkpoint through the FSDP checkpoint wrapper, with no change to the model code itself.

Result: The 8K context fit on the same hardware, step time rose by about 24 percent, and the training run proceeded at the same dollar cost per hour, exactly the memory-for-compute trade Output 16.8.1 predicts.

Lesson: When activation memory is the binding ceiling and the accelerators have spare cycles, checkpointing converts an out-of-memory failure into a modest slowdown, and it does so on the device you already have.

5. A Per-Device Knob That Makes Distribution Fit Intermediate

It is worth being precise about what checkpointing is and is not. It is not a distribution method: it moves no data between machines, runs no collective, and changes nothing about how the model is sharded. It is a per-device memory technique, in the same family as the mixed-precision and quantization tricks of Chapter 22. What earns it a place inside this chapter on parallelism is that it is frequently the difference between a distribution plan that fits and one that does not. Sharding answers "how do we split the model across devices?"; checkpointing answers "how do we make each device's share actually fit?", and large-model training needs both answers at once.

Thesis Thread: A Scale-Up Knob in Service of Scale-Out

This book leads with scale-out and treats single-node efficiency as a labeled prerequisite, not a main event. Activation checkpointing is the cleanest example of why the prerequisite matters: the tensor, pipeline, and sharded parallelism of this chapter are the distribution; checkpointing is the per-node enabler that lets the distributed plan land on real hardware. The combination is what makes foundation-model training feasible, which is why Chapter 19 assumes checkpointing is on by default before it composes the parallelism axes into a full training recipe. Scale-up here is not a competitor to scale-out; it is the floor that scale-out stands on.

Library Shortcut: torch.utils.checkpoint Does the Recompute for You

The from-scratch model in Code 16.8.1 counted tensors; a real implementation must intercept the autograd graph so that the discarded activations are regenerated at exactly the right moment. PyTorch handles all of that bookkeeping behind a single wrapper. You replace a direct call to a block with a checkpointed one, and the framework discards the block's internal activations on the forward pass and reruns the block during backward to recover them:

import torch
from torch.utils.checkpoint import checkpoint

# Instead of: out = block(x)
out = checkpoint(block, x, use_reentrant=False)   # discard + recompute managed

# For a whole stack, wrap each transformer block. With FSDP, the
# checkpoint wrapper composes directly with sharding:
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
    checkpoint_wrapper, CheckpointImpl)
block = checkpoint_wrapper(block, checkpoint_impl=CheckpointImpl.NO_REENTRANT)
Code 16.8.2: The same store-a-few-recompute-the-rest trade from Output 16.8.1, now as one wrapper call. PyTorch manages the autograd graph surgery, the activation discard, and the just-in-time recomputation; the FSDP checkpoint_wrapper makes it compose with the sharding of Section 16.4 so a single device's shard fits.

6. Where It Sits in the Stack Beginner

Checkpointing slots in cleanly under everything else in this chapter. You first choose how to distribute the model (tensor, pipeline, sharded, sequence, or a mix), then turn on checkpointing per device to shrink the activation stack that the chosen distribution still leaves on each accelerator. Because it is orthogonal to the distribution axes, it composes with all of them, which is precisely the property that lets the next section stack several axes together. With the activation cost under control on each device, we can return to the question of how to combine data, tensor, pipeline, and expert parallelism into a single coherent training plan, which is the subject of Section 16.9 on 3D and 4D parallelism.

Research Frontier: Selective and Compute-Aware Recomputation (2024 to 2026)

Plain checkpointing recomputes whole blocks, which can waste effort on cheap operations. Selective activation recomputation, introduced in the Megatron line of work (Korthikanti et al., 2023) and now standard in 2024 to 2026 training stacks, recomputes only the memory-heavy and compute-cheap parts (such as attention's intermediate tensors) while keeping the expensive matrix-multiply outputs, cutting the overhead well below the naive one-extra-forward bound. A parallel thread treats checkpoint placement as an optimization problem: tools in the lineage of Rockmate and the FlashAttention memory model solve for which layers to checkpoint given a memory budget, rather than defaulting to the uniform $\sqrt{L}$ rule, and recent systems fold activation offload to CPU and NVMe into the same solver so that recomputation and offload are chosen per tensor. The frontier is no longer whether to checkpoint but how to spend a fixed memory budget across recompute, offload, and sharding so that the distributed plan of Section 16.9 trains as fast as the hardware allows.

Exercise 16.8.1: Read the Square-Root Rule Conceptual

Using the model behind Code 16.8.1, explain in words why placing a checkpoint every $s$ layers gives a peak of about $(L/s + s)$ activation tensors, and why that expression is minimized at $s = \sqrt{L}$. Then answer: if a colleague keeps a checkpoint every two layers on a 256-layer network, are they over-spending memory or over-spending compute relative to the optimum, and roughly by how much on each axis? State why the saving at $L = 4$ in Output 16.8.1 is zero.

Exercise 16.8.2: Add Activation Offload to the Model Coding

Extend Code 16.8.1 with a third strategy, offload_checkpoint(L), that stores the $\sqrt{L}$ boundary checkpoints in "host memory" (assume effectively unlimited) so on-accelerator peak memory is only the single segment being recomputed, about $\sqrt{L}$ tensors, but charges a transfer cost proportional to the number of offloaded tensors at a per-tensor bandwidth penalty you pick (for example, 5 compute-units per offloaded tensor moved each way). Tabulate on-device memory and total step cost for $L \in \{64, 256, 1024\}$ against the two existing strategies, and identify the depth at which offload's transfer cost exceeds the recompute cost of plain checkpointing.

Exercise 16.8.3: Decide Whether to Turn It On Analysis

A training run on one node has accelerators that are 55 percent utilized (memory-bound, batch shrunk to fit) and a step time of 800 ms. Enabling activation checkpointing would let you triple the microbatch, lifting per-step throughput, but adds a 25 percent recompute overhead to each step. Estimate the new step time and the new effective samples-per-second, and argue from these numbers whether checkpointing is worth it here. Then describe one regime, in terms of accelerator utilization and the memory headroom of Chapter 3, where the same change would be a net loss.