"I hold a vector for every user, every product, and every word anyone ever typed. Each training step asks me for nine of them. The other billion just sit there, judging."
An Embedding Table With Mostly Idle Rows
An embedding table assigns a learned vector to every categorical entity (a user, an item, a word, an ad), so the table can have billions of rows, yet each training example touches only a handful of them. That gap between the size of the model and the size of any single update is the whole story of this section. Because a step pulls a tiny keyed slice and writes back a tiny keyed slice, the parameter-server push/pull with sparse keyed access from Section 11.2 fits the workload like a glove, while all-reduce, which would synchronize the entire table on every step, is the wrong tool. This is the workload parameter servers were born for, and it is where their advantage over dense collective communication is most decisive.
Most of the parameters in a large recommendation or ranking model are not in its dense layers; they are in its embedding tables. A categorical feature such as "user id" has no natural numeric meaning, so we cannot feed the raw integer into a neural network and expect it to learn. Instead we give each distinct value its own short, dense vector and let training discover what that vector should be. A platform with two hundred million users, fifty million items, and a vocabulary of hashed cross-features can easily carry tens of billions of such vectors. The dense part of the same model, the multilayer perceptron that consumes the looked-up vectors, might be only a few million parameters. The embedding tables are the elephant, and they are sparse: any one training example references only the few entities it actually involves.
This section explains the sparse structure precisely, shows why the parameter-server primitives of this chapter are the natural home for it, and builds a tiny embedding table from scratch so you can watch the sparsity in action: only the rows in a batch change, and each row carries its own optimizer state that accumulates independently of every other row. We then tie embeddings forward to the retrieval and recommendation systems that consume them at serving time.
1. Embeddings: A Learned Vector per Entity Beginner
An embedding layer is, at heart, a lookup table. Let a categorical feature take values in a vocabulary of size $V$ (for example, $V$ distinct user ids). We allocate a matrix $E \in \mathbb{R}^{V \times d}$, where $d$ is a small embedding dimension, often between 8 and 256. The embedding of entity $i$ is simply row $i$ of that matrix, $E_i \in \mathbb{R}^{d}$. There is no matrix multiply in the lookup itself: feeding the one-hot vector $e_i$ through $E$ would give $e_i^\top E = E_i$, but materializing a $V$-dimensional one-hot and multiplying it by the whole table would be absurd when $V$ is in the billions. The lookup is a memory access, an indexing operation,
$$\text{lookup}(i) = E_i = E[i, :], \qquad i \in \{1, \dots, V\}.$$A real example carries several categorical features, and a feature may be multi-valued (a user's list of recently clicked items). The model gathers the relevant rows and pools them, typically by summing or averaging, before handing the result to its dense layers. For a feature whose example holds a set $S$ of active keys, the pooled embedding is $\frac{1}{|S|}\sum_{i \in S} E_i$. The crucial property for distribution is that $|S|$ is tiny next to $V$: a user clicked a few dozen items, not a few billion. The number of rows any one example reads is bounded by the number of features it actually has, which is a property of the data, not of the table size.
An embedding table can hold tens of billions of parameters, yet a single training example reads and writes only the few rows for the entities it names. The active fraction per step is the ratio of touched rows to total rows, and it is often less than one in a million. This is the defining asymmetry of sparse models: the parameter count is set by the vocabulary, but the per-step communication is set by the batch's feature count. Any system that moves the whole table per step (as a naive all-reduce would) is paying for parameters it never touched.
2. The Sparse Gradient: Only Used Rows Move Intermediate
The forward sparsity carries straight through to the backward pass, and this is what makes the parameter-server fit exact rather than approximate. Suppose the loss $\mathcal{L}$ depends on the table only through the rows that were looked up. For a row $i$ that was not in the example, the loss does not depend on $E_i$ at all, so its gradient is exactly the zero vector. For a row that was looked up, the gradient flows back through the pooling and into that row alone:
$$\frac{\partial \mathcal{L}}{\partial E_i} = \begin{cases} \dfrac{\partial \mathcal{L}}{\partial g}\cdot \dfrac{1}{|S|} & i \in S \\[1.2ex] \mathbf{0} & i \notin S \end{cases}, \qquad g = \frac{1}{|S|}\sum_{j \in S} E_j.$$The full gradient $\partial \mathcal{L} / \partial E$ is therefore a $V \times d$ matrix with at most $|S|$ nonzero rows. Representing it densely would waste billions of zeros; instead the framework returns a sparse gradient, a short list of (row index, gradient vector) pairs. This is precisely the shape that the push/pull interface of Section 11.2 consumes: pull the keyed rows you need, compute, push back the keyed gradients for exactly those rows. The keys that index the parameter server are the entity ids, and the sparsity of the gradient is what keeps each push small no matter how large the table grows.
All-reduce, the exact gradient combine that opened the book in Section 1.1 and powers data-parallel deep learning in Chapter 15, assumes every worker holds a dense gradient of the same shape, so summing them is meaningful and cheap relative to compute. An embedding table breaks that assumption: each worker's gradient is nonzero on a different handful of rows, and the rows themselves are too many to ever materialize densely on one device. Summing two sparse gradients with disjoint keys is a union, not an overlap, and a dense all-reduce over a billion-row table would move terabytes to combine kilobytes of real signal. The parameter server wins here because it indexes by key: it touches only the rows that moved. This is the chapter's central claim, that push/pull beats all-reduce exactly when the model is sparse and the table does not fit on one device.
3. Sparse Optimizers: Per-Row State Intermediate
Adaptive optimizers such as Adagrad and Adam keep a running statistic per parameter: a sum of squared gradients for Adagrad, first and second moment estimates for Adam. For a dense layer those statistics live alongside the weights and are updated every step. For an embedding table, updating the optimizer state of every row every step would defeat the entire point of sparsity, since most rows received a zero gradient and their state should not change. A sparse optimizer therefore updates the per-row state only for the rows that appeared in the batch. Adagrad on row $i$ accumulates
$$G_i \leftarrow G_i + \left(\frac{\partial \mathcal{L}}{\partial E_i}\right)^{\!2}, \qquad E_i \leftarrow E_i - \frac{\eta}{\sqrt{G_i} + \epsilon}\,\frac{\partial \mathcal{L}}{\partial E_i},$$applied elementwise, and only for $i$ in the batch. A frequently seen entity (a popular item) accumulates a large $G_i$ and so takes progressively smaller steps, while a rarely seen entity keeps a small $G_i$ and a relatively large effective learning rate. Each row's state evolves on its own clock, set by how often that entity appears, which is exactly the behavior you want when entity frequencies are wildly skewed. The next code makes a tiny table and shows both the sparsity of the update and this independent per-row accumulation directly.
import numpy as np
rng = np.random.default_rng(7)
V, D = 1_000_000, 8 # one million rows (entities), 8-dim vectors each
table = np.zeros((V, D)) # the embedding table; row v is entity v's vector
G = np.zeros((V, D)) # per-row Adagrad accumulator (sparse optimizer state)
lr, eps = 0.1, 1e-8
touched = set() # which rows have ever been pulled/updated
def step(batch_rows, target):
# PULL: fetch only the rows this batch references (keyed sparse access).
rows = np.array(sorted(set(batch_rows)))
vecs = table[rows] # tiny slice, not the whole table
# toy gradient: pull each used row toward a shared target vector.
grad = vecs - target # gradient only for used rows
G[rows] += grad * grad # accumulate per-row Adagrad state
table[rows] -= lr * grad / (np.sqrt(G[rows]) + eps) # per-row scaled update
touched.update(rows.tolist())
return rows
target = np.ones(D)
# Three batches, each referencing only a handful of the million rows.
b1 = step([3, 17, 17, 42, 256], target)
b2 = step([42, 99, 1000], target)
b3 = step([42, 256], target) # row 42 seen in all three; 256 in two; 99 in one
rows_changed = np.flatnonzero(np.any(table != 0.0, axis=1))
print("table shape :", table.shape)
print("rows ever touched :", sorted(touched))
print("rows actually changed :", rows_changed.tolist())
print("untouched fraction :", f"{1 - len(touched)/V:.6f}")
print()
# Per-row optimizer state accumulates independently per row's visit count.
for r in [42, 256, 99, 1000]:
print(f"row {r:>4}: visits-implied G[0]={G[r,0]:.4f} vec[0]={table[r,0]:+.4f}")
print()
# Row 7 was never in any batch: its vector and state remain exactly zero.
print("row 7 (never used) : vec norm =", np.linalg.norm(table[7]),
" G norm =", np.linalg.norm(G[7]))
step mimics the parameter-server pull/push: it touches only the rows named in the batch, updates their vectors, and accumulates their optimizer state in G independently of every untouched row.table shape : (1000000, 8)
rows ever touched : [3, 17, 42, 99, 256, 1000]
rows actually changed : [3, 17, 42, 99, 256, 1000]
untouched fraction : 0.999994
row 42: visits-implied G[0]=2.5041 vec[0]=+0.2195
row 256: visits-implied G[0]=1.8100 vec[0]=+0.1669
row 99: visits-implied G[0]=1.0000 vec[0]=+0.1000
row 1000: visits-implied G[0]=1.0000 vec[0]=+0.1000
row 7 (never used) : vec norm = 0.0 G norm = 0.0
The output is the section in miniature. The table has a million rows, but the set of rows that ever changed is the union of the batch keys and nothing more, so the untouched fraction is 0.999994. The accumulator values rank exactly by how many batches each row appeared in, which is the per-row clock the optimizer runs on. A real system replaces the in-process array with a sharded parameter server, the slice indexing with keyed pull and push over the network, and the toy gradient with the backward pass of a deep model, but the access pattern is precisely the one shown here.
Row 7 in Output 11.6.1 is the embedding nobody asked for: an entity that exists in the vocabulary but never appeared in training. Its vector is still the initial value, and at serving time it contributes a generic, uninformative signal. This is the cold-start problem wearing a numerical disguise, and it is why production systems fall back to content features or popularity priors for entities whose embeddings have not yet been trained. A billion-row table is, on any given day, mostly cold rows waiting for their first gradient.
4. Hashing Tricks and Feature Collisions Intermediate
Even a sparse table has a ceiling: you must allocate one row per distinct entity, and for open vocabularies (raw search queries, arbitrary cross-features, a feature store that grows every day) the number of distinct keys is unbounded and unknown in advance. The hashing trick caps the table at a fixed number of rows $B$ by mapping each entity through a hash function $h$ into a bucket, $\text{row}(i) = h(i) \bmod B$. Now the table size is a budget you choose, not a quantity the data dictates, and unseen entities map to some bucket without any vocabulary lookup. The cost is collisions: two distinct entities that hash to the same bucket share a single embedding and become indistinguishable to the model. With $B$ buckets and $n$ active entities, the expected number of collided pairs grows like $n^2 / (2B)$, so you trade memory against the blurring of rare entities.
In practice collisions are tolerable because they fall hardest on infrequent keys, whose embeddings were poorly trained anyway, while frequent keys, although they can still collide, dominate their bucket's gradient and so steer it toward their own meaning. Multi-hashing (the "hashing trick" generalized to several independent hashes whose buckets are summed) reduces the chance that two entities collide in every hash at once, recovering much of the lost capacity at a fixed memory budget. Choosing $B$ is the same kind of resource decision as choosing the number of workers in Chapter 10: bigger is more faithful but costs memory and network, and the right size is measured, not guessed.
Code 11.6.1 hand-rolled the sparse lookup, the keyed gradient, and the per-row optimizer. Production stacks give you all three as configured objects. TorchRec, PyTorch's library for large-scale recommendation, exposes an EmbeddingBagCollection whose tables are automatically sharded across devices and servers, with sparse all-to-all communication and fused sparse optimizers built in:
# pip install torchrec
from torchrec import EmbeddingBagConfig, EmbeddingBagCollection
from torchrec.optim import RowWiseAdagrad # sparse, per-row optimizer state
cfg = EmbeddingBagConfig(
name="user_id", embedding_dim=64,
num_embeddings=200_000_000, # 200M rows, sharded for you
feature_names=["user_id"],
)
ebc = EmbeddingBagCollection(tables=[cfg], device="meta")
# DistributedModelParallel later shards `ebc` across the cluster and routes
# each batch's keys to the shard that owns them; you never index rows by hand.
tf.feature_column and Keras StringLookup plus Embedding offer the equivalent with hashed columns for open vocabularies.5. Where These Vectors Go Next Beginner
Embeddings are not an end in themselves; they are the input currency of retrieval and recommendation. Once trained, the item embeddings become a searchable index: to recommend for a user, you take the user's embedding and find the nearest item vectors, which is approximate nearest-neighbor (ANN) search rather than a scan over billions of items. The classical-ML treatment of ANN structures lives in Chapter 12, and the distributed vector databases that serve these lookups at scale, sharding the index the same way this chapter shards the table, are the subject of Chapter 25. The full recommendation pipeline, where the embedding table is trained as in this section and queried as in Chapter 25, is built end to end in the case study of Chapter 38.
This forward arc is also why the partitioning ideas of this chapter recur: the embedding table you shard for training in Section 11.2 becomes the vector index you shard for serving in Chapter 25, and the same key-routing logic decides which machine answers a given lookup. The next section pushes the table size to its limit, examining the terabyte-scale embeddings of industrial recommendation systems and the memory hierarchy (HBM, host RAM, SSD) that holds them when even the sharded servers run short of fast memory.
Who: An ML platform engineer on the ads-ranking team at a large marketplace.
Situation: The ranking model carried a 40-billion-parameter embedding table for users, items, and cross-features, plus a small dense tower on top.
Problem: The team's first distributed version wrapped the entire model in all-reduce data parallelism, and every step synchronized the full table across workers, saturating the network while the dense tower sat idle.
Dilemma: Keep the uniform all-reduce path, simple but moving terabytes per step to combine a few megabytes of real gradient, or split the model so the sparse table used a parameter server with keyed push/pull while the dense tower kept all-reduce.
Decision: They split it, because the gradient was sparse on the table (each batch touched well under one in a million rows) and dense only on the small tower, so the two parts wanted opposite communication patterns.
How: They moved the embedding tables onto sharded parameter servers with sparse all-to-all routing (the TorchRec pattern of Code 11.6.2) and left DistributedDataParallel on the dense tower alone.
Result: Per-step network traffic fell by more than two orders of magnitude, step time dropped from communication-bound to compute-bound, and accuracy was unchanged because the math of the sparse update was identical; only the bytes moved differed.
Lesson: Match the communication pattern to the gradient's structure. Sparse tables want keyed push/pull; dense layers want all-reduce; forcing one pattern on both wastes the network on whichever part it does not fit.
6. When Sparse Tables Earn Their Complexity Intermediate
A parameter server for embeddings is not free: it adds a routing layer, a sharding scheme, and a fault-tolerance story (the table is now spread across machines that can fail, taken up in Section 11.8). It earns that complexity only when two conditions hold together: the table is too large to replicate on every worker, and the per-step access is genuinely sparse. If the table is small enough to fit on each device, replicating it and using all-reduce is simpler and often faster, because you avoid the routing round trip. If the access is dense (every step reads most rows), the parameter server loses its advantage, since it would move nearly the whole table anyway. The interesting regime, and the one industrial recommenders live in, is large-and-sparse, where the table cannot be replicated and any single step touches a vanishing fraction of it.
This is the same "match the remedy to the ceiling" discipline that the book has insisted on since Section 1.1: the binding constraint here is model memory (the table does not fit) combined with a communication structure (the gradient is sparse), and the parameter server is the remedy fitted to exactly that pair. Get either condition wrong and a simpler design wins. The skill is recognizing the large-and-sparse regime when you are in it, and the embedding table is its purest example.
Large embedding systems are an active engineering frontier. TorchRec (Meta) has become the open reference for sharded embedding training, pairing fused row-wise optimizers with planner-chosen table placement and sparse all-to-all, and recent releases push toward variable-batch and quantized communication to cut the all-to-all bytes. NVIDIA's HugeCTR and its Merlin stack target GPU-resident embeddings with a hierarchical parameter server that tiers rows across HBM, host memory, and SSD, the memory hierarchy that Section 11.7 examines. A parallel research line attacks the table size itself: compositional and hashing-based embeddings (the quotient-remainder trick and its descendants), learned mixed-dimension tables that give frequent entities longer vectors than rare ones, and frequency-aware caching that keeps hot rows in fast memory. The throughline of 2024 to 2026 work is treating the embedding table as a tiered, compressible, sharded data structure rather than a monolithic matrix, which is the systems view this chapter has been building toward.
A model has an embedding table of $V = 2 \times 10^{9}$ rows with $d = 64$ float32 entries per row, plus a dense tower of $10^{7}$ parameters. A training batch references on average 200 distinct embedding rows. (a) How many bytes is the full table, and how many bytes are the embedding gradients actually produced by one batch? (b) Estimate the per-step bytes a naive dense all-reduce would move for the table versus a keyed push/pull. (c) Argue from these numbers which part of the model belongs on a parameter server and which belongs in all-reduce, and connect your answer to the split in the Practical Example.
Extend Code 11.6.1 so that each batch's keys are drawn from a heavily skewed distribution (for example, Zipf, so a few rows appear constantly and most appear rarely). Run several hundred batches, then plot or print, for the ten most-frequent and ten least-frequent rows, their accumulated Adagrad state $G_i$ and the norm of their final vector. Confirm that frequent rows have large $G_i$ and small effective step size, and explain in two sentences why this per-row adaptivity is exactly what you want when entity frequencies are skewed. Then count what fraction of all rows were ever touched and relate it to the untouched fraction in Output 11.6.1.
Using the hashing trick of Section 4, take $n$ active entities hashed into $B$ buckets. (a) Derive the expected number of colliding pairs as a function of $n$ and $B$, and the expected fraction of entities that share their bucket with at least one other. (b) For $n = 10^{8}$ and $B = 10^{7}$, compute both. (c) Explain why a single hash with this $B$ may still be acceptable in practice despite the collision rate you found, and how adding a second independent hash (multi-hashing) changes the probability that two specific entities collide in every hash at once. Tie your conclusion back to the memory-versus-fidelity trade the section described.