Part V: Distributed Inference and Serving
Chapter 25: Distributed Retrieval and Vector Search

Multi-Stage Retrieval and Distributed Reranking

"I read every passage in the shard before I admit the answer was in the first ten. The cross-encoder only reads ten, and somehow it is the one everyone calls accurate."

An ANN Index With Recall to Spare
Mental model: The Retrieve-then-Rank Funnel
Mental model. A huge candidate pool is narrowed by a fast approximate retrieval stage, then by a precise but expensive reranking stage, into a tiny ordered shortlist. The same funnel powers search, RAG, and recommendation.
Big Picture

The most accurate way to score a query against a passage is also the most expensive, so you cannot afford to run it on the whole corpus; the retrieve-then-rerank cascade resolves this by spending cheap compute broadly and expensive compute narrowly. A first stage runs approximate nearest-neighbor search over the entire sharded index (the machinery of Section 25.5) and returns a few hundred candidates with high recall. A second stage runs a far more accurate scorer, a cross-encoder or a late-interaction model, over only those candidates and cuts to the final top-$k$. This is not a retrieval trick; it is a distributed resource-allocation decision, the same multi-fidelity logic that governs hyperparameter search. The funnel sketched in the mental model above is the shape to keep in mind throughout: a wide cheap stage narrows to a tiny precise shortlist. This section shows why the cascade reaches near-perfect final quality at a fraction of the cost of reranking everything, how to size the candidate set, and where the reranker lives as its own served model on the fleet.

By this point in the chapter, the first stage is built. Approximate nearest-neighbor indexes (Section 25.4) shard across machines, replicate for throughput, and answer a query in milliseconds by returning the passages whose embeddings are closest to the query embedding. That bi-encoder design is fast precisely because it is decomposable: every passage is embedded once, offline, into a single vector, and a query is a single dot product against the index. The price of that speed is that the query and the passage never meet until the dot product; the model never reads them together. A different family of models does read them together, scores far more accurately, and cannot be precomputed. The tension between those two facts is what the cascade exists to manage, and managing it well across a fleet is the subject of this section.

The cascade pattern is simple to state and consequential to size. A cheap, high-recall stage casts a wide net; a sequence of progressively more expensive, higher-precision stages tightens it. Each stage processes fewer items than the one before, so you are allowed to spend more compute per item the deeper you go. The whole structure is an explicit decision about where to put your inference budget, and the rest of this section makes that decision quantitative.

Stage 1: cheap, broad ANN search bi-encoder, one dot product per passage over ALL N shards scores N (millions) 1 unit each Candidate set top-C candidates C = a few hundred high recall, low precision passes C forward C ≪ N Stage 2: expensive, narrow cross-encoder joint query+passage transformer, batched over only C items scores C (hundreds) ~1000 units each final top-k k = 10, high precision cut to k
Figure 25.7.1: The retrieve-then-rerank cascade as a funnel. Stage 1 runs a cheap bi-encoder ANN search over all $N$ passages across the sharded index and returns the top-$C$ candidates with high recall. Stage 2 runs an expensive cross-encoder over only those $C$ items, where $C \ll N$, and cuts to the final top-$k$. Per-item compute (the orange annotations) rises by roughly three orders of magnitude from stage 1 to stage 2, but the item count falls by four or more, so total reranking cost stays small. The candidate count $C$ is the one knob that trades recall against rerank cost.

1. Why One Scorer Cannot Do the Whole Job Beginner

Two model architectures sit at the heart of the cascade, and the cascade exists because neither one can serve the whole pipeline alone. The bi-encoder embeds a query and a passage independently into vectors $\mathbf{q}$ and $\mathbf{p}$ and scores them with a similarity such as the dot product $\mathbf{q}^\top \mathbf{p}$. Because the passage embedding does not depend on the query, you embed the entire corpus once, offline, and store the vectors in the ANN index. At query time the model runs once, on the query, and the index does the rest. This is what makes search over millions of passages feasible in milliseconds; it is also why the bi-encoder is the only architecture that can serve stage 1.

