"They handed me a million tokens and asked me to attend to all of them. I attended to a quarter, mailed my keys to my neighbor, and we all agreed never to write the full matrix down."
A Device Holding Only Its Share of the Sequence
Long contexts blow up activation memory, not parameter memory, so the model can fit on your devices while a single forward pass still runs out of room. Sequence and context parallelism split the sequence dimension across devices, and Ring Attention lets each device see every position by passing key-value blocks around a ring instead of materializing the full attention. The other axes in this chapter (tensor, pipeline, and sharded data parallelism) all shrink something that grows with the model. Sequence parallelism is the axis that shrinks something that grows with the input length. When a context reaches tens of thousands or millions of tokens, the attention activations alone exceed any single accelerator, and the only escape is to give each device a slice of the tokens and teach the devices to exchange just enough to compute attention correctly. This section shows why the pressure is real, why attention is the hard part, and how a ring of key-value exchanges resolves it exactly.
Every other form of parallelism in this chapter answers the question "the model is too big for one device, how do we split it?" Sequence parallelism answers a different question: "the model fits, but a single long input does not, how do we split the input?" The distinction matters because the two pressures scale with different quantities. Parameter memory is fixed once you choose an architecture; it does not change when a user pastes a longer document. Activation memory, the intermediate tensors a forward pass must keep so the backward pass can use them, grows with how many tokens you process at once. For most of deep learning's history sequences were short enough that this growth was a rounding error. Long-context language models changed that, and they changed it most sharply inside attention. Figure 16.7.1 previews the fix this section builds: split the sequence across devices and rotate key-value blocks around a ring so every query still attends to every position.
1. The Activation Memory That Grows With the Sequence Intermediate
Consider a single attention head over a sequence of length $S$ with head dimension $d$. The query, key, and value tensors each cost $S \cdot d$ numbers, which grows linearly in $S$ and is rarely the problem. The problem is the attention score matrix. To compute which positions attend to which, the head forms the product $Q K^\top$, a matrix of shape $S \times S$, applies a softmax over each row, and multiplies by $V$. A naive implementation stores that $S \times S$ matrix, and storing it costs
$$M_{\text{attn}} = b \cdot h \cdot S^2 \quad\text{bytes per layer},$$where $b$ is the batch size, $h$ the number of heads, and the cost is quadratic in $S$. At $S = 2{,}000$ this is a nuisance; at $S = 128{,}000$ it is roughly four thousand times larger, and at a million tokens it is hopeless on any single device. FlashAttention (covered as a per-node enabler in Chapter 22) removes the need to store the full matrix on one device by streaming it in blocks with an online softmax, which lowers the constant dramatically, but the queries, keys, values, and the layer's other activations still grow linearly in $S$, and linear growth in a million tokens is still more than one accelerator holds. The parameters, meanwhile, have not changed at all. This is the signature of an activation-bound regime: the model fits, the input does not.
Tensor, pipeline, and sharded data parallelism all attack memory that scales with the model: parameters, gradients, optimizer state. None of them help when the model already fits and a single long input is what overflows. Activation memory scales with the number of tokens processed at once, and inside attention the naive score matrix scales as $S^2$. Sequence parallelism is the axis that splits along the token dimension, so each device holds the activations for only its slice of the sequence. Reach for it when the binding ceiling moves with context length rather than with model size.
2. Splitting the Sequence, and Why Attention Fights Back Intermediate
The idea is the same regrouping that made data parallelism exact in Section 1.1, applied to a new dimension. Cut the sequence of $S$ tokens into $P$ contiguous blocks and give block $d$ to device $d$. For any operation that acts on each token independently, this is trivially correct and embarrassingly parallel: a feed-forward layer, a layer normalization, or an element-wise activation touches token $i$ using only token $i$'s vector, so each device processes its own block and nothing needs to move. Each device now holds the activations for $S/P$ tokens, cutting the per-device activation memory by a factor of $P$ for those layers at zero communication cost.
Attention is the operation that refuses to cooperate. By design, the output at position $i$ depends on the keys and values at every position $j$, because attention is an all-to-all interaction over the sequence. A device that holds only query block $Q_d$ can compute scores against its local keys, but to finish the softmax it needs the keys and values from every other block, which live on other devices. So sequence parallelism forces a communication pattern across the very dimension it just split, and the question becomes how to supply each device with all the keys and values without (a) gathering the entire key-value tensor onto every device, which would undo the memory saving, or (b) forming the full $S \times S$ score matrix anywhere, which is the thing we cannot afford.
Picture four students each assigned to summarize a quarter of a very long book, but the summary of any page may depend on any other page. Photocopying the entire book for all four defeats the point. Instead each student keeps their quarter and, on a timer, slides their stack to the person on their right and receives a stack from the left. After four passes everyone has skimmed every page exactly once, and at no moment did anyone hold more than two quarters at a time. That timer-and-slide ritual is Ring Attention, and the students never needed a table big enough for the whole book.
3. Ring Attention: Passing Keys and Values Around a Ring Advanced
Ring Attention resolves the tension by arranging the $P$ devices in a logical ring and rotating the key-value blocks around it. Each device keeps its query block $Q_d$ fixed and, at each step, holds one key-value block. It computes the partial attention of its queries against the block currently in hand, then sends that block to the next device on the ring and receives a new block from the previous device. After $P$ steps, each query block has met every key-value block exactly once, so every query has attended to every position, and the send of one block overlaps with the compute on the previous block, hiding much of the communication behind useful work. This ties directly to the ring and all-to-all collectives of Chapter 4: a ring all-reduce moves a fixed payload hop by hop around exactly this topology, and Ring Attention reuses the structure to move key-value blocks instead of gradient shards.
The piece that makes this exact rather than approximate is the online softmax, the same running-accumulation trick FlashAttention uses on a single device. A device cannot normalize its softmax until it has seen all the keys, but it does not need to wait. It maintains a running row-maximum $m$, a running denominator $\ell$, and a running weighted sum of values, and as each new block arrives it rescales the accumulation by $e^{m_{\text{old}} - m_{\text{new}}}$ to keep the numbers stable and the result identical to a single global softmax. Writing the softmax this way,
$$\text{attn}(Q_d, K, V)_i = \frac{\sum_{j} e^{s_{ij} - m_i}\, v_j}{\sum_{j} e^{s_{ij} - m_i}}, \qquad s_{ij} = \frac{q_i \cdot k_j}{\sqrt{d}},\quad m_i = \max_j s_{ij},$$shows that the numerator and denominator are both plain sums over keys $j$, and a sum can be accumulated one block at a time regardless of the order the blocks arrive. That is exactly the property that let data parallelism regroup a gradient sum across workers; here we regroup an attention sum across key blocks travelling on a ring. The code below makes the claim concrete: it splits a sequence across $P$ "devices," runs the ring exchange with an online softmax, and checks the result against full single-device attention.
import numpy as np
rng = np.random.default_rng(0)
S, d, P = 12, 8, 4 # sequence length, head dim, number of devices (ring)
assert S % P == 0
block = S // P # tokens (queries) held per device
# One attention head over the FULL sequence (the single-device reference).
Q = rng.standard_normal((S, d))
Kmat = rng.standard_normal((S, d))
V = rng.standard_normal((S, d))
scale = 1.0 / np.sqrt(d)
def softmax_rows(z):
z = z - z.max(axis=1, keepdims=True)
e = np.exp(z)
return e / e.sum(axis=1, keepdims=True)
# Reference: materialize the full S x S attention matrix on one device.
ref = softmax_rows((Q @ Kmat.T) * scale) @ V
# --- Sequence/context parallelism with a ring KV exchange ---------------
# Each device owns query block Q_d and, initially, its own K_d, V_d block.
# We never form the full S x S matrix on any device. Instead we use the
# online-softmax (FlashAttention-style) running accumulation so a device can
# fold in one KV block at a time as the blocks travel around the ring.
def blocks(M):
return [M[i*block:(i+1)*block] for i in range(P)]
Qb, Kb, Vb = blocks(Q), blocks(Kmat), blocks(V)
out = [None] * P
for dev in range(P):
q = Qb[dev] # this device's queries, never moves
m = np.full((block, 1), -np.inf) # running row-max
l = np.zeros((block, 1)) # running denominator
acc = np.zeros((block, d)) # running weighted sum of V
# Ring: step t, this device processes the KV block P-t hops upstream.
for t in range(P):
src = (dev - t) % P # which KV block is here after t passes
k, v = Kb[src], Vb[src]
s = (q @ k.T) * scale # block scores, shape block x block
m_new = np.maximum(m, s.max(axis=1, keepdims=True))
alpha = np.exp(m - m_new) # rescale prior accumulation
p = np.exp(s - m_new) # local softmax numerators
l = alpha * l + p.sum(axis=1, keepdims=True)
acc = alpha * acc + p @ v
m = m_new
# In a real system: send (k, v) to the next device, receive the next block.
out[dev] = acc / l
ring = np.vstack(out)
print("sequence length S :", S)
print("ring devices P :", P)
print("queries per device :", block)
print("full S x S matrix on device: never (online softmax, block by block)")
print("max abs difference :", f"{np.max(np.abs(ring - ref)):.2e}")
print("relative error :", f"{np.linalg.norm(ring - ref) / np.linalg.norm(ref):.2e}")
ref forms the full $S \times S$ matrix on one device; the ring path keeps each device's queries pinned and folds in one key-value block per step with an online softmax, so no device ever stores the full matrix yet the two outputs must match exactly.sequence length S : 12
ring devices P : 4
queries per device : 3
full S x S matrix on device: never (online softmax, block by block)
max abs difference : 3.33e-16
relative error : 2.27e-16
The difference is zero up to floating-point rounding, exactly as in the gradient identity of Section 1.1. The point is not that the numbers are small; it is that splitting the sequence and exchanging key-value blocks is an exact reorganization of the same computation, not an approximation you tolerate for the sake of memory. The full attention matrix, the object whose $S^2$ size was the whole problem, is never built on any device. Each device's peak activation footprint for attention scales with $S/P$ queries against one block at a time, so $P$ devices in a ring stretch the affordable context length by roughly a factor of $P$.
Data parallelism (Section 1.1) was exact because a gradient is an average and an average regroups freely across workers. Ring Attention is exact for the same reason: an attention output is a normalized sum over keys, and a sum regroups freely across the key blocks that travel the ring. The recurring move of this book is to find the associative reduction hiding inside an expensive operation and stream it across devices. Whenever a later method claims to split an operation "for free," ask which sum it is regrouping and which collective carries the partial results; here the collective is the ring exchange of Chapter 4, now carrying keys and values instead of gradients.
4. Composing With Tensor Parallelism Advanced
Sequence parallelism is most useful in combination, not in isolation. Tensor parallelism, treated earlier in this chapter, splits the hidden dimension: each device holds a slice of every weight matrix and the devices all-reduce within a layer. It shrinks parameter and per-layer activation memory along the feature axis but leaves every device holding the full sequence of tokens for the parts between attention and the matrix multiplies, the layer norms, dropouts, and residual adds, whose activations still scale with $S$. Megatron-style sequence parallelism closes that gap by splitting those non-attention activations along the sequence axis too, inserting reduce-scatter and all-gather operations (the sharded-collective cousins of all-reduce from Chapter 4) at the boundaries where the two partitionings meet. The hidden dimension is split where tensor parallelism wants it; the sequence dimension is split where sequence parallelism wants it; and the collectives convert between the two layouts.
The two axes are complementary because they cut orthogonal dimensions: tensor parallelism cuts the feature axis, sequence parallelism cuts the token axis, and context parallelism (the ring attention of Section 3) cuts the token axis specifically inside the attention all-to-all. A production long-context training job typically stacks all of them with pipeline and sharded data parallelism, a layered scheme often called multi-dimensional or "4D" parallelism, where sequence and context parallelism are the dimensions that earn their keep precisely when the context is long. We assemble the full stack and reason about which axis to grow first under a fixed device budget in Chapter 19.
Code 16.7.1 spelled out the ring exchange and the online-softmax rescaling by hand, roughly thirty lines and the part most likely to harbor a numerical bug. Production frameworks expose the entire scheme as a parallelism dimension you size like any other. In a Megatron-style launcher it is a single argument that splits the sequence and inserts the ring collectives and reduce-scatter/all-gather boundaries for you:
# Megatron-LM style launch: split the sequence across a ring of 4 devices,
# composed with tensor parallelism over 8 devices, for long-context training.
torchrun --nproc_per_node=32 pretrain_gpt.py \
--tensor-model-parallel-size 8 \
--context-parallel-size 4 \
--seq-length 131072 \
--use-flash-attn
--context-parallel-size 4 replaces the hand-written ring loop and online-softmax bookkeeping; the framework places the ring, overlaps the key-value sends with attention compute, and fuses the FlashAttention kernel that Chapter 22 unpacks. Megatron-LM, DeepSpeed-Ulysses, and PyTorch's context-parallel APIs all expose the dimension this way.5. Why This Matters for Long-Context Models Intermediate
Sequence and context parallelism are not a curiosity; they are the reason today's frontier models can advertise context windows of hundreds of thousands or millions of tokens. A model that reasons over an entire codebase, a book-length document, an hour of transcribed audio, or a long agent trajectory must hold that whole input in attention at once, and the $S^2$ activation pressure of Section 1 makes that impossible on one device past a few tens of thousands of tokens. Context parallelism is what converts "the model architecturally supports a long context" into "a real cluster can actually train and run it." It is one of the load-bearing techniques behind training the foundation models of Chapter 19, and it reappears at serving time, where a long prompt must be held across several devices to fit the key-value cache, in the distributed LLM serving of Chapter 24.
Who: An ML systems engineer at a startup fine-tuning a code model to read whole repositories in one prompt.
Situation: The 13-billion-parameter model fit comfortably on a node of eight 80 GB GPUs with tensor parallelism, and training at a 4K context was stable.
Problem: Pushing the context to 256K tokens crashed with out-of-memory in the attention layer, even though the parameters, gradients, and optimizer state all still fit; only the activations had exploded.
Dilemma: Buy more or bigger GPUs to grow parameter-oriented sharding further, which the profiler showed would not help because parameters were not the ceiling, or split the sequence itself, which meant adding a parallelism axis the team had never used.
Decision: They added context parallelism, because the profiler placed every overflowing byte in the per-token attention and feed-forward activations, the exact memory that scales with $S$ and that splitting the sequence reduces by the ring degree.
How: They set --context-parallel-size 4 on top of their existing tensor-parallel-8 layout and enabled FlashAttention, turning each node into a ring of four sequence shards, no model-code changes required.
Result: The 256K-token forward and backward passes fit with headroom, throughput dropped only modestly because the ring key-value sends overlapped the attention compute, and the loss curve matched a (memory-impossible) single-device reference to numerical precision, exactly as Output 16.7.1 predicts.
Lesson: When the profiler says activations, not parameters, are overflowing, split the sequence, not the model. Growing a parameter-oriented axis would have spent hardware on the wrong ceiling.
The technique in Section 3 traces to Ring Attention with Blockwise Transformers (Liu, Zaharia, and Abbeel, 2024), which showed that a ring of devices can train and run contexts whose length scales with the device count, demonstrated at millions of tokens. DeepSpeed-Ulysses (Jacobs et al., 2024) takes an all-to-all route instead of a ring, scattering attention heads across devices so each device computes full attention for a subset of heads over the whole sequence, and the two patterns are now often combined into hybrid sequence-parallel schemes that pick per-layer between a ring exchange and a head all-to-all. A parallel line attacks the $S^2$ term itself with sparse, sliding-window, and linear-attention variants so that context parallelism and sub-quadratic attention compound, and 2024 to 2026 frontier releases advertising hundred-thousand to million-token windows lean on exactly this combination. The open questions are how to balance the ring's communication against the all-to-all's, how to keep load even when causal masking makes later query blocks cheaper, and how far sub-quadratic attention can push before accuracy on long-range dependencies degrades. We return to the long-context serving side of this story in Chapter 24.
We now have the fourth distinct way this chapter splits a transformer: tensor parallelism cuts the hidden dimension, pipeline parallelism cuts the layer dimension, sharded data parallelism cuts the parameter and optimizer state, and sequence/context parallelism cuts the token dimension to tame attention's $S^2$ activation cost. All four reduce some flavor of memory, and all four pay for it in communication that the collectives of Chapter 4 carry. There is one more lever that is not a parallelism axis at all but a per-device technique that multiplies what every one of these axes can fit: trading stored activations for recomputed ones. That is the subject of the next section.
A team reports that their 7-billion-parameter model trains fine at a 8K context on a single 80 GB GPU but runs out of memory at 64K context on the same GPU. (a) Explain, in terms of the $M_{\text{attn}} = b \cdot h \cdot S^2$ scaling from Section 1, why the parameter count being unchanged is consistent with the crash. (b) State why adding more sharded data-parallel replicas (the FSDP axis from earlier in this chapter) would not fix the crash, and why sequence parallelism would. (c) The team asks whether FlashAttention alone, without any cross-device split, removes the need for sequence parallelism. Give the regime in which it does and the regime in which it does not.
Modify Code 16.7.1 to compute causal attention, where query position $i$ may attend only to key positions $j \le i$. First fix the single-device reference by masking the upper triangle of $Q K^\top$ before the softmax. Then change the ring loop so each device skips entire key-value blocks that lie strictly in the future of all its queries, and applies a triangular mask only on the one diagonal block where past and future meet. Verify the ring output still matches the masked reference to floating-point precision, and report how many block-pairs of work the causal version skips compared to the full version. Explain why this skipping makes load uneven across the ring and which devices end up with the least work.
Suppose each of $P$ devices holds a key-value block of $B = (S/P) \cdot d$ numbers per tensor (4 bytes each, two tensors for $K$ and $V$) and the ring links move data at bandwidth $\beta$ bytes per second. (a) Estimate the total bytes a single device sends over a full attention pass under the ring scheme, and compare it to the bytes it would send if instead every device all-gathered the entire key-value tensor onto itself. (b) For which ratio of compute time per block to per-hop transfer time does the ring's overlap of communication with computation hide the transfer entirely? (c) Argue from these two numbers why the ring is preferred over a naive all-gather for long sequences, and connect your reasoning to the alpha-beta communication-cost model of Chapter 3.