Part III: Distributed Machine Learning
Chapter 11: Parameter Servers and Distributed Embeddings

Terabyte-Scale Embeddings

"I hold two billion rows and I have met perhaps four hundred of them this second. The rest are asleep on an SSD, dreaming of the day someone clicks."

A Sharded Embedding Table With a Long Tail
Big Picture

An industrial recommendation model is not one model on one machine; it is a terabyte of embedding rows sharded across many hosts that answer lookups in parallel, attached to a small dense network replicated on every worker. When a system must embed billions of users and billions of items into wide vectors, the embedding tables alone reach terabytes, far past the tens of gigabytes of memory on any single accelerator. The tables therefore must be split across machines: this is model parallelism, but for a lookup table rather than a matrix multiply. The dense layers that consume those embeddings are tiny by comparison and stay data-parallel on every worker. The looked-up rows are routed from the embedding shards to the dense replicas with an all-to-all, the same collective that routes tokens between experts. This hybrid, model-parallel embedding feeding a data-parallel dense network, is the standard architecture of every large recommender, and this section shows why its cost is set by memory capacity and network bytes rather than by arithmetic.

The previous section built the sparse-model machinery: an embedding table is a lookup over a vocabulary, and only the rows named by a batch are read or updated, so the parameter server ships rows rather than dense gradients. That picture still assumed the table fit somewhere convenient. This section removes that assumption. A web-scale recommender assigns an embedding vector to every user, every item, every advertiser, every query token, and every cross-feature, and the vocabularies run to billions of entries. At a few hundred bytes per row, the table crosses one terabyte and keeps climbing. No single accelerator holds a terabyte of high-bandwidth memory, and renting one that did would not help, because the bottleneck is not arithmetic but the sheer capacity to store and the bandwidth to move sparse rows. The only way forward is to shard the table across many hosts and accept the communication that sharding forces, the same trade this book has made since Section 1.1.

Model-parallel embedding: table sharded across hosts Host 1 HBM cache: hot rows host RAM: warm rows SSD: cold tail (billions of rows) Host S HBM cache: hot rows host RAM: warm rows SSD: cold tail (billions of rows) shards 1 .. S of the table Dense worker 1 data-parallel MLP Dense worker W data-parallel MLP All-to-all gather looked-up rows to dense replicas sparse ids pooled embeddings -> dense forward/backward gradients for hit rows flow back along the same all-to-all, applied to the owning shard (push-pull of Section 11.2)
Figure 11.7.1: The standard recommendation architecture. The embedding table is split across $S$ host parameter servers (model parallelism for the lookup), each backing its shard with a three-level memory hierarchy: an HBM cache of hot rows, host RAM for warm rows, and SSD for the cold tail. The sparse ids in a batch are routed to the owning shards, the looked-up rows are gathered to the data-parallel dense workers by an all-to-all, and gradients for the rows that were read flow back along the same path. The all-to-all is the collective introduced for routing tokens between experts in Section 4.6.

1. Why the Table Cannot Fit Beginner

Start with the arithmetic of size, because it is what forces every later decision. Suppose a recommender embeds a combined vocabulary of $R$ rows into vectors of width $D$, stored at $b$ bytes per element. The table occupies

$$M_\text{table} = R \cdot D \cdot b \text{ bytes}.$$

Plug in numbers that are ordinary for a large platform: $R = 2 \times 10^9$ rows across users, items, and cross-features, $D = 128$, and $b = 4$ bytes for fp32. The table is $2 \times 10^9 \cdot 128 \cdot 4 \approx 1.02 \times 10^{12}$ bytes, just over one terabyte. A high-end accelerator carries on the order of 80 gigabytes of HBM, so the table is more than ten times too large for the largest single device, and production tables at the biggest platforms run to tens of terabytes. The dense network sitting on top, a multilayer perceptron with feature interactions, is typically tens to low hundreds of millions of parameters: a rounding error against the table. This inversion, where the lookup table dwarfs the compute network, is the defining feature of the recommendation workload and the reason its distribution strategy differs from that of a language model.

Key Insight: The Embedding Is Model-Parallel, the Dense Network Is Data-Parallel

