Part III: Distributed Machine Learning
Chapter 11: Parameter Servers and Distributed Embeddings

Motivation for Parameter Servers

"All-reduce kept asking me to hand it my whole self every step. I have a billion coordinates and most of them did not change. So I moved into a server and let the workers come to me for the few that did."

A Parameter Server Under Mild Staleness
Big Picture

A parameter server keeps the model's parameters on dedicated, sharded server nodes, and lets workers pull the parameters they need and push gradients back, so that training a model too large to replicate, or too sparse to touch fully each step, becomes natural rather than wasteful. The all-reduce data parallelism of Section 10.5 assumes every worker can hold a full copy of the model and exchange the full gradient every step. That assumption breaks for the two workloads that dominate industrial machine learning: models whose parameters are larger than any single worker's memory, and models so sparse that each step reads a vanishing fraction of them. This section explains why a centralized, sharded store of parameters was invented to answer both, where it still wins today, and how the rest of the chapter builds it out.

By the end of Chapter 10 we had a clean and powerful picture of distributed training: every worker holds a full copy of the model, each computes a gradient on its own data shard, and a single all-reduce averages those gradients so that every worker steps in lockstep on the exact full-batch direction. That design is symmetric and elegant, and for dense models that fit comfortably on one device it is the right answer; Chapter 15 turns it into the production training loop. The parameter server exists because two enormous and very common model families violate the assumptions underneath that picture, and violate them so badly that all-reduce stops being merely suboptimal and becomes impossible.

1. Where All-Reduce Data Parallelism Runs Out Beginner

All-reduce SGD rests on two quiet assumptions. The first is that every worker can store a full replica of the model: the parameters, and during the backward pass the gradient of every parameter. The second is that the natural unit of communication is the dense full-model gradient, summed coordinate by coordinate across workers each step. Both assumptions are true for a ResNet or a modestly sized transformer. Both are false for the models that run the largest machine learning services on earth.

Consider the embedding table at the heart of a recommendation or ranking system. Every user identifier, every item identifier, every categorical feature value is mapped to a learned vector, and the number of distinct values runs into the hundreds of millions or billions. A table with $10^9$ rows and a $128$-dimensional embedding holds $1.28 \times 10^{11}$ parameters, roughly half a terabyte at four bytes each. No single accelerator, and few single servers, can hold that. Replicating it on every worker, as all-reduce demands, is out of the question before training even starts.

Now look at what one training step actually touches. A batch of a few thousand interactions references at most a few thousand distinct users and items, so it reads and updates a few thousand of the billion rows. The gradient for that step is overwhelmingly zero: it is exact, but it is sparse, with only a minute fraction of coordinates nonzero. All-reduce, which sums dense vectors coordinate by coordinate, would communicate the entire half-terabyte gradient (almost all of it zeros) on every step. The communication cost would dwarf the useful work by a factor of a million. The two failures compound: the model is too big to replicate, and the gradient is too sparse to justify dense exchange.

Key Insight: Two Independent Failures, One Shared Cure

All-reduce data parallelism assumes a model small enough to replicate on every worker and dense enough that exchanging the whole gradient each step is worthwhile. Huge embedding tables break the first assumption (they do not fit), and sparse access breaks the second (each step touches almost nothing). A parameter server cures both at once by never replicating the model and never moving coordinates a worker did not ask for: parameters live in one logical place, sharded across servers, and workers pull and push only the slices they touch.

2. The Parameter-Server Answer Beginner

The idea is to stop treating the model as something every worker owns a copy of, and instead treat it as a shared data structure that lives on its own machines. Dedicated server nodes hold the authoritative parameters. Because the full table is far too large for one server, it is partitioned across many: each server holds a disjoint shard of the rows, so that the combined memory of the server fleet, not of any one node, bounds the model size. Workers hold no permanent copy of the model at all. Instead, each step they perform two operations against the servers. They pull the specific parameters their current batch needs, a handful of embedding rows out of billions. They compute the forward and backward pass locally, producing a sparse gradient over exactly those rows. Then they push that sparse gradient back, and the servers apply it to their shards.

This push-pull pattern, the subject of Section 11.2, is what makes both hard cases easy. Sparsity is handled for free: a worker pulls and pushes only the rows it touched, so communication scales with the access pattern of the batch rather than with the size of the model. Asynchrony is handled naturally too. Because the servers hold the single authoritative copy and apply each push as it arrives, there is no barrier forcing all workers to finish a step together. A fast worker can pull, compute, and push again while a slow worker is still on its previous step, exactly the asynchronous SGD whose convergence behavior we studied in Section 10.4. The structure that makes the parameter server attractive for large sparse models is the same structure that lets it tolerate stragglers without a global wait.

