"They told me to take a few steps on my own before checking in. I took a few hundred, found a lovely local optimum, and now nobody recognizes the model I came back with."
A Client That Drifted Too Far
Federated Averaging trains one shared model across machines that never share their data by replacing per-step gradient exchange with per-round model exchange: each round the server broadcasts the current model, every selected client improves it on its own data for several local steps, and the server folds the returned models back together with a size-weighted average. That single design choice, letting clients compute many local steps between communications, is what makes federated learning feasible over slow, intermittent links where exchanging a gradient every step would be hopeless. The cost is client drift: the more local work each device does, the further its model wanders toward its own data, and on non-identical data those wanderings do not average out cleanly. This section derives why the size-weighted average is the correct combine, measures the rounds-versus-drift trade-off in running code, and surveys the variants that buy back the accuracy that aggressive local computation gives away. The mental model above captures the whole loop in one image: broadcast, train locally on data that never leaves, average the returns.
The previous sections framed federated learning as training without centralizing data: hospitals, phones, and banks hold partitions that legal and physical constraints forbid moving to one place. That framing raises an immediate question. The data-parallel training of Section 1.1 assumed workers could exchange a gradient on every step over a fast interconnect. A fleet of phones on cellular links cannot do that; the network is slow, metered, and frequently unavailable. The whole apparatus of distributed training has to be rebuilt around a network that is the scarcest resource in the system. Federated Averaging, introduced by McMahan and colleagues in 2017, is the algorithm that performs that rebuild, and it remains the foundation that every later method modifies rather than replaces.
1. The Algorithm: One Round of FedAvg Beginner
FedAvg organizes training into communication rounds rather than gradient steps. At the start of round $t$ the server holds a global model $w^{(t)}$ and selects a subset $S_t$ of the available clients (on a fleet of millions of phones, only a small sample participates per round). It broadcasts $w^{(t)}$ to each selected client. Client $k$ then runs $E$ epochs of local stochastic gradient descent on its own dataset of $n_k$ examples, starting from the broadcast model, and arrives at an updated local model $w_k^{(t+1)}$. The clients send those models (not their data, and not raw gradients) back to the server, which combines them into the next global model. Figure 14.3.1 traces this loop end to end.
Two numbers control everything. The number of local epochs $E$ (equivalently, the number of local SGD steps per round) sets how much computation each client does before it has to communicate. The number of rounds sets how many times the network is used. The art of FedAvg is spending local computation, which is cheap and private, to buy down communication rounds, which are expensive and unreliable. The rest of this section is about how far that trade goes before it breaks.
2. Why the Size-Weighted Average Is the Right Combine Intermediate
The server does not average the returned models equally; it weights each by the size of the client's dataset. With $N = \sum_k n_k$ total examples across the participating clients, the aggregation rule is
$$w^{(t+1)} = \sum_{k \in S_t} \frac{n_k}{N}\, w_k^{(t+1)}, \qquad N = \sum_{k \in S_t} n_k.$$This weighting is not a heuristic; it is forced by the same gradient identity that justified data parallelism in Section 1.1. The federated objective is the average loss over all data, $L(w) = \frac{1}{N}\sum_k \sum_{i \in \mathcal{D}_k} \ell(w; x_i, y_i)$, which regroups by client into $L(w) = \sum_k \frac{n_k}{N} L_k(w)$, where $L_k$ is client $k$'s local average loss. Because the global loss is the size-weighted average of the local losses, its gradient is the size-weighted average of the local gradients: $\nabla L(w) = \sum_k \frac{n_k}{N}\nabla L_k(w)$. Consider the special case $E = 1$ with a single full-batch local step. Each client returns $w_k = w - \eta \nabla L_k(w)$, and the size-weighted average of those returned models is exactly $w - \eta \sum_k \frac{n_k}{N}\nabla L_k(w) = w - \eta \nabla L(w)$, one true gradient step on the global objective. FedAvg with one local step is centralized SGD; the size-weighting is what makes that equivalence hold, and a client with twice the data correctly pulls twice as hard. This is the same local-SGD-then-average primitive analyzed in Section 10.7, now run over a network where local work per round is the whole point rather than an optimization.
FedAvg is exact only at $E = 1$ full-batch local step, where the size-weighted average of the local updates reconstructs one true gradient step on the global loss. For every $E > 1$ the clients optimize their own local objectives for a while before averaging, so the average of their endpoints is no longer the endpoint of any single global trajectory. That gap is client drift. It is zero when all clients hold identical data and grows with both $E$ and the dissimilarity of the data. Every FedAvg variant in this section is an attempt to enlarge the usable range of $E$ by shrinking that gap.
3. The Communication-versus-Accuracy Trade-off of E Intermediate
Raising $E$ has two opposing effects. More local steps per round means the global model improves more per round, so fewer rounds are needed to reach a target, which is pure communication savings. But more local steps also let each client travel further toward its own local optimum, so the endpoints the server averages are more scattered, and on non-identical data that scatter degrades the averaged model. On identically distributed data the first effect dominates and large $E$ is close to free; the value of $E$ is settled empirically per deployment. The code below makes both effects visible at once. It runs FedAvg on identically distributed clients and counts the communication rounds needed to drive the global training loss to within $10^{-3}$ of the centralized optimum as $E$ increases, then measures how far the local models spread apart when $E$ is large.
import numpy as np
# Logistic-regression FedAvg on IID data. We count communication rounds needed to
# drive the GLOBAL training loss to within a small gap of the optimum, as a function
# of the number of local steps E. More local work per round means fewer rounds.
rng = np.random.default_rng(7)
N, d, C = 20_000, 30, 20 # examples, features, clients
w_star = rng.standard_normal(d)
X = rng.standard_normal((N, d))
y = (rng.random(N) < 1.0 / (1.0 + np.exp(-(X @ w_star)))).astype(np.float64)
# IID split: shuffle, then chunk into C equal client shards.
perm = rng.permutation(N)
shards = np.array_split(perm, C)
n_k = np.array([len(s) for s in shards], dtype=np.float64) # local dataset sizes
def sigmoid(z):
return 1.0 / (1.0 + np.exp(-np.clip(z, -30, 30)))
def loss(w):
z = X @ w
return float(np.mean(np.logaddexp(0.0, z) - y * z))
def local_train(w, idx, E, lr):
"""E full-batch gradient steps on one client's data, starting from the global w."""
w = w.copy()
Xk, yk = X[idx], y[idx]
for _ in range(E):
w -= lr * (Xk.T @ (sigmoid(Xk @ w) - yk) / len(idx))
return w
# Reference optimum: many centralized steps on the full dataset.
w_opt = np.zeros(d)
for _ in range(4000):
w_opt -= 0.5 * (X.T @ (sigmoid(X @ w_opt) - y) / N)
L_opt = loss(w_opt)
def run_fedavg(E, lr=0.5, gap=1e-3, max_rounds=500):
w = np.zeros(d) # global model
for r in range(1, max_rounds + 1):
locals_ = [local_train(w, s, E, lr) for s in shards] # broadcast + local SGD
# Weighted average: sum_k (n_k / N) w_k, the FedAvg aggregation rule.
w = np.tensordot(n_k / n_k.sum(), np.stack(locals_), axes=1)
if loss(w) - L_opt < gap:
return r
return max_rounds
print(f"optimal training loss L* = {L_opt:.4f}")
print(f"{'E (local steps)':>16} | {'rounds to L*+1e-3':>18}")
print("-" * 39)
for E in (1, 2, 5, 20):
print(f"{E:>16} | {run_fedavg(E):>18}")
# Drift glimpse: with very many local steps each client moves far toward its own
# shard optimum, so the spread of local models around their average grows.
print()
for E in (1, 50):
locals_ = np.stack([local_train(np.zeros(d), s, E, 0.5) for s in shards])
drift = float(np.mean(np.linalg.norm(locals_ - locals_.mean(0), axis=1)))
print(f"mean client drift at E={E:>2}: {drift:.3f}")
optimal training loss L* = 0.2309
E (local steps) | rounds to L*+1e-3
---------------------------------------
1 | 347
2 | 174
5 | 70
20 | 17
mean client drift at E= 1: 0.041
mean client drift at E=50: 0.354
The rounds-to-target numbers in Output 14.3.1 are the entire reason federated learning is practical: a 20-fold cut in communication rounds is the difference between a model that finishes training on a fleet of phones and one that never does. On these identically distributed clients the drift at $E=20$ is still small enough that the averaged model is essentially the centralized one, which is why large $E$ looks free here. The trouble starts when clients hold different distributions, the realistic case, where that same drift pulls each local optimum toward a different place and the average of scattered endpoints is a worse model. That failure mode is the entire subject of Section 14.4; the variants below are the defenses built against it.
FedAvg with a large $E$ is the distributed-systems version of a group project where everyone disappears for two weeks and then merges their work the night before. If everyone built the same thing (identically distributed data), the merge is painless and you saved two weeks of meetings. If everyone interpreted the assignment differently (non-identical data), you spend the merge night discovering that the four halves do not fit together, which is client drift wearing a hoodie.
4. Variants That Correct Client Drift Advanced
Because drift is what limits how large $E$ can be, the major FedAvg variants are each a different mechanism for keeping local training honest. Three families dominate. FedProx (Li et al., 2020) adds a proximal term $\frac{\mu}{2}\lVert w - w^{(t)}\rVert^2$ to each client's local objective, penalizing local models that wander far from the broadcast global model; this softly tethers every client to the consensus and lets the system tolerate both larger $E$ and clients that finish different amounts of work. SCAFFOLD (Karimireddy et al., 2020) attacks drift at its source with control variates: the server and each client maintain correction vectors that estimate the difference between the client's local gradient direction and the global one, and subtract it during local steps, so a client's drift is actively cancelled rather than merely penalized. The third family changes the server side rather than the client side: FedAdam and the broader FedOpt framework (Reddi et al., 2021) treat the average of client updates as a pseudo-gradient and feed it to an adaptive server optimizer such as Adam, which adds momentum and per-coordinate scaling across rounds and noticeably accelerates convergence on heterogeneous data. Table 14.3.1 lays the three families side by side.
| Method | What it changes | Mechanism against drift | Cost it adds |
|---|---|---|---|
| FedAvg | nothing (the baseline) | none; relies on small $E$ | none |
| FedProx | client objective | proximal pull toward the global model | one tuning knob $\mu$ |
| SCAFFOLD | client update direction | control variates cancel drift | extra control vector per client, doubled upload |
| FedAdam / FedOpt | server aggregation | adaptive momentum across rounds | server optimizer state |
The variants are not mutually exclusive; a production system might run FedProx on the clients and FedAdam on the server at once. What they share is the structure of Figure 14.3.1, which is why FedAvg is the right thing to understand first: every one of these methods is a modification of one of its three steps, not a replacement for the loop.
The size-weighted average of Section 1.1 returns here as the heart of FedAvg, but the network it crosses has changed character completely. In data-parallel deep learning the all-reduce runs every step over a fast interconnect; in federated learning the same average runs once per round over slow, unreliable, partial-participation links, which is exactly why the algorithm pushes so much computation into the local phase between communications. The combine is the same operation the whole book is built around; the engineering shifts from making it fast to making it rare. That rarity is also what lets a secure-aggregation protocol sit on top of the average so the server learns only the sum, a privacy property we develop in Section 14.6 and whose cost we weigh against the communication-cost models of Chapter 10.
Code 14.3.1 implemented broadcast, local training, and weighted aggregation by hand. Production federated frameworks supply that loop and the variants as configured strategies, so switching from FedAvg to FedAdam is a one-line change. In Flower, the client implements only local training and the server picks a strategy; the round orchestration, client sampling, and size-weighted aggregation are built in:
# pip install flwr
import flwr as fl
class Client(fl.client.NumPyClient):
def fit(self, parameters, config):
set_weights(model, parameters)
train_local(model, local_data, epochs=config["E"]) # the only client code you write
return get_weights(model), len(local_data), {} # weight = n_k for the average
# FedAvg is the default; swap one line for FedAdam, FedProx, etc.
strategy = fl.server.strategy.FedAvg() # or FedAdam(...), FedProx(...)
fl.server.start_server(strategy=strategy,
config=fl.server.ServerConfig(num_rounds=100))
fit method and a one-line strategy; Flower handles client selection, round management, the size-weighted average (it reads each client's returned example count as the weight), and the swap to drift-correcting variants. TensorFlow Federated (TFF) offers the same loop in a more research-oriented, functional API.Who: An on-device machine learning team training a next-word prediction model across a fleet of consumer phones.
Situation: Each round cost real money and battery: phones train only while charging on Wi-Fi, and a round takes minutes of wall-clock as the server waits for enough clients to report in.
Problem: At one local epoch per round the model needed hundreds of rounds to converge, stretching training over weeks because so few rounds complete per day.
Dilemma: Raise the local epoch count $E$ to cut the number of rounds, mirroring the 20-fold reduction in Output 14.3.1, or keep $E$ small to avoid client drift, since each user's typing distribution is sharply non-identical.
Decision: They raised $E$ to a moderate value and adopted FedProx, using the proximal term to tether drifting clients so the larger $E$ did not wreck accuracy on the heterogeneous typing data.
How: They swept $E$ on a held-out shadow population, watched the global validation loss for the drift-induced plateau, and set the proximal weight $\mu$ just high enough to remove it.
Result: Rounds-to-target fell by roughly an order of magnitude while accuracy held, turning a multi-week training cycle into a few days without any user data leaving a device.
Lesson: On non-identical data, $E$ and the drift correction are tuned together: the proximal tether is what lets you spend the large $E$ that the communication budget demands.
5. Where FedAvg Sits in the Larger Story Intermediate
FedAvg is a star-topology algorithm: every round funnels through a central server that broadcasts and aggregates. That server is a single point of coordination and, for the size-weighted average, a single point of trust. The decentralized methods later in this chapter remove the server entirely and let clients average with neighbors over a gossip graph, trading the server's global view for fault tolerance and a flatter trust model. The drift you measured in Output 14.3.1 reappears there in a different guise, as the slowness of consensus across a sparse graph. For now, the server-based FedAvg loop is the workhorse: it is what runs on production phone fleets and federated medical networks, and its $E$ knob and drift-correcting variants are the levers every federated deployment actually pulls.
The pressing current question is how to run FedAvg-style training when the model is a billion-parameter foundation model and the clients are bandwidth-limited edge devices. Uploading a full model per round is infeasible, so 2024 to 2026 work pushes parameter-efficient federated tuning: clients train and exchange only low-rank LoRA adapters rather than full weights, slashing the per-round upload by orders of magnitude while the frozen backbone stays put. A parallel line attacks the heterogeneity that drift-correction only partly fixes, with personalized and clustered federated methods that let groups of similar clients share a head while a common body is averaged. On the systems side, asynchronous and buffered aggregation schemes such as FedBuff relax FedAvg's synchronous round barrier so that slow or dropped clients no longer stall a round, an idea that connects directly to the bounded-staleness optimization of Chapter 10 and the communication-cost models of Chapter 3. The throughline is unchanged: shrink what crosses the network per round, because the network is still the binding constraint.
We now have the core algorithm, the proof that its size-weighted average is the correct combine, a measured account of the rounds-versus-drift trade-off of $E$, and the variants that extend its reach. Every one of those variants exists to cope with one assumption we have quietly kept: that the clients hold identically distributed data. Real federated clients never do. The next section confronts that head-on, showing exactly how non-identical data breaks the clean averaging picture and what the cure costs. That confrontation begins in Section 14.4.
Section 2 showed that FedAvg with one full-batch local step per round equals one step of centralized SGD on the global loss. State precisely which two conditions on the local update make that equivalence exact, and explain what specifically breaks it when (a) clients take more than one local step, and (b) the server uses a plain unweighted average instead of the size-weighted one. For case (b), describe a small two-client example, one client with many examples and one with few, where the unweighted average visibly pulls the global model toward the wrong place.
Modify Code 14.3.1 so the client shards are non-identically distributed: sort the examples by their label before splitting, so each client sees a skewed label mix instead of an IID sample. Rerun the rounds-to-target experiment for the same values of $E$ and report what happens to both the round counts and the final loss gap. Then re-measure client drift at $E=50$ and compare it to the IID value of 0.354 from Output 14.3.1. Explain in two sentences why raising $E$ helped under IID data but hurts here.
Suppose each communication round costs a fixed $T_{\text{comm}}$ seconds (broadcast plus upload of one model) and each local step costs $T_{\text{local}}$ seconds, so a round with $E$ local steps takes $T_{\text{comm}} + E\,T_{\text{local}}$. Using the rounds-to-target numbers from Output 14.3.1 (347, 174, 70, 17 for $E = 1, 2, 5, 20$), write the total wall-clock time to reach the target as a function of $E$, and find which $E$ minimizes it when $T_{\text{comm}} = 60$ seconds and $T_{\text{local}} = 1$ second. Then redo it for $T_{\text{comm}} = 2$ seconds and explain why the optimal $E$ drops when communication gets cheap. This is the same communication-versus-computation balance quantified in Chapter 3.