The cross-encoder makes the opposite trade. It concatenates the query and the passage into one input, feeds the pair through a transformer, and reads them jointly with full attention between every query token and every passage token. That joint reading is what lets it catch the relevance signals a bi-encoder must compress away: negation, exact term overlap, the difference between "a treatment that causes nausea" and "a treatment for nausea." The cost is that nothing precomputes. A cross-encoder score for a query-passage pair requires a fresh forward pass through a transformer, and there is no query-independent vector to cache. Scoring a query against $N$ passages means $N$ forward passes. For a corpus of millions, that is not a latency budget; it is an afternoon.

So the architectures are complementary by construction. The bi-encoder is cheap per item and decomposable, which is exactly what a corpus-wide first stage needs. The cross-encoder is expensive per item and accurate, which is exactly what a final-stage scorer over a short list needs. The cascade is the structure that lets each do the job it is suited to.

Key Insight: Decomposability Is What You Are Really Trading

The bi-encoder is fast not because it is a smaller model but because its score factorizes into a query part and a passage part that can be computed separately and combined by a cheap dot product. That factorization is what lets the passage side precompute and the index search a billion vectors. The cross-encoder is accurate because it refuses to factorize: it lets every query token attend to every passage token. You cannot have both at corpus scale. The cascade buys accuracy where it is affordable (a few hundred items) and decomposability where it is mandatory (the whole corpus).

2. The Cascade Is a Multi-Fidelity Allocation, Not a Retrieval Trick Intermediate

It is tempting to read the cascade as a domain-specific optimization for search. It is something more general: a multi-fidelity resource allocation, the same pattern that governs distributed hyperparameter search. In Section 21.3, successive-halving evaluates many configurations cheaply at low fidelity (a few training epochs), keeps the promising ones, and spends expensive full-fidelity evaluation only on survivors. The retrieve-then-rerank cascade has the identical shape: evaluate many passages cheaply at low fidelity (a dot product), keep the promising ones, and spend expensive high-fidelity evaluation (a cross-encoder pass) only on survivors. Both are answers to the same question: given a fixed compute budget and a cheap-but-noisy proxy for an expensive-but-accurate score, how do you allocate the budget to maximize final quality?

Stating the cost makes the allocation concrete. Let the corpus hold $N$ passages, let the cheap first stage cost $c_1$ per passage, let the expensive reranker cost $c_2$ per passage with $c_2 \gg c_1$, and let the cascade pass $C$ candidates from stage 1 to stage 2. The total cost per query is

$$\text{Cost}(C) = c_1 N + c_2 C.$$

Reranking the whole corpus costs $c_2 N$; the cascade replaces the second term's $N$ with $C$. Since the first stage already cost $c_1 N$ no matter what, the saving is the ratio of the expensive work, $c_2 C / (c_2 N) = C / N$. With $C$ in the hundreds and $N$ in the millions, you pay well under one percent of the full-rerank cost on the expensive stage. The only thing you can lose is recall: a relevant passage that the cheap stage ranked below position $C$ never reaches the reranker, so it can never appear in the final top-$k$. The final quality is therefore capped by stage-1 recall at depth $C$,

$$\text{Recall@}C = \frac{\bigl|\{\text{relevant passages ranked} \le C \text{ by stage 1}\}\bigr|}{|\{\text{relevant passages}\}|}.$$

This is the whole tension in two equations. Cost rises linearly in $C$; the recall ceiling rises in $C$ too, but with sharply diminishing returns, because the cheap stage, while imprecise, is rarely so wrong that a truly relevant passage falls hundreds of places. The right $C$ is where the recall curve has flattened but the cost is still small, and you find it by measurement, not by guessing.

Thesis Thread: Multi-Fidelity Returns, Now in the Serving Path

