"I am an embedding row, one of two billion, sitting on shard seventeen. Most days nobody looks me up. Then a celebrity follows me and forty thousand requests arrive in one second, all for me, all at once."
A Hot Key That Did Not Ask For Fame
In a web-scale recommender the model is not mostly the neural network on top; it is the embedding tables underneath, billions of learned vectors that together weigh terabytes and therefore cannot live on one machine. The dense layers that combine those vectors are tiny by comparison and replicate everywhere. The embeddings are the model, and because they do not fit in any single host's memory they are sharded across many parameter servers by construction. This is "distribute the model" in its purest, most data-driven form: not a clever decomposition we chose, but a partition forced on us by sheer parameter count. This section takes the parameter-server machinery built in Chapter 11 and applies it to the concrete object that motivated it, showing why the tables grow so large, how a sparse request touches only a handful of rows scattered across shards, and how popular items turn a balanced partition into a skewed one.
The previous section framed the recommendation problem and the production constraints a planet-scale feed lives under. We now open the model itself and find that almost all of its parameters sit in two lookup tables: one row per user, one row per item. A click model, a ranking network, a two-tower retriever, whatever sits on top, consumes dense vectors fetched from these tables and contributes only a few million weights of its own. The tables, by contrast, contain billions of rows. Understanding how those rows are stored, fetched, and trained across many machines is the heart of distributed recommendation, and it is a direct, concrete instance of the parameter-server abstraction of Chapter 11. We build on that abstraction here rather than rederive it.
1. The Embedding Tables Are the Model Beginner
A recommender represents each categorical entity, a user, an item, a hashtag, a device type, as a learned dense vector. The collection of those vectors for one feature is an embedding table: a matrix with one row per distinct ID and one column per latent dimension. A user table for a billion accounts with $d = 128$ latent dimensions is a matrix of one billion rows. The total parameter count of a table with $N$ rows is $N \cdot d$, and its memory footprint, at $b$ bytes per element, is
$$\text{table bytes} = N \cdot d \cdot b.$$The numbers are not subtle. With $N = 2 \times 10^{9}$ rows, $d = 128$, and single-precision storage ($b = 4$), one table occupies $2 \times 10^{9} \cdot 128 \cdot 4 \approx 1.0 \times 10^{12}$ bytes, about a terabyte, before we add a single item table, optimizer state, or a second feature. Industrial systems carry dozens of such tables and routinely reach tens of terabytes of embeddings. No accelerator holds that; the largest single-host memory is two orders of magnitude too small. The model is sharded not as an optimization but because it physically cannot be otherwise, which is exactly the regime Chapter 16 calls model-parallel, here arriving for the simplest possible reason: the parameters do not fit.
A recommender inverts the usual deep-learning intuition. In a language model nearly all parameters are dense matrices multiplied on every token. In a recommender the dense network is megabytes and replicates onto every worker, while the embedding tables are terabytes and must be partitioned. Each forward pass reads only the few rows the current request mentions, so the tables are enormous but accessed sparsely. That asymmetry, a vast sparsely-read parameter store plus a small densely-read network, is the defining shape of the model and dictates the entire distribution strategy: shard the sparse part across many hosts, replicate the dense part everywhere.
2. Sharding the Table by ID Hash Intermediate
To spread a table across $S$ parameter servers we assign each row to a shard. The standard choice is hash partitioning of the ID: shard $s(\text{id}) = h(\text{id}) \bmod S$ for some hash $h$. Hashing, rather than range partitioning by ID, matters because IDs are not random; sequential or time-correlated IDs would pile recent, active accounts onto one shard. A good hash scatters them, so each shard holds roughly $N / S$ rows and the storage cost per shard is
$$\text{shard bytes} \approx \frac{N \cdot d \cdot b}{S}.$$With the terabyte table above split across $S = 64$ servers, each holds about 16 gigabytes, comfortably in host memory. This is the same push-pull sharded parameter store of Chapter 11, specialized to the case where the parameters are an addressable table and the "key" is an integer ID. Figure 38.2.1 shows the layout: one logical table, physically scattered as disjoint row sets across servers, with a single request reaching into several of them at once.
3. Sparse Lookups and the All-Gather Intermediate
A single ranking request scores a few hundred candidate items for one user. It therefore touches one user row and a few hundred item rows, a vanishing fraction of the billions in the tables. This is the sparse-access pattern that makes the whole scheme viable: although the table is terabytes, the bytes moved per request are tiny. For a batch of $r$ distinct requested IDs the lookup transfers
$$\text{lookup bytes} = r \cdot d \cdot b,$$independent of $N$. With $r = 512$ rows, $d = 128$, $b = 4$, that is roughly 260 kilobytes per request regardless of whether the table holds a million rows or ten billion. The cost is in the round trips, not the volume. Because the requested IDs scatter across shards, the worker sends each shard the subset of IDs it owns, every shard reads its local rows in parallel, and the rows are collected back in request order. That collection is precisely the all-gather collective from Chapter 4, here gathering data rows rather than gradient slices, the same primitive that reduce-scatter and all-gather perform inside sharded training in Chapter 16. The demo in Section 5 performs exactly this routed gather and checks it against a single-table reference.
Writing the hash router, the per-shard storage, and the all-gather by hand is instructive but you would never ship it. PyTorch's TorchRec library represents the whole collection of tables as one EmbeddingBagCollection, then a planner shards it across ranks and inserts the collective communication automatically:
# pip install torchrec ; launch with torchrun --nproc_per_node=S
import torch
from torchrec import EmbeddingBagCollection, EmbeddingBagConfig
from torchrec.distributed import DistributedModelParallel
tables = [EmbeddingBagConfig(name="user", num_embeddings=2_000_000_000,
embedding_dim=128, feature_names=["user_id"]),
EmbeddingBagConfig(name="item", num_embeddings=500_000_000,
embedding_dim=128, feature_names=["item_id"])]
ebc = EmbeddingBagCollection(tables=tables, device=torch.device("meta"))
# The planner picks a sharding (row-wise / table-wise / column-wise) per table
# and DistributedModelParallel wires up the all-to-all + all-gather for lookups.
model = DistributedModelParallel(module=ebc) # sharded across all ranks
DistributedModelParallel inserts the all-to-all routing and all-gather that Section 3 describes by hand. The from-scratch router, storage map, and gather of the Section 5 demo collapse into the single DistributedModelParallel wrapper.4. The Training-Time Challenge: Hot Keys and Sparse Gradients Advanced
Lookups are read-only and balance well in expectation. Training is harder, for two coupled reasons. First, the gradients are extremely sparse: a single example updates only the rows it touched, so the backward pass produces a tiny set of row-indexed updates that must be pushed back to the owning shards, exactly the push step of the parameter server in Chapter 11 and the sparse SGD of Chapter 10. Sparse gradients are cheap to compute and expensive to coordinate, because thousands of workers may all want to update the same popular row in the same step.
That is the second reason: skew. Item popularity follows a heavy-tailed (roughly Zipfian) law, so a few items appear in a large fraction of all requests. If the most popular ID lands on shard $s$, that shard receives far more lookups and far more gradient pushes than its peers. Quantify the imbalance with a skew factor: if shard $s$ serves load $L_s$ and the mean load is $\bar{L} = \frac{1}{S}\sum_s L_s$, then
$$\text{skew} = \frac{\max_s L_s}{\bar{L}} \ge 1,$$where $1$ is perfect balance and larger values mean one shard is the bottleneck for the whole step. Storage stays balanced because every shard holds the same number of rows; it is the access load, driven by popularity rather than row count, that skews. Remedies include replicating the hottest rows onto several shards so reads spread out, hash-bucketing colliding cold IDs together, and pre-aggregating gradients for a hot row on each worker before the push so the owning shard receives one combined update instead of thousands. The skew this section measures in Code 38.2.2 is the recommendation-system face of the straggler problem traced from Chapter 10 onward.
Who: A platform engineer on the embedding-infrastructure team of a large video service.
Situation: Item embeddings were hash-sharded across 96 parameter servers, balanced by row count, and training throughput had been stable for months.
Problem: A single trailer went viral overnight; its item ID was suddenly in a third of all training batches, and the one shard owning that row saturated its network link while the other 95 sat idle.
Dilemma: Reshard the whole table to move the hot row, an expensive global operation, or leave balanced storage alone and attack the access skew directly.
Decision: They left storage untouched and added per-worker gradient pre-aggregation for hot rows plus read replication of the top few thousand IDs, so each worker pushed one combined update per hot row instead of one per occurrence.
How: A lightweight popularity counter flagged hot IDs each epoch; flagged rows were replicated to three shards for reads and their gradients were locally summed before the push.
Result: The skew factor on the hot shard fell from about $11\times$ to under $1.6\times$, and step time returned to its pre-launch level without resharding a terabyte of storage.
Lesson: Balanced storage does not imply balanced load. Popularity skew is an access problem, and it is cheaper to spread the reads and pre-aggregate the writes than to move the rows.
5. A Sharded Embedding Table You Can Run Intermediate
The code below builds a small embedding table, shards it by a multiplicative hash of the ID across $S$ shards, draws a heavy-tailed batch of requested IDs so a few are hot, and performs the routed sparse gather: each ID goes to its shard, the local row is read, and the rows are reassembled in request order. It reports the per-shard storage (balanced), the per-shard lookup load (skewed by popularity), the resulting skew factor, and the hottest IDs, then verifies that the gathered rows are bit-for-bit identical to a single-table reference. The verification is the point: sharding changes where the rows live, not which numbers come back.
import numpy as np
# A toy embedding table: N rows (one per user OR item ID), each of dimension d.
# In production N is billions and the table is terabytes; here it is small so we
# can hold a single-table REFERENCE and check the sharded version against it.
N, d, S = 20_000, 16, 4 # rows, embedding dim, parameter-server shards
rng = np.random.default_rng(7)
table = rng.standard_normal((N, d)).astype(np.float32) # the "true" full table
# Shard the table by hashing the ID. Row id -> shard (id * A) mod S.
# Each shard physically stores only the rows assigned to it.
A = 2654435761
def shard_of(idx): # multiplicative hash, then mod S
return (idx * A) % S
owner = shard_of(np.arange(N))
shards = [table[owner == s] for s in range(S)] # disjoint partition
local_index = -np.ones(N, dtype=np.int64) # global id -> local row
for s in range(S):
ids_here = np.nonzero(owner == s)[0]
local_index[ids_here] = np.arange(len(ids_here))
# A request batch touches only a few rows (sparse lookup). We build a skewed
# (Zipf-like) workload so a handful of "hot" IDs dominate, as popular items do.
B_req = 50_000
zipf = rng.zipf(1.3, size=B_req) - 1
batch_ids = np.minimum(zipf, N - 1).astype(np.int64)
# Sharded gather: route each requested id to its shard, look up the local row,
# then all-gather the rows back in request order.
req_owner = owner[batch_ids]
gathered = np.empty((B_req, d), dtype=np.float32)
per_shard_load = np.zeros(S, dtype=np.int64)
for s in range(S):
mask = req_owner == s
per_shard_load[s] = mask.sum()
gathered[mask] = shards[s][local_index[batch_ids[mask]]]
# Single-table reference: gather the same rows straight from the full table.
reference = table[batch_ids]
print("rows N :", N)
print("shards S :", S)
print("requests in batch :", B_req)
print("rows per shard (store):", [len(x) for x in shards])
print("lookups per shard :", per_shard_load.tolist())
print("skew (max/mean load) :", f"{per_shard_load.max() / per_shard_load.mean():.2f}x")
uniq, counts = np.unique(batch_ids, return_counts=True)
order = np.argsort(counts)[-3:][::-1]
print("hottest 3 IDs :", uniq[order].tolist(), "-> hits", counts[order].tolist())
print("gathered == reference :", np.array_equal(gathered, reference))
print("max abs difference :", f"{np.max(np.abs(gathered - reference)):.2e}")
rows N : 20000
shards S : 4
requests in batch : 50000
rows per shard (store): [5000, 5000, 5000, 5000]
lookups per shard : [19843, 11367, 8924, 9866]
skew (max/mean load) : 1.59x
hottest 3 IDs : [0, 1, 2] -> hits [12911, 5023, 3113]
gathered == reference : True
max abs difference : 0.00e+00
Two facts from the output deserve emphasis. The storage line is flat, $5000$ rows on every shard, because the hash spreads row count evenly; the load line is not, because popularity, not row count, drives access. That gap between balanced storage and skewed load is the entire training-time problem of Section 4 made visible in eight lines of output. And the final comparison is exactly zero: distributing the table across four shards did not perturb a single returned value, just as data parallelism in Section 1.1 distributed the gradient without perturbing it. Distribution here is a relocation of where bytes live, not an approximation of what they are.
6. Memory-Tier Tricks: Fitting More Table in Less Space Advanced
Even sharded across many hosts, the tables strain memory budgets, so several tricks shrink them. Caching exploits the same skew that hurts training: the hot rows that overload one shard are also the rows most worth keeping in fast memory, so a small high-bandwidth cache of popular rows in front of a larger store on slower memory or even on the storage tier of Chapter 8 serves most requests from cache. Mixed precision stores embeddings in 16-bit or 8-bit formats, cutting $b$ in half or more in the $N \cdot d \cdot b$ budget at a usually negligible accuracy cost. The most distinctive trick is the hashing trick, in particular the quotient-remainder method, which represents each ID's vector as a combination of two much smaller tables indexed by $\text{id} \bmod m$ and $\lfloor \text{id} / m \rfloor$. Two tables of size $m$ and $N/m$ cover $N$ IDs with
$$\big(m + \tfrac{N}{m}\big)\, d \, b \ \ll\ N \, d \, b,$$so a billion-row table compresses into the sum of two tables of a few tens of thousands of rows each, trading a small chance of collision for a large memory saving. These compressions are orthogonal to sharding; you shard the compressed tables. Together with the cache-the-hot-rows idea, they let a fixed cluster hold a table that nominally would not fit, which is why production embedding stacks combine all three.
Of all the "distribute the model" stories in this book, embedding tables are the purest. There is no algorithmic cleverness in the decomposition and no choice to make about whether to shard; the parameter count, set by the number of users and items in the world, simply exceeds any one machine's memory, so the model is sharded the moment it exists. The collectives that move gradient slices in data-parallel training (Section 1.1) here move embedding rows, the parameter server of Chapter 11 becomes the literal storage substrate, and the sharded-parallel patterns of Chapter 16 reappear as table sharding. The recommender is the case where scale-out is not a strategy laid over the model but the only form in which the model can exist.
Because embeddings dominate both memory and cost, shrinking them is an active frontier. Learned and adaptive table sizing assigns more dimensions to frequent IDs and fewer to rare ones, in the lineage of mixed-dimension and frequency-aware embeddings, so capacity follows popularity rather than a uniform $d$. Compositional and quotient-remainder embeddings continue to be pushed toward higher compression with less collision cost, and hierarchical placement spreads a single logical table across GPU memory, host memory, and SSD with software-managed paging, as in Meta's TorchRec and NVIDIA's HugeCTR and Merlin stacks. A parallel thread treats the hot-key problem as a systems co-design question, combining popularity-aware caching, dynamic resharding, and gradient pre-aggregation into the embedding server itself. The throughline is that the embedding table, long treated as a passive lookup, is becoming a first-class distributed data structure engineered jointly for memory, skew, and accuracy.
The most expensive object in a billion-dollar recommender is, structurally, a dictionary: integer keys, vector values. Decades of database research went into B-trees and the embedding store quietly settled on hash-and-shard, the same move a first-year student makes on a whiteboard. Scale did not make the data structure fancier; it made the unglamorous one mandatory.
7. Where the Rows Go Next Beginner
We now have the model's bulk accounted for: terabyte-scale embedding tables, sharded by ID hash across parameter servers, read by sparse all-gathers and updated by sparse pushes, with hot keys skewing the load and memory-tier tricks holding the footprint down. What the tables produce, the dense vectors gathered per request, does not sit idle; it feeds retrieval. The user vector becomes a query, the item vectors become the searchable corpus, and finding the best few hundred candidates among hundreds of millions is a distributed nearest-neighbor problem in its own right. That is the subject of Section 38.3, which takes the embeddings built here and asks how to search them at scale, connecting back to the distributed retrieval machinery of Chapter 25.
A service has $3 \times 10^{9}$ users and $8 \times 10^{8}$ items, each embedded in $d = 96$ dimensions. (a) Using $\text{table bytes} = N \cdot d \cdot b$, compute the memory for both tables in single precision ($b = 4$) and again in 8-bit ($b = 1$). (b) If each parameter server can dedicate 24 gigabytes to embeddings, how many shards $S$ does the single-precision case need from the $\text{shard bytes} \approx N d b / S$ relation? (c) The Adam optimizer keeps two extra state vectors per parameter. Recompute the single-precision footprint including optimizer state and explain why training memory, not serving memory, usually sets the shard count.
Extend Code 38.2.2. (a) Replicate the hottest $k$ IDs onto all $S$ shards (store their rows everywhere) and, for replicated IDs, route each lookup to a uniformly random shard instead of the hash-chosen one; recompute the skew factor as a function of $k$ and plot or print it. (b) Confirm that gathered still equals reference under replication, since every replica holds the same row. (c) Explain why read replication helps lookup skew but, on its own, does nothing for the gradient-push skew of Section 4, and what additional step closes that gap.
Consider compressing an $N = 10^{9}$ row table with the quotient-remainder trick into two tables of sizes $m$ and $N/m$, where the embedding for ID $i$ is the sum (or product) of row $i \bmod m$ of the first table and row $\lfloor i/m \rfloor$ of the second. (a) Minimize the combined size $(m + N/m)\,d\,b$ over $m$ and report the optimal $m$ and the compression ratio against the full table. (b) Two distinct IDs collide on a shared component when they agree in one coordinate; estimate the probability that two random IDs share their remainder coordinate. (c) Argue qualitatively when this collision cost is acceptable (hint: relate it to which IDs are frequent versus rare) and why the trick composes cleanly with the hash sharding of Section 2.