Because the embedding table exceeds any single device, it must be partitioned across machines by row range or by hashing the id space: this is model parallelism, applied to a lookup rather than a matmul. Because the dense network is small, it is replicated on every worker and trained data-parallel, exactly as in Section 1.1. The two halves meet at an all-to-all that ships the looked-up rows from the embedding shards to the dense replicas. This hybrid, model-parallel sparse part plus data-parallel dense part joined by all-to-all, is not one option among many; it is the architecture that essentially every terabyte-scale recommender converges on.

2. The All-to-All Between Sparse and Dense Intermediate

Tracing one training step makes the communication concrete. A global batch is split across $W$ dense workers, and each sample names a set of sparse ids: user id, recent item ids, advertiser id, and engineered cross-features. Those ids are scattered across all $S$ embedding shards, because the id space was partitioned by hashing, not by which worker happens to hold the sample. So before any dense compute can run, every worker must send each id to the shard that owns it and receive back the corresponding row. Sending a different slice of ids from every worker to every shard, and receiving a different slice of rows in return, is precisely an all-to-all, the collective introduced for mixture-of-experts routing in Section 4.6. The looked-up rows are then pooled (summed or averaged per feature) and handed to the dense network; on the backward pass, the gradient for each row that was read travels back along the same all-to-all to the shard that owns it, where the push-pull update of Section 11.2 applies it.

The cost of that step is not dominated by floating-point work. If a global batch of $B$ samples each fires $L$ sparse lookups, the step moves on the order of $B \cdot L \cdot D \cdot b$ bytes of embedding rows across the network in each direction. With $B = 65{,}536$, $L = 40$, $D = 128$, and $b = 4$, that is roughly $1.3$ gigabytes of row traffic per step before any compression, while the dense matmuls touch only a few hundred megaflops worth of activations. The step is communication-bound, and the table itself is memory-capacity-bound. Neither bottleneck is arithmetic, which is why this section reasons in bytes and bandwidth rather than in flops.

Thesis Thread: The Same All-to-All, a Different Sparse Object

The all-to-all that routes tokens to experts in a mixture-of-experts layer (Section 4.6) and the all-to-all that routes looked-up rows from embedding shards to dense workers here are the same collective serving the same purpose: a sparse object lives partitioned across machines, and each step a data-dependent subset of it must be gathered to wherever the compute is. Expert parallelism and terabyte embeddings are two faces of one pattern, distributing a sparse parameter set and paying an all-to-all to assemble the active slice. We meet the expert-parallel face in Chapter 17; recognizing it as a relative of the embedding all-to-all is the point of the thread.

3. The Memory Hierarchy: HBM, Host RAM, SSD Intermediate

A terabyte table sharded across, say, sixteen hosts still leaves each host holding 64 gigabytes of rows, more than its HBM but well within its SSD and most of its host RAM. The systems trick that makes this affordable is the long tail of access. Item popularity in a recommender follows a heavy-tailed, roughly Zipfian distribution: a tiny head of items absorbs most of the impressions, while the vast majority of rows are touched rarely or never within a step. That skew turns a capacity problem into a caching problem. Each host keeps the hottest rows in HBM for the lowest latency, the warm rows in host RAM, and the enormous cold tail on SSD or even networked storage, paging rows up the hierarchy on demand and evicting cold ones. A row served from the HBM cache costs no cross-host network bytes at all; only a cache miss triggers the all-to-all fetch from a remote shard.

The economics follow directly. Let $h$ be the cache hit rate, the fraction of lookups served locally from cached hot rows. The cross-host traffic per step falls from the full $B \cdot L \cdot D \cdot b$ to

$$M_\text{comm} = (1 - h)\, B \cdot L \cdot D \cdot b \text{ bytes},$$

so a cache that captures 85 percent of lookups cuts the network bytes, and the time spent in the all-to-all, by the same 85 percent. Because the hot head is small, this hit rate costs only a few megabytes of HBM per host. The demo below models exactly this trade: it sizes a one-terabyte table, shards it, and shows the cache hit rate climbing and the cross-host bytes collapsing as the HBM cache grows from nothing to a few thousand rows.

import math
import random

random.seed(0)

# ---- Table geometry -----------------------------------------------------
ROWS = 2_000_000_000        # 2 billion embedding rows (users x items vocabulary)
DIM = 128                   # embedding width
BYTES_PER_ELEM = 4          # fp32 storage on the cold tail
table_bytes = ROWS * DIM * BYTES_PER_ELEM
print(f"embedding rows         : {ROWS:,}")
print(f"embedding dim          : {DIM}")
print(f"full table size        : {table_bytes / 1e12:.3f} TB (fp32)")