Sharded parameter servers: one logical model, split across nodes Server A rows 0 .. 0.33B Server B rows 0.33B .. 0.66B Server C rows 0.66B .. 1B Worker 1 batch -> few rows Worker 2 batch -> few rows Worker 3 batch -> few rows pull rows push sparse grad each worker touches only the rows its batch references; untouched rows never move
Figure 11.1.1: The parameter-server pattern. The model is one logical table sharded across servers A, B, and C, so total model size is bounded by the server fleet's memory, not any single node. Each worker pulls only the few rows its batch references (solid green), computes locally, and pushes a sparse gradient back (dashed orange). A worker may pull from several servers at once, as Worker 3 does. Communication scales with the access pattern, not the model size, and no global barrier forces the workers to step together.

Figure 11.1.1 makes the asymmetry visible. Unlike the symmetric all-reduce of Section 10.5, where there is no central store and every worker is identical, the parameter-server design has two distinct roles: stateless workers that compute, and stateful servers that hold the model. That split is the whole point. It decouples the size of the model (which lives on the servers and grows by adding servers) from the number of workers (which scales compute and grows by adding workers), and it lets the two scale independently.

3. A Parameter Server in Thirty Lines Intermediate

The mechanics are simpler than the terminology suggests. The code below is a complete, single-process parameter server: a flat weight vector standing in for a tiny embedding table, a pull that returns only requested coordinates, and a push that applies a sparse gradient to exactly the coordinates it names. Three workers each drive a different small slice of the vector toward a target, round-robin, and the point of the run is to watch which coordinates move and which do not.

# The parameter server: a flat weight vector of D coordinates, held in one place.
D = 20
weights = [0.0] * D          # the single source of truth, held on the server
lr = 0.5                     # the server applies updates with this step size

def pull(indices):
    """A worker asks the server for ONLY the coordinates it needs."""
    return {i: weights[i] for i in indices}

def push(grad):
    """A worker sends back a sparse gradient; the server updates in place.
    Only the coordinates present in `grad` are ever touched."""
    for i, g in grad.items():
        weights[i] -= lr * g

# Three workers, each touching only a tiny, different slice of the table.
active = {0: [1, 4], 1: [4, 9, 14], 2: [17]}
target = 1.0   # every active coordinate should converge toward this value

touched = set()
for step in range(12):
    w = step % 3                      # round-robin over the three workers
    idx = active[w]
    touched.update(idx)
    local = pull(idx)                 # PULL only the needed coordinates
    grad = {i: (local[i] - target) for i in idx}   # sparse local gradient
    push(grad)                        # PUSH the sparse gradient back
    if step < 4:
        print(f"step {step:2d}  worker {w}  pulled {idx}")
print("  ...")

print()
print("touched coordinates :", sorted(touched))
print("untouched stayed 0  :", all(weights[i] == 0.0 for i in range(D) if i not in touched))
print("final weights       :", "[" + " ".join(f"{x:+.2f}" for x in weights) + "]")
print("active near target  :", all(abs(weights[i] - target) < 0.1 for i in touched))
Code 11.1.1: A minimal parameter server. The server holds the only copy of weights; each worker pulls a sparse slice, forms a gradient over just that slice, and pushes it back. No worker ever holds or communicates the full vector.
step  0  worker 0  pulled [1, 4]
step  1  worker 1  pulled [4, 9, 14]
step  2  worker 2  pulled [17]
step  3  worker 0  pulled [1, 4]
  ...

touched coordinates : [1, 4, 9, 14, 17]
untouched stayed 0  : True
final weights       : [+0.00 +0.94 +0.00 +0.00 +1.00 +0.00 +0.00 +0.00 +0.00 +0.94 +0.00 +0.00 +0.00 +0.00 +0.94 +0.00 +0.00 +0.94 +0.00 +0.00]
active near target  : True
Output 11.1.1: Only the five coordinates any worker referenced (1, 4, 9, 14, 17) ever moved; the other fifteen stayed exactly zero, and every touched coordinate converged toward the target. This is sparse training in miniature: communication and computation both scale with what a batch touches, not with the size of the table.