The cheap-proxy-then-expensive-confirmation pattern you met as successive halving in distributed HPO (Section 21.3) returns here in the inference serving path, where the cheap proxy is an ANN dot product and the expensive confirmation is a cross-encoder forward pass. It is the same allocation logic seen from the serving side rather than the training side: broad cheap compute to filter, narrow expensive compute to decide. Whenever you meet a distributed AI pipeline that pairs a fast approximate filter with a slow exact scorer, ask which fidelity runs where and on how many items; the answer is almost always a cascade, and its quality is capped by the recall of its cheapest stage.

3. The Cascade in Code: High Precision at a Fraction of the Cost Intermediate

The argument above is worth seeing as numbers. The program below builds a small retrieval world in which the cheap stage is a bi-encoder dot product (the high-recall proxy) and the expensive stage is a more accurate scorer that reads a signal the bi-encoder cannot, standing in for what a cross-encoder catches. The "gold" answer is the top-$k$ that the expensive scorer would produce if it ran over the entire corpus, which is exactly what reranking everything gives you. We then sweep the candidate count $C$ and measure how close the cascade gets to that gold answer, and at what fraction of the full-rerank cost.

import numpy as np

rng = np.random.default_rng(7)
N, d = 50_000, 64          # corpus size, embedding dimension
n_queries = 200            # held-out queries we evaluate over
k_final = 10               # final top-k we hand to the user

# A bi-encoder world: every passage and query is a vector. The cheap first-stage
# score is cosine similarity (one dot product), exactly what an ANN index returns.
def unit(v):
    return v / np.linalg.norm(v, axis=-1, keepdims=True)

corpus = unit(rng.standard_normal((N, d)))
queries = unit(rng.standard_normal((n_queries, d)))

# The "true" relevance the EXPENSIVE reranker would assign. We model the
# cross-encoder as a more accurate but costly scorer: it sees a richer signal
# (here, a second view of the passage the bi-encoder cannot see) plus the cheap
# similarity. This is the gap a cross-encoder closes over a bi-encoder.
hidden = unit(rng.standard_normal((N, d)))          # the view only rerank sees
q_hidden = unit(rng.standard_normal((n_queries, d)))

cheap_full = queries @ corpus.T                      # (Q, N) first-stage scores
# The expensive reranker mostly agrees with the cheap stage (so the cheap stage
# is a HIGH-RECALL filter) but corrects it with a signal the bi-encoder misses.
# A high-recall first stage is exactly what makes the cascade work.
rerank_full = 0.80 * cheap_full + 0.20 * (q_hidden @ hidden.T)  # (Q, N) expensive

# Gold = the top-k_final under the EXPENSIVE scorer over the WHOLE corpus.
gold = np.argsort(-rerank_full, axis=1)[:, :k_final]

def precision_at_k(pred):
    hits = sum(len(set(pred[i]) & set(gold[i])) for i in range(n_queries))
    return hits / (n_queries * k_final)

# Baseline A: rerank EVERYTHING (perfect quality, maximal cost).
prec_rerank_all = 1.0                                # equals gold by construction

# Baseline B: cheap-only, no rerank (cheapest, lowest quality).
cheap_topk = np.argsort(-cheap_full, axis=1)[:, :k_final]
prec_cheap_only = precision_at_k(cheap_topk)

print(f"corpus N={N}, queries={n_queries}, k_final={k_final}")
print(f"{'strategy':<26}{'rerank calls/query':>20}{'precision@10':>16}")
print(f"{'rerank ALL (gold)':<26}{N:>20}{prec_rerank_all:>16.3f}")
print(f"{'cheap-only (no rerank)':<26}{0:>20}{prec_cheap_only:>16.3f}")
print("-" * 62)

# The cascade: cheap stage returns top-C candidates, expensive stage reranks
# only those C and keeps the best k_final. Sweep the candidate count C.
for C in (20, 50, 100, 200, 500, 1000):
    cand = np.argsort(-cheap_full, axis=1)[:, :C]    # first stage: C per query
    pred = np.empty((n_queries, k_final), dtype=int)
    for i in range(n_queries):
        c = cand[i]
        order = np.argsort(-rerank_full[i, c])[:k_final]  # rerank ONLY the C
        pred[i] = c[order]
    p = precision_at_k(pred)
    print(f"cascade C={C:<5}{'':<6}{C:>20}{p:>16.3f}   "
          f"({C / N:6.2%} of full rerank cost)")
