"They told me to score every token against every other token, so I built a giant table in the slow memory. By the time I finished writing it down, the accelerator had gone to sleep waiting for me."
An Attention Matrix That Never Fit in Cache
Attention is the single most expensive operation in a transformer, and the naive way to compute it wastes the accelerator by shuttling a quadratic intermediate through slow memory. FlashAttention rewrites the same exact computation as a fused, tiled kernel that keeps its working set in fast on-chip memory, cutting the memory footprint from quadratic to linear and turning a memory-bound operation into a compute-bound one. This is a scale-up technique: it lives inside one accelerator and changes nothing about the mathematics of attention. But it is a prerequisite for everything in this part, because the per-node throughput it unlocks is exactly the number that distribution then multiplies across a serving fleet. The same tiling-and-online-softmax idea also turns out to be the mechanism that lets attention itself be split across machines, which is why this kernel-level section sits at the foundation of distributed serving.
A transformer spends most of its time in two places: the large matrix multiplications of its feed-forward layers, and the attention operation that lets every token look at every other token. The feed-forward work is dense linear algebra that modern accelerators handle near their peak rate. Attention is different. Written the obvious way, it builds an $N \times N$ table of pairwise scores for a sequence of length $N$, writes that table to the accelerator's main memory, reads it back to normalize it, and reads it a third time to weight the values. The arithmetic is modest; the traffic to and from slow memory is enormous. The result is an operation that leaves the compute units idle while they wait on memory, and whose footprint grows with the square of the sequence length, so a long context can exhaust device memory before it exhausts compute. FlashAttention removes both problems at once without changing a single output value.
1. Why Naive Attention Is Memory-Bound Intermediate
Scaled dot-product attention takes three matrices: queries $Q$, keys $K$, and values $V$, each with one row per token. For a sequence of length $N$ and head dimension $d$, the definition is
$$\mathrm{Attention}(Q, K, V) = \mathrm{softmax}\!\left(\frac{QK^{\top}}{\sqrt{d}}\right) V, \qquad S = \frac{QK^{\top}}{\sqrt{d}} \in \mathbb{R}^{N \times N}.$$The intermediate score matrix $S$ has $N^2$ entries. For a context of $N = 8192$ tokens, that is roughly 67 million numbers per attention head, per layer, and a model has dozens of heads in dozens of layers. The naive kernel materializes $S$ in the accelerator's high-bandwidth memory (HBM), then reads it back to apply the softmax row by row, then reads it again to multiply by $V$. Counting bytes moved, the operation reads and writes on the order of $N^2$ words while performing on the order of $N^2 d$ floating-point operations. With $d$ in the low hundreds, the ratio of compute to memory traffic is small, and that ratio is precisely what the roofline model of Section 3.7 uses to predict whether an operation is limited by arithmetic or by bandwidth. Naive attention sits on the memory-bound side of the roofline: the accelerator's arithmetic units are starved, waiting on HBM.
The second cost is capacity, not just bandwidth. Because $S$ must exist in memory all at once, peak memory grows as $N^2$. Doubling the context quadruples the attention footprint, and at long context lengths this term, not the model weights, is what overflows the device. Both symptoms, the bandwidth waste and the quadratic capacity, trace to the same root cause: the full $N \times N$ matrix is written to slow memory only to be read straight back. If the matrix never had to leave the chip, neither cost would exist.
Naive attention is slow not because the multiplications are expensive but because it moves an $N \times N$ intermediate in and out of slow memory three times. The fix is not a faster matrix multiply or a cleverer approximation; it is to never write the intermediate to slow memory at all. Whenever an operation is memory-bound on the roofline, the highest-leverage move is to cut bytes transferred, and the cleanest way to cut bytes is to fuse the stages that produce and consume an intermediate so that intermediate lives only in fast on-chip memory.
2. The FlashAttention Idea: Tile and Stream Intermediate
FlashAttention, introduced by Dao and colleagues in 2022, is an IO-aware kernel. It treats the gap between the accelerator's tiny, fast on-chip memory (SRAM) and its large, slow off-chip memory (HBM) as the thing to optimize, and it organizes the computation so that the expensive intermediate never touches HBM. The mechanism is tiling. Split the queries into blocks of rows and the keys and values into blocks of rows. Load one query block into SRAM, then stream the key and value blocks past it one at a time, accumulating the attention output for that query block as you go. Each tile of the score matrix is computed, consumed, and discarded inside SRAM; only the final output, which is the same size as $Q$, is ever written back to HBM. The full $N \times N$ matrix is never formed anywhere.
The obstacle is the softmax. The standard softmax over a row needs the maximum and the sum of exponentials over the whole row, but a tiled kernel sees only one block of that row at a time. The classic answer is the online softmax: maintain a running maximum and a running normalizer, and rescale the partial result whenever a new block reveals a larger maximum. Figure 22.6.1 contrasts the two strategies, the naive full matrix resident in HBM against the FlashAttention tiles that stay in SRAM with the online-softmax state carried between them.
3. Online Softmax, Made Exact Advanced
The online softmax is what makes the tiling exact rather than approximate. Process the score row in blocks. After seeing blocks up to the current one, carry three quantities: the running maximum $m$ of the scores seen so far, the running denominator $l = \sum \exp(s - m)$, and the running output accumulator $o = \sum \exp(s - m)\, v$. When a new block of scores $s^{(j)}$ with values $v^{(j)}$ arrives, compute its local maximum and update the running maximum, then correct the carried state by the factor that accounts for the shift in maximum:
$$m^{\text{new}} = \max\!\big(m,\ \max_t s^{(j)}_t\big), \qquad \alpha = e^{\,m - m^{\text{new}}},$$ $$l \leftarrow \alpha\, l + \sum_t e^{\,s^{(j)}_t - m^{\text{new}}}, \qquad o \leftarrow \alpha\, o + \sum_t e^{\,s^{(j)}_t - m^{\text{new}}}\, v^{(j)}_t.$$The correction factor $\alpha$ rescales the previously accumulated denominator and output so they are expressed relative to the new, larger maximum; the new block is added in the same scale. After the final block, the attention output for the row is $o / l$. Because subtracting the running maximum keeps every exponent at most zero, the computation never overflows, and because the rescaling is algebraically exact, the result equals the softmax computed over the whole row at once. This is the crucial property: FlashAttention is exact, not an approximation. It returns bit-for-bit the same answer as the naive kernel up to floating-point rounding, which is why it can be a drop-in replacement with no accuracy cost, unlike the approximate attention variants of Section 5.
The code below implements both kernels in pure NumPy and checks that they agree. The naive version forms the full $N \times N$ matrix; the tiled version carries the online-softmax state and never allocates anything larger than a single tile. Each reports the size of its largest live intermediate so the memory gap is concrete.
import numpy as np
rng = np.random.default_rng(0)
N, d = 4096, 64 # sequence length, head dimension
scale = 1.0 / np.sqrt(d)
Q = rng.standard_normal((N, d)).astype(np.float64)
K = rng.standard_normal((N, d)).astype(np.float64)
V = rng.standard_normal((N, d)).astype(np.float64)
# Naive attention: materialize the full N x N score matrix in "HBM".
def naive_attention(Q, K, V):
S = (Q @ K.T) * scale # N x N scores, the quadratic intermediate
S = S - S.max(axis=1, keepdims=True)
P = np.exp(S)
P = P / P.sum(axis=1, keepdims=True)
O = P @ V # N x d output
peak = S.size # floats held at once for the big intermediate
return O, peak
# Tiled / online-softmax attention: never form the full N x N matrix.
# Each query tile streams over key/value tiles, carrying a running max m,
# a running denominator l, and a running output accumulator o.
def flash_attention(Q, K, V, Br=128, Bc=128):
Nq = Q.shape[0]
O = np.zeros_like(Q)
peak = 0
for i in range(0, Nq, Br):
Qi = Q[i:i+Br] # Br x d query tile in "SRAM"
m = np.full((Qi.shape[0], 1), -np.inf) # running row max
l = np.zeros((Qi.shape[0], 1)) # running softmax denominator
o = np.zeros((Qi.shape[0], d)) # running weighted-value sum
for j in range(0, Nq, Bc):
Kj = K[j:j+Bc] # Bc x d key tile
Vj = V[j:j+Bc] # Bc x d value tile
Sij = (Qi @ Kj.T) * scale # Br x Bc tile only
m_new = np.maximum(m, Sij.max(axis=1, keepdims=True))
p = np.exp(Sij - m_new) # rescaled to the new max
alpha = np.exp(m - m_new) # correction for the old block
l = alpha * l + p.sum(axis=1, keepdims=True)
o = alpha * o + p @ Vj # rescale, then add this block
m = m_new
# largest intermediate alive: one tile plus the carried state
peak = max(peak, Sij.size + o.size + l.size + m.size)
O[i:i+Br] = o / l # finalize the tile
return O, peak
O_naive, peak_naive = naive_attention(Q, K, V)
O_flash, peak_flash = flash_attention(Q, K, V)
print("sequence length N :", N)
print("head dimension d :", d)
print("max abs difference :", f"{np.max(np.abs(O_flash - O_naive)):.2e}")
print("relative error :", f"{np.linalg.norm(O_flash - O_naive) / np.linalg.norm(O_naive):.2e}")
print("naive peak intermediate :", f"{peak_naive:,} floats (the full N x N matrix)")
print("flash peak intermediate :", f"{peak_flash:,} floats (one tile + carried state)")
print("memory reduction factor :", f"{peak_naive / peak_flash:.1f}x")
Br x Bc tile plus the small carried state, while the naive kernel holds the entire N x N score matrix at once.sequence length N : 4096
head dimension d : 64
max abs difference : 6.87e-16
relative error : 2.18e-15
naive peak intermediate : 16,777,216 floats (the full N x N matrix)
flash peak intermediate : 24,832 floats (one tile + carried state)
memory reduction factor : 675.6x
The relative error is at the floor of floating-point arithmetic: tiling changes nothing about the answer. What it changes is the peak intermediate, from quadratic in $N$ to a small constant set by the tile size. On a real accelerator the same restructuring keeps that constant-size tile inside SRAM, so the kernel reads $Q$, $K$, and $V$ from HBM once and writes $O$ once, collapsing the three round trips of the naive version into a single fused pass. That is how an operation that was starved on memory bandwidth becomes one that runs near the accelerator's compute peak.
The online softmax is the attention-flavored cousin of a habit every programmer already has: computing a running average over a stream without storing the stream. You keep a running total and a count and update them as numbers arrive. Online softmax keeps a running maximum and a running normalizer and rescales when a bigger number shows up. The hard part is purely bookkeeping, the $\alpha = e^{m - m^{\text{new}}}$ correction, and once you have it, an operation that looked irreducibly all-at-once becomes a tidy fold over blocks.
4. FlashAttention-2, FlashAttention-3, and FP8 Advanced
The original kernel proved the principle; later versions chase the hardware. FlashAttention-2 (Dao, 2023) reorganizes the work to reduce non-matrix-multiply operations, which run far slower than matrix multiplies on tensor cores, and parallelizes across the sequence-length dimension so that even a single long sequence keeps every streaming multiprocessor busy. The reported effect is roughly a doubling of throughput over the first version, reaching a large fraction of the accelerator's theoretical peak. FlashAttention-3 (Shah and colleagues, 2024) targets the Hopper generation specifically, overlapping the softmax computation with the matrix multiplies through asynchronous, warp-specialized scheduling and exploiting the hardware's low-precision tensor cores.
That low precision is the FP8 story. Running the attention matrix multiplies in 8-bit floating point roughly doubles the achievable arithmetic rate again, but naive FP8 would lose too much accuracy in the score and probability values. FlashAttention-3 keeps the running softmax statistics in higher precision and applies block-level scaling so that the bulk of the multiply work runs in FP8 while the numerically sensitive accumulation stays accurate. The pattern is the same one this chapter returns to repeatedly: push the heavy arithmetic into the lowest precision the hardware accelerates, and protect the few quantities that cannot tolerate it. The result is a kernel that, on the right hardware, makes attention so cheap that it stops being the bottleneck at all.
FlashAttention-3 (Shah et al., 2024) brought asynchronous, warp-specialized FP8 attention to Hopper accelerators, reporting throughput well beyond FlashAttention-2 while keeping accuracy through block scaling of the softmax statistics. A parallel line targets the very different shape of inference decoding, where one new query attends over a long cached context: FlashDecoding and FlashDecoding++ split the key-value sequence into chunks processed in parallel and then combine them with the same online-softmax merge, recovering parallelism that the single-query shape would otherwise lose. The KV cache those kernels read from is the subject of Section 22.5, and the decode-time batching that surrounds them is the subject of Section 22.7. The frontier is a family of attention kernels each specialized to a regime, training, prefill, and decode, all sharing the tiled online-softmax core.
You almost never write a fused attention kernel by hand. PyTorch ships a fused implementation behind scaled_dot_product_attention, which dispatches to a FlashAttention backend when the inputs qualify, and the standalone flash-attn package exposes the latest kernels directly. The roughly forty lines of tiling and online-softmax bookkeeping in Code 22.6.1 collapse to a single call that the library fuses, schedules, and runs at FP8 or FP16 on the available hardware:
import torch
import torch.nn.functional as F
# Q, K, V shaped (batch, heads, seq_len, head_dim) on the GPU.
# PyTorch fuses this into a FlashAttention kernel when the shapes/dtype qualify;
# no N x N matrix is ever materialized.
out = F.scaled_dot_product_attention(Q, K, V, is_causal=True)
# Force the FlashAttention backend explicitly when you want to be sure:
from torch.nn.attention import sdpa_kernel, SDPBackend
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
out = F.scaled_dot_product_attention(Q, K, V, is_causal=True)
flash-attn package exposes FlashAttention-2 and FlashAttention-3 kernels directly for cases the built-in dispatcher does not yet cover.5. Approximate Attention, and Why Exact Won Intermediate
Before FlashAttention, the dominant response to the quadratic cost of attention was to approximate it. Sparse-attention methods (Longformer, BigBird) compute only a chosen subset of the score entries, typically a local window plus a few global tokens, reducing the cost from $N^2$ to roughly $N$ at the price of a fixed attention pattern. Linear-attention methods (the Performer family, and later state-space models) replace the softmax with a kernel feature map that factorizes the computation so it never forms the $N \times N$ matrix at all, achieving linear cost but changing the function being computed. Sliding-window attention, used in several production models, simply caps each token's view to a recent window. Every one of these trades exactness for asymptotic savings.
FlashAttention reframed the problem. The quadratic term people were trying to approximate away was, for the sequence lengths most deployments actually use, not a compute wall but a memory-traffic wall, and the memory-traffic wall could be removed without any approximation. Once exact attention ran near hardware peak and in linear memory, the accuracy cost of the approximate methods no longer bought enough speed to justify itself for general use. The approximate variants did not vanish: sliding-window and sparse patterns remain valuable at extreme context lengths where even linear memory and exact compute become expensive, and linear-attention and state-space ideas drive their own research line. But for the mainstream case, exact FlashAttention became the default, a reminder that the right move is sometimes to make the exact computation cheap rather than to settle for an approximate one.
Who: An inference engineer at a document-analysis startup serving a summarization model.
Situation: A new feature needed to process whole contracts in one pass, pushing the context length from 4,000 tokens to 32,000.
Problem: With the stock attention kernel, the $N^2$ score matrix at 32,000 tokens overflowed the accelerator's memory, and the few requests that did fit ran far too slowly to meet the latency target.
Dilemma: Buy larger-memory accelerators and approximate the attention with a sliding window to cut the footprint, accepting a quality regression on long-range references, or change the kernel and keep attention exact.
Decision: They switched the attention path to a FlashAttention backend, which removed the quadratic memory term and kept the computation bit-for-bit exact, so no quality regression was possible.
How: The model already ran in PyTorch, so the change was routing attention through scaled_dot_product_attention with the FlashAttention backend, as in Code 22.6.2, plus a few lines to confirm the kernel was being selected.
Result: The 32,000-token context fit comfortably on the existing accelerators, per-request latency dropped enough to meet the budget, and because the kernel is exact, summaries were identical to a small reference run that used the naive path.
Lesson: When a long-context feature blocks on attention memory, the first thing to try is the exact fused kernel, not an approximation; approximate attention is a tool for the regime past where exact-and-fused still fits, not a default.
6. Why the Kernel Is a Distributed-Serving Lever Intermediate
Everything so far happens inside one accelerator, which is exactly why it belongs at the front of a part about distribution. The throughput of a single node is the multiplier that distribution scales. If a kernel-level change makes attention two or three times faster per node, the whole serving fleet of Chapter 23 needs proportionally fewer nodes to meet the same demand, and the per-token cost that Chapter 24 drives down across the fleet starts from a smaller base. Kernel efficiency is not a detail beneath the distribution layer; it sets the baseline that the distribution layer multiplies, which is the entire reason this chapter exists as a labeled prerequisite.
The connection runs deeper than a multiplier, though. The online-softmax merge that lets FlashAttention combine tiles within one device is the very same operation that lets attention be split across devices. In Ring Attention and the sequence-parallel methods of Section 16.7, each machine holds a slice of the sequence and computes attention over its local keys and values, then passes its partial softmax statistics, the same running maximum, normalizer, and accumulator, around a ring of machines, merging them with the identical $\alpha$-rescaling correction. A tile that crossed an SRAM boundary on one chip becomes a tile that crosses a network link between chips, and the math that recombines them is unchanged. The associativity that makes online softmax work is what makes attention distributable at all.
This section is a scale-up technique, single-node and kernel-level, yet it advances the book's scale-out spine twice over. First, the per-node throughput it unlocks is the baseline that distributed serving multiplies; the fleet sizing of Chapter 23 and the LLM serving economics of Chapter 24 both begin from this number. Second, the online-softmax merge that fuses tiles inside one accelerator is identically the merge that stitches attention back together when a sequence is split across machines in Section 16.7. A primitive that began as a memory-saving trick on one device returns, unchanged, as the mechanism of sequence parallelism across many. Make the per-node operation efficient and associative, and distribution inherits both properties for free.
The lovely thing about FlashAttention crossing into Ring Attention is that almost nothing changes except the distance the tile travels. On one chip, a tile that does not fit in SRAM is the boundary you tile around. Across a cluster, a tile that does not fit on one machine is the boundary you ring around. Same recurrence, same correction factor, a few orders of magnitude more latency on the hop. The kernel writer and the distributed-systems engineer turn out to be solving the same problem at different radii.
The thread from here runs straight into the rest of the chapter and the part. The KV cache of Section 22.5 is what the decode-time attention kernels read from; the continuous batching and speculative decoding of Section 22.7 are what keep those kernels fed with enough work to stay compute-bound. Each per-node lever compounds with the others, and together they set the throughput that the fleet then scales. We turn next to the batching and decoding techniques that surround the attention kernel, in Section 22.7.
Using the roofline reasoning of Section 3.7, explain why naive attention is memory-bound while the feed-forward matrix multiplications in the same transformer layer are compute-bound. Count the floating-point operations and the bytes of HBM traffic for each as functions of the sequence length $N$, head dimension $d$, and hidden width, and argue which side of the roofline ridge each sits on. Then state precisely which of those two quantities FlashAttention changes for attention and which it leaves alone, and why that is enough to move attention across the ridge.
Extend Code 22.6.1 to sweep the sequence length $N$ over several values (for example 1024, 2048, 4096, 8192) while keeping the tile sizes fixed, and record both kernels' peak intermediate size at each $N$. Plot or tabulate the two curves and confirm that the naive peak grows quadratically while the tiled peak stays flat. Then vary the tile sizes Br and Bc at a fixed $N$ and describe the trade-off you observe between the tiled kernel's peak memory and the number of tiles it must process, relating it to the SRAM-capacity constraint a real kernel faces.
The tiled kernel in Code 22.6.1 merges blocks of one sequence on one device. Sketch how the identical online-softmax merge would combine partial results from two devices, each holding half of the keys and values, that exchange only their running maximum $m$, normalizer $l$, and accumulator $o$. Write the merge of two such partial states explicitly and argue that the operation is associative, so that any number of devices arranged in a ring produce the same exact result. Connect your answer to the sequence-parallel methods of Section 16.7, and state what is communicated per ring step and how its size scales with the head dimension rather than the sequence length.