The output is the whole argument in miniature. Coordinate 4, the one row two different workers both pulled, received pushes from both and still converged; the servers serialized the concurrent updates, a hint of the contention questions Section 11.4 takes up. The fifteen untouched coordinates never moved a bit, because no worker ever pulled or pushed them. Scale this picture up: replace twenty coordinates with a billion embedding rows and three workers with a thousand, shard the vector across many servers instead of one process, and replace the toy gradient with a real recommendation model, and you have the architecture that trains the world's largest sparse models. The logic does not change, only the scale.

Thesis Thread: Sharding the Model Is the Other Half of Scale-Out

Data parallelism (Section 1.1) scaled out the computation by splitting examples across workers while the model stayed replicated. The parameter server scales out the model itself by splitting parameters across servers while the computation stays distributed. These are the two complementary axes of distributed training, and the sharding move you meet here as embedding-table partitioning returns transformed in Chapter 16, where ZeRO and FSDP shard a dense model's parameters and optimizer state with the same logic of "no node holds the whole model, every node holds a slice." When you meet a sharded-parameter system later, ask what plays the role of the server shard and what plays the role of the pull.

4. Where the Parameter Server Still Wins Today Intermediate

It would be a mistake to read this chapter as history. All-reduce won the dense-model contest: for training a large language model whose every parameter is touched every step, ring or tree all-reduce over a fast interconnect is faster and simpler than routing gradients through a separate server tier, and Section 11.9 makes that comparison carefully. But the parameter server did not lose; it specialized. It remains the dominant architecture wherever the model is dominated by huge sparse embedding tables, which is to say across most of the recommendation, ranking, and advertising systems that generate the bulk of industrial machine learning revenue. In those systems a hybrid is now standard: the dense neural-network layers are trained data-parallel with all-reduce, while the enormous sparse embedding tables live on a parameter-server tier and are accessed by pull and push, each part using the architecture suited to its access pattern.

The reason is structural, not historical. An access pattern that touches a tiny, data-dependent fraction of an enormous model every step is fundamentally a sparse-lookup problem, and a sharded store that serves lookups and applies sparse updates is the natural data structure for it. No amount of interconnect bandwidth makes dense all-reduce of a half-terabyte mostly-zero gradient sensible. The embeddings-at-scale arc that begins here continues through approximate nearest neighbor search in Chapter 12 and vector databases in Chapter 25, and culminates in the full distributed recommendation case study of Chapter 38.

Research Frontier: Terabyte Embedding Systems (2024 to 2026)

The parameter server is very much a live research and engineering frontier, driven by recommendation models whose embedding tables now reach tens of terabytes. Meta's TorchRec and its underlying sharding planner (Ivchenko et al.) treat embedding-table placement across a heterogeneous GPU and host-memory hierarchy as an explicit optimization problem, and NVIDIA's HugeCTR and Merlin push embedding lookups through a GPU-resident cache backed by a parameter-server tier. A vigorous 2024 to 2026 line attacks the memory wall directly: learned hashing and compositional embeddings shrink the table, and hierarchical placement keeps hot rows in GPU memory while cold rows spill to host RAM or SSD, so that an effectively multi-terabyte model serves from a fraction of that in fast memory. The throughline is the one this section opened with: when the model is huge and the access is sparse, the winning systems still keep parameters in a sharded store and move only what a batch touches. We return to these systems with the full machinery in Chapter 38.

Library Shortcut: TorchRec Shards the Embedding Tables for You

Code 11.1.1 hand-rolled the store, the pull, and the push to expose the mechanics. In production you do not manage shards by hand. PyTorch's TorchRec library takes a declaration of your embedding tables and a planner, and it shards, places, and serves them across the available devices, generating the pull (lookup) and push (sparse gradient) traffic automatically:

# pip install torchrec ; then run under torchrun across the worker+server fleet
from torchrec import EmbeddingBagCollection, EmbeddingBagConfig
from torchrec.distributed import DistributedModelParallel

ebc = EmbeddingBagCollection(tables=[
    EmbeddingBagConfig(name="user_id", num_embeddings=1_000_000_000,
                       embedding_dim=128, feature_names=["user_id"]),
    EmbeddingBagConfig(name="item_id", num_embeddings=500_000_000,
                       embedding_dim=128, feature_names=["item_id"]),
])
# The planner decides how to shard each billion-row table across devices;
# DistributedModelParallel wires up the pull/push collectives end to end.
model = DistributedModelParallel(ebc)     # sharded, sparse-aware, automatic
Code 11.1.2: The same pull-push store as Code 11.1.1, now declared not implemented. Roughly thirty lines of manual sharding, lookup, and sparse-update logic collapse to a table declaration plus one DistributedModelParallel wrap; the library chooses the shard layout, routes lookups, and handles the sparse all-to-all that Section 11.6 unpacks.
Practical Example: The Recommender That Would Not Fit on Any GPU