Code 25.7.1: A pure-Python retrieve-then-rerank cascade. The cheap stage is a dot product over the whole corpus; the expensive stage reranks only the top-$C$ candidates. Sweeping $C$ traces the recall-versus-cost trade of Section 2 directly, with precision@10 measured against the gold answer that full reranking produces.
corpus N=50000, queries=200, k_final=10
strategy                    rerank calls/query    precision@10
rerank ALL (gold)                        50000           1.000
cheap-only (no rerank)                       0           0.594
--------------------------------------------------------------
cascade C=20                           20           0.792   ( 0.04% of full rerank cost)
cascade C=50                           50           0.957   ( 0.10% of full rerank cost)
cascade C=100                         100           0.997   ( 0.20% of full rerank cost)
cascade C=200                         200           1.000   ( 0.40% of full rerank cost)
cascade C=500                         500           1.000   ( 1.00% of full rerank cost)
cascade C=1000                       1000           1.000   ( 2.00% of full rerank cost)
Output 25.7.1: The cascade reaches precision@10 of 0.997 by reranking just 100 of 50,000 passages, that is 0.20% of the cost of reranking everything, and matches the gold answer exactly at $C = 200$ (0.40% of the cost). The cheap stage alone scores only 0.594, so the reranker is doing real work; it is simply doing it on a short list. Note the sharply diminishing returns: doubling $C$ from 100 to 200 buys the last 0.3 points of precision, and beyond 200 buys nothing.

The shape of that table is the entire lesson. Reranking everything is perfect and ruinously expensive. The cheap stage alone is fast and mediocre. A cascade that reranks a few hundred candidates captures essentially all of the quality for essentially none of the expensive-stage cost, because the cheap stage, though imprecise about exact ordering, almost never buries a relevant passage hundreds of places deep. The candidate count $C$ where the precision column flattens, here around 100 to 200, is the operating point you would ship, and you find it exactly as the code does: by sweeping it and watching where the curve bends.

4. The Reranker Is a Served Model on the Fleet Advanced

So far the reranker has been an abstraction. In a real system it is a transformer with its own weights, its own accelerator, and its own latency and throughput profile, served exactly like any other model in Chapter 24. That changes how you think about the candidate count $C$. The reranker scores $C$ query-passage pairs per query, and those $C$ pairs are independent, so they batch: a single forward pass over a batch of $C$ inputs is far more accelerator-efficient than $C$ sequential passes. The candidate count is therefore not only a recall knob but a batch size, and the reranker's throughput, in pairs scored per second, sets a hard ceiling on how large $C$ can be within a latency budget.

This makes reranking a distributed-serving problem in its own right. Under load, one reranker replica saturates, so you replicate it behind a load balancer and route candidate batches across replicas, the same fleet-sizing arithmetic that Chapter 23 applies to any inference service: required replicas equals offered load in pairs per second divided by per-replica throughput. The first stage and the reranker also scale independently, because they bottleneck on different resources. The ANN stage is memory-bandwidth-bound over a large sharded index; the reranker is compute-bound over a small batch. A well-tuned pipeline sizes the two fleets separately and places them so that the candidate set moves cheaply from the search shards to the reranker replicas.

A useful middle architecture sits between the bi-encoder and the cross-encoder: the late-interaction model, of which ColBERT is the canonical example. Instead of collapsing a passage to one vector, it keeps one vector per token and scores a query-passage pair by summing, over each query token, the maximum similarity to any passage token, the so-called MaxSim. This keeps most of the bi-encoder's precompute-and-index advantage (the per-token passage vectors are query-independent) while recovering much of the cross-encoder's token-level precision. The cost is storage and a heavier index: many vectors per passage instead of one. Late-interaction models are often used as the first stage of a cascade whose final stage is a full cross-encoder, giving a three-tier funnel of rising fidelity and falling fan-out.