# ---- Sharding across S hosts -------------------------------------------
S = 16                      # parameter-server hosts holding the table
rows_per_host = ROWS // S
bytes_per_host = table_bytes / S
print(f"hosts S                : {S}")
print(f"rows per host          : {rows_per_host:,}")
print(f"table bytes per host   : {bytes_per_host / 1e9:.1f} GB")

# ---- A training step: lookups for a global batch -----------------------
GLOBAL_BATCH = 65536        # samples in the global batch
LOOKUPS_PER_SAMPLE = 40     # sparse feature ids fired per sample
total_lookups = GLOBAL_BATCH * LOOKUPS_PER_SAMPLE

# Zipfian popularity: a small head of rows takes most of the traffic.
ZIPF_S = 1.1
HEAD = 5000                 # candidate hot rows we might cache
norm = sum(1.0 / (r ** ZIPF_S) for r in range(1, HEAD + 1))


def sample_row():
    """Draw a row id from a Zipf head, else from the uniform cold tail."""
    if random.random() < 0.85:                 # 85% of lookups hit the head
        x = random.random() * norm
        acc = 0.0
        for r in range(1, HEAD + 1):
            acc += 1.0 / (r ** ZIPF_S)
            if acc >= x:
                return r - 1                    # 0-indexed hot row
    return random.randint(HEAD, ROWS - 1)       # cold tail row


lookups = [sample_row() for _ in range(20000)]  # sampled estimate of the batch

# ---- Cache hit rate as a function of HBM cache capacity ----------------
def hit_rate(cache_rows):
    cached = set(range(cache_rows))             # cache the most popular rows
    hits = sum(1 for r in lookups if r in cached)
    return hits / len(lookups)


# Communication model: a miss must be fetched from a remote shard via
# all-to-all. Bytes moved per miss = DIM * BYTES_PER_ELEM (the looked-up row).
# Hits are served from local HBM and cost no network bytes.
BYTES_PER_ROW = DIM * BYTES_PER_ELEM
NET_GBPS = 100.0            # per-host injection bandwidth, gigabytes/sec

print()
print("cache_rows |  HBM_MB  | hit_rate | cross-host bytes/step | step comm (ms)")
print("-" * 74)
for cache_rows in (0, 100, 1000, 5000):
    hr = hit_rate(cache_rows)
    hbm_mb = cache_rows * BYTES_PER_ROW / 1e6
    misses = total_lookups * (1.0 - hr)
    comm_bytes = misses * BYTES_PER_ROW
    comm_ms = comm_bytes / (NET_GBPS * 1e9) * 1e3
    print(f"{cache_rows:>10} | {hbm_mb:>7.2f} | {hr:>7.3f}  | {comm_bytes/1e6:>17.1f} MB | {comm_ms:>10.2f}")

# ---- Quantization of the cold tail -------------------------------------
print()
base = table_bytes / 1e12
for bits, label in ((8, "int8"), (4, "int4")):
    q = ROWS * DIM * (bits / 8) / 1e12
    print(f"cold-tail as {label:<4}: {q:.3f} TB  ({base / q:.1f}x smaller than fp32)")
Code 11.7.1: A pure-Python cost model of a one-terabyte embedding sharded across sixteen hosts. It draws lookups from a Zipfian popularity head over a uniform cold tail, then sweeps the HBM cache capacity to show the hit rate rising and the cross-host all-to-all traffic falling, and closes by sizing the table under int8 and int4 quantization of the cold rows.
embedding rows         : 2,000,000,000
embedding dim          : 128
full table size        : 1.024 TB (fp32)
hosts S                : 16
rows per host          : 125,000,000
table bytes per host   : 64.0 GB

cache_rows |  HBM_MB  | hit_rate | cross-host bytes/step | step comm (ms)
--------------------------------------------------------------------------
         0 |    0.00 |   0.000  |            1342.2 MB |      13.42
       100 |    0.05 |   0.576  |             568.8 MB |       5.69
      1000 |    0.51 |   0.751  |             333.7 MB |       3.34
      5000 |    2.56 |   0.850  |             200.9 MB |       2.01