Who: A machine learning platform engineer at a streaming-media company rebuilding the homepage ranking model.

Situation: The new model added per-user, per-item, and per-creator embedding tables totaling 900 GB of parameters, while the dense ranking network on top was under 200 MB.

Problem: The team's existing all-reduce data-parallel pipeline required a full model replica on every GPU, and 900 GB fit on no accelerator they could buy.

Dilemma: Aggressively shrink the embeddings with hashing and lose ranking quality, or split the system so that the part that does not fit lives somewhere it can, at the cost of a more complex two-tier architecture.

Decision: They kept all-reduce for the tiny dense network and moved the embedding tables to a sharded parameter-server tier, because only the tables broke the replication assumption and only the tables had the sparse access pattern that pull and push exploit.

How: They declared the tables in TorchRec, let the planner shard them across host memory on the server nodes, and kept the dense network in standard DDP, so each step pulled the few thousand referenced rows and pushed sparse gradients to them.

Result: Training fit without shrinking a single table; per-step communication dropped by more than three orders of magnitude versus dense exchange of the full table, because only touched rows moved, exactly as Output 11.1.1 shows in miniature.

Lesson: Match the architecture to the access pattern. Dense parts want all-reduce; huge sparse tables want a sharded store with pull and push, and a real system uses both.

5. What This Chapter Builds Beginner

The rest of Chapter 11 turns the sketch above into a working architecture, one concern at a time. Section 11.2 formalizes the push-pull protocol and the worker-server interaction. Section 11.3 studies how to shard parameters across many servers and route requests to the right shard, the partitioning arc first met in Chapter 2. Section 11.4 contrasts synchronous and asynchronous update application, and Section 11.5 introduces bounded staleness, the controlled middle ground that caps how far behind any worker may fall. Sections 11.6 and 11.7 dive into sparse models and the terabyte-scale embedding tables of recommendation systems, the workload that justified the architecture in the first place. Section 11.8 confronts fault tolerance, because a sharded store of the only copy of the model is a single point of failure unless it is engineered not to be. Finally, Section 11.9 places the parameter server against all-reduce in modern systems, drawing the honest line between where each wins.

Fun Note: The Server That Remembers Everyone

An embedding table in a billion-user system is, in a real sense, a learned opinion about every single user and item that has ever appeared. The parameter server is where that institutional memory physically lives. When an engineer says a model "knows" a long-tail user, they mean one specific row, out of a billion, on one specific shard, holds a vector nobody has updated in weeks, waiting patiently to be pulled the next time that user comes back.

With the motivation settled, we can build the protocol. The next section makes the pull and push of Code 11.1.1 precise: what a worker requests, what a server returns, how updates are applied, and what guarantees each side may assume. That construction begins in Section 11.2.

Exercise 11.1.1: Which Architecture Fits? Conceptual

For each model, decide whether all-reduce data parallelism, a sharded parameter server, or a hybrid of both is the natural fit, and justify your choice from the two assumptions in Section 1 (does the model replicate, is the gradient dense): (a) a 7-billion-parameter language model where every parameter is updated every step; (b) a click-prediction model with a 4-billion-row embedding table and a 50 MB dense head; (c) a small image classifier fine-tuned on one GPU. Explain specifically which assumption each model honors or breaks.

Exercise 11.1.2: Make the Untouched Rows Move (and Then Stop Them) Coding

Start from Code 11.1.1. First replace the sparse push with a dense one that subtracts lr * (weights[i] - target) for every coordinate $i$ regardless of what any worker pulled, and confirm that previously untouched coordinates now drift. Then measure, for the sparse version, the total number of coordinate reads and writes across all twelve steps, and compare it to what a dense full-vector exchange would cost at $D = 20$ and again at $D = 10^9$. Report the ratio and explain why it grows without bound as the table grows while the access pattern stays fixed.

Exercise 11.1.3: The Bytes a Batch Actually Moves Analysis

An embedding table has $10^9$ rows of dimension $128$, stored as 4-byte floats. A training batch references $4{,}000$ distinct rows. Compute (a) the total size of the table, (b) the bytes a parameter-server worker pulls and pushes for one batch (rows referenced, two directions), and (c) the bytes a dense all-reduce of the full gradient would move per worker per step. Express (c) divided by (b) as the communication-saving factor of the parameter server for this workload, and state in one sentence why that factor is the quantitative reason the architecture exists. We make the underlying communication-cost model rigorous in Chapter 3.