Library Shortcut: A Production Reranker in Five Lines

Code 25.7.1 builds the cascade from scratch to expose the cost-recall trade. In production the expensive stage is a single call. The sentence-transformers CrossEncoder wraps a fine-tuned reranker and scores query-passage pairs in one batched forward pass; the managed cohere Rerank endpoint does the same as a hosted service. Either replaces a hand-rolled scoring loop, and both handle the batching, padding, and accelerator dispatch that the per-pair view in Code 25.7.1 hides:

# pip install sentence-transformers
from sentence_transformers import CrossEncoder

reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")   # a served model

# `candidates` are the top-C passages from the ANN first stage (Section 25.5).
pairs = [(query, passage) for passage in candidates]              # C query-passage pairs
scores = reranker.predict(pairs, batch_size=64)                   # ONE batched pass over C

top_k = [candidates[i] for i in scores.argsort()[::-1][:10]]      # cut C down to k=10
Code 25.7.2: The expensive stage of the cascade as a single predict call over the $C$ candidates. The batch_size argument is the fleet knob from Section 4: it sets how many query-passage pairs the reranker scores per forward pass, trading latency against accelerator utilization. The managed cohere.rerank(query=..., documents=candidates, top_n=10) call is the same step as a hosted service.
Practical Example: The Support Search That Reranked Its Way Out of a Bigger Index

Who: A search engineer on the help-center team at a SaaS company.

Situation: Semantic search over two million support articles returned answers users called "almost right," with the truly correct article often sitting at rank 4 or 5 instead of rank 1.

Problem: The bi-encoder embeddings were good enough to retrieve the right article into the top 50 but not precise enough to rank it first, and the click-through on position 1 was what the business measured.

Dilemma: Fine-tune a larger bi-encoder and re-embed the whole two-million-document corpus, a multi-day offline job that improves ranking only modestly, or add a cross-encoder reranker over the top candidates, which adds a second served model and per-query latency.

Decision: They added a reranker over the top $C = 100$ candidates, because the recall analysis showed the right article was already inside the top 100 for 98% of queries; the problem was ordering, not retrieval, and a cross-encoder fixes ordering.

How: They served a CrossEncoder on two GPU replicas behind the existing load balancer, batched the 100 candidates per query into one forward pass, and left the ANN index untouched.

Result: Position-1 accuracy rose sharply, added latency stayed within the budget because 100 pairs batch into a single pass, and the expensive stage touched 100 of two million documents per query, well under one ten-thousandth of a full-corpus rerank.

Lesson: When retrieval recall is already high but ranking precision is low, the fix is a reranker over a short candidate list, not a bigger or re-embedded index. Diagnose which stage is failing before you pay to scale the wrong one.

5. Tuning the Candidate Set, and the Rise of LLM Rerankers Advanced

The candidate count $C$ is the cascade's master knob, and Output 25.7.1 shows how to set it: sweep $C$, plot final precision and per-query cost against it, and choose the point where precision has plateaued but cost is still small. A larger $C$ raises the recall ceiling and so the achievable precision, but with diminishing returns and rising rerank cost; a smaller $C$ is cheaper and lower-latency but risks discarding relevant passages before the reranker ever sees them. The right value is workload-specific, because it depends on how well the cheap stage's ordering tracks the expensive stage's: a strong bi-encoder needs a smaller $C$, a weak one needs a larger $C$ to compensate, and the only honest way to find it is the measurement the code performs.

The deepest tier of the cascade is increasingly a large language model. An LLM can be prompted to judge or to order the candidates, reading the query and each passage with the full reasoning capacity of a served generative model. The quality can exceed a fine-tuned cross-encoder, but the cost is another order of magnitude higher: an LLM rerank pass over $C$ candidates is $C$ generations or one long-context generation over all $C$, both far more expensive than a cross-encoder's single classification pass. That cost only makes sense at the very tip of the funnel, over a handful of candidates that a cheaper reranker has already shortlisted, which is why LLM rerankers sit at the end of a multi-tier cascade rather than replacing it. The cascade discipline, cheap compute broad and expensive compute narrow, is what keeps an LLM reranker affordable at all.

Research Frontier: Listwise and LLM Reranking (2024 to 2026)

The reranker tier is where retrieval research is moving fastest. Listwise LLM rerankers such as RankGPT (Sun et al., 2023) and the open RankVicuna and RankZephyr models (Pradeep et al., 2023) score the whole candidate list jointly rather than one passage at a time, so the model reasons about relative order, and a sliding-window strategy keeps the list within the context limit; distilling these listwise teachers into smaller students is an active 2024 to 2025 thread aimed at cutting their cost. Late-interaction retrieval has matured in parallel: ColBERTv2 (Santhanam et al., 2022) and the PLAID engine made per-token indexes practical at scale, and the 2024 generation of training recipes pushes their quality toward cross-encoder territory while keeping the index searchable. A third line treats reranking as just another served LLM call and studies how to batch, cache, and quantize it within a latency budget, folding the reranker into the same serving economics as Chapter 24. The common thread is that the most accurate scorers are now generative and listwise, which only sharpens the cascade's logic: you can afford a brilliant, slow judge precisely because the cheap stages already cut the list to a length the judge can read.

Fun Note: The Hiring-Funnel Analogy Is Not a Metaphor

A company does not interview every applicant on the planet. A keyword filter on the resume (cheap, high recall, low precision) cuts millions to hundreds; a phone screen (more expensive) cuts hundreds to dozens; an on-site loop (very expensive, very accurate) cuts dozens to one. The cost per candidate rises at each stage and the count falls, so the total interviewer-hours stay sane. Swap "applicant" for "passage" and "on-site loop" for "cross-encoder" and you have rebuilt Figure 25.7.1 exactly. Multi-stage cascades are the universal answer to "an accurate judge is too slow to run on everyone."

With the cascade sized and the reranker placed on the fleet, the retrieval pipeline is fast where it must be broad and accurate where it must be precise. One bottleneck remains: every stage of this pipeline re-runs from scratch on every query, even queries it has answered before. The next section, Section 25.8, attacks that waste with distributed caching, so the fleet stops recomputing answers it already holds.

Exercise 25.7.1: Where Does the Cascade Break? Conceptual

The cascade in this section assumes the cheap first stage has high recall at depth $C$: the relevant passages are almost always somewhere in the top $C$, even if mis-ordered. Describe a retrieval workload where this assumption fails, so that relevant passages are routinely ranked below $C$ by the bi-encoder. State what happens to the final precision no matter how good the reranker is, and explain why raising the reranker's quality cannot fix it while raising $C$ (or improving the first stage) can. Connect your answer to the recall ceiling equation in Section 2.

Exercise 25.7.2: Add a Third Tier Coding

Extend Code 25.7.1 into a three-stage cascade: the cheap dot-product stage returns $C_1$ candidates, a "medium" reranker (model it as rerank_full with extra noise added, cheaper but less accurate) cuts those to $C_2 < C_1$, and the full rerank_full scorer cuts $C_2$ to the final top-$k$. Sweep $(C_1, C_2)$ and report, for each pair, the precision@10 and the total expensive-stage cost, counting the medium scorer as cheaper than the full scorer. Find a $(C_1, C_2)$ that matches the precision of the two-stage cascade at $C = 200$ but at lower total reranking cost, and explain why a third tier can help.

Exercise 25.7.3: Size the Reranker Fleet Analysis

A service receives 800 queries per second. The cascade reranks $C = 150$ candidates per query, and one reranker replica scores 6,000 query-passage pairs per second within the latency budget. Using the fleet-sizing arithmetic from Chapter 23 (required replicas equals offered load divided by per-replica throughput), compute how many reranker replicas you need, and how that number changes if a product decision raises $C$ to 300. Then argue, from the cost equation $\text{Cost}(C) = c_1 N + c_2 C$, why doubling $C$ doubles the reranker fleet but leaves the ANN search fleet unchanged.