cold-tail as int8: 0.256 TB  (4.0x smaller than fp32)
cold-tail as int4: 0.128 TB  (8.0x smaller than fp32)
Output 11.7.1: Caching only 5,000 hot rows, about 2.6 megabytes of HBM per host, lifts the hit rate to 0.85 and cuts the cross-host bytes per step from 1,342 megabytes to 201 megabytes, an 85 percent reduction in all-to-all traffic for a trivial memory cost. Quantizing the cold tail to int8 or int4 shrinks the stored table by 4 to 8 times on top of that.

The numbers in Output 11.7.1 tell the whole capacity story. A few thousand cached rows, costing single-digit megabytes, absorb most of the lookup traffic because of the Zipfian skew, and the leftover misses are what actually traverse the network. The table that must live somewhere is a terabyte; the table that must move per step, after caching, is a couple hundred megabytes. Memory capacity and network bytes, not arithmetic, are the two quantities a recommender's systems team spends its life minimizing.

Fun Note: The 0.05-Megabyte Win

Notice the first jump in Output 11.7.1: caching just 100 rows, fifty kilobytes, already serves 58 percent of lookups. The popularity head is so steep that the cache pays for itself before it is large enough to register on a memory budget. Most of the terabyte exists to handle the rare row that someone, somewhere, clicks once a week.

4. Compressing the Rows Themselves Advanced

Caching reduces traffic; quantization reduces the stored and moved bytes per row. Since most embedding rows are read far more often than they are precisely needed, the cold tail can be stored at lower precision with little measurable loss in recommendation quality. The last block of Output 11.7.1 makes the saving explicit: storing rows as int8 instead of fp32 shrinks the table fourfold, from 1.02 terabytes to 256 gigabytes, and int4 shrinks it eightfold to 128 gigabytes, often enough to fit warm rows entirely in host RAM and shorten the SSD tail. Beyond uniform quantization, production systems compress further with row-wise scales, hashing tricks that share parameters across many ids (the hashing-trick and compositional-embedding family), and learned mixed-dimension schemes that give rare rows narrower vectors than popular ones. Each technique trades a controlled amount of representational precision for a large reduction in the two scarce resources, capacity and bandwidth.

Practical Example: The Recommender That Stopped Buying Bigger Boxes

Who: A systems engineer on the ranking-infrastructure team at a large video-sharing platform.

Situation: The main retrieval model's embedding tables had grown to 6 terabytes as new cross-features were added, and training throughput was falling each quarter.

Problem: The team had been sharding the table across more and more high-memory hosts, but cross-host all-to-all traffic now dominated step time and the host count was getting expensive.

Dilemma: Keep adding hosts to hold the raw fp32 table, which spread capacity thinner but did nothing for the communication bottleneck, or invest in an HBM hot-row cache plus int8 quantization of the cold tail, which required reworking the lookup path and accepting a small precision risk.

Decision: They added the cache and quantization, because the step was communication-bound, and Output 11.7.1 shows caching attacks exactly that bottleneck while quantization shrinks what remains.

How: They cached the top few thousand rows per shard in HBM, stored cold rows int8 with per-row scales, and left the dense network untouched in fp32, then measured offline ranking metrics to confirm no quality regression.

Result: Cross-host bytes per step fell by roughly 80 percent, the table footprint dropped fourfold so they could remove hosts rather than add them, and step throughput recovered with no measurable loss in recommendation quality.

Lesson: When a workload is memory- and communication-bound, the leverage is in caching the hot head and compressing the cold tail, not in renting more or bigger machines to hold the raw table.

5. The Cost Structure in One Picture Intermediate

Pulling the threads together, a terabyte-scale embedding system is governed by two budgets, neither of them arithmetic. The first is memory capacity: the table of $M_\text{table} = R \cdot D \cdot b$ bytes must be held somewhere across the $S$ hosts and their memory hierarchy, and quantization lowers $b$ while compositional tricks lower the effective $R$. The second is communication: each step moves $M_\text{comm} = (1 - h)\, B \cdot L \cdot D \cdot b$ bytes through the all-to-all, and the hot-row cache raises $h$ while quantization again lowers $b$. Every engineering lever in this section, sharding, tiering, caching, quantizing, and parameter sharing, pushes on one or both of these two terms. The dense network's flops, which would dominate a language model's cost, are a footnote here. This is the signature of the recommendation workload, and it is why Chapter 38 can build an entire industrial case study around getting these two budgets right at scale.

Library Shortcut: TorchRec Shards the Table for You

In Code 11.7.1 we modeled sharding, caching, and the cross-host fetch by hand to expose the cost structure. In production you do not implement the sharded table, the all-to-all, or the hot-row cache yourself; PyTorch's TorchRec library provides them as configured components. You declare the embedding tables and a sharding plan, and TorchRec places shards across hosts, inserts the all-to-all, and manages the lookup, collapsing hundreds of lines of distributed table management into a short declaration:

# pip install torchrec ; run under torchrun across the host group
import torch
from torchrec import EmbeddingBagCollection, EmbeddingBagConfig
from torchrec.distributed import DistributedModelParallel

ebc = EmbeddingBagCollection(tables=[
    EmbeddingBagConfig(name="item_id", embedding_dim=128,
                       num_embeddings=2_000_000_000, feature_names=["item"]),
])

# DistributedModelParallel chooses a sharding plan (row-wise / table-wise),
# places shards across hosts, and wires the all-to-all automatically.
model = DistributedModelParallel(module=ebc, device=torch.device("cuda"))
Code 11.7.2: The same sharded, all-to-all-connected embedding modeled in Code 11.7.1, now declared in a handful of lines with TorchRec. The library selects the sharding plan, places the two-billion-row table across the host group, and inserts the all-to-all and the embedding lookup that the cost model accounted for by hand.
Research Frontier: Terabyte Recommenders (2024 to 2026)

The terabyte-embedding problem is an active systems frontier. Meta's open DLRM lineage and its Neo and software-hardware co-design work target trillion-parameter recommendation tables with hierarchical HBM, host-memory, and SSD tiering, and ZionEX-style hardware moves the embedding all-to-all onto dedicated interconnects. TorchRec has become the reference framework for sharded embeddings, adding row-wise and table-wise sharding plans and quantized inference paths through 2024 to 2026. A parallel research line attacks the table size directly: compositional and hash-based embeddings, learned mixed-dimension tables that give rare ids narrower vectors, and low-precision (int8 and int4) training of the cold tail, all reported at RecSys and in industry engineering notes as the way to keep table growth sublinear in the vocabulary. The shared message of this work matches the cost structure of Section 5: progress comes from shrinking memory capacity and all-to-all bytes, not from adding flops.

We have now seen the extreme of the parameter-server story: a model so dominated by its sparse embedding that it must be sharded across hosts, tiered through a memory hierarchy, cached, and quantized, all to keep two budgets, capacity and communication, under control. One question remains before the chapter closes. A table spread over many hosts is a table where, at any moment, some host has probably failed; losing a shard cannot be allowed to lose the model. How a sharded parameter server survives the failure of the machines that hold it is the subject of Section 11.8.

Exercise 11.7.1: Which Budget Binds? Conceptual

For each change, state whether it primarily relieves the memory-capacity budget $M_\text{table}$, the communication budget $M_\text{comm}$, or both, and explain why using the two formulas in Section 5: (a) quantizing the cold-tail rows from fp32 to int4; (b) adding more rows to the HBM hot-row cache; (c) doubling the number of shard hosts $S$ while keeping the table size fixed; (d) replacing two billion independent ids with a compositional embedding that hashes them into ten million shared buckets. Which of these does nothing for the per-step all-to-all time?

Exercise 11.7.2: Tune the Cache Coding

Modify Code 11.7.1 so the cache is no longer the exact top-$k$ rows but an approximate, finite cache updated online: maintain a set of cached row ids, and on each lookup, if the row is absent, insert it and evict the least-recently-used entry when the cache is full. Measure the achieved hit rate against the idealized top-$k$ hit rate for the same capacity, and against the Zipf exponent $\texttt{ZIPF\_S}$. By how much does a realistic LRU cache fall short of the oracle, and how does that gap change as the popularity distribution flattens?

Exercise 11.7.3: All-to-All vs Replication Analysis

Using $M_\text{comm} = (1 - h)\, B \cdot L \cdot D \cdot b$ from Section 5 with the parameters of Code 11.7.1, compute the per-step cross-host bytes at hit rates $h = 0.5$, $0.85$, and $0.99$. Now consider an alternative design that fully replicates a 256-megabyte int8 quantized table on every dense worker, eliminating the all-to-all entirely but paying a one-time broadcast and per-step update synchronization. Argue from the byte counts under what conditions (table size, hit rate, number of workers, update frequency) replication beats the sharded all-to-all, and why the terabyte regime rules replication out. Relate your answer to the parameter-server-versus-all-reduce trade developed across this chapter.