Part III: Distributed Machine Learning
Chapter 14: Federated and Decentralized Learning

FedAvg and Its Variants

"They told me to take a few steps on my own before checking in. I took a few hundred, found a lovely local optimum, and now nobody recognizes the model I came back with."

A Client That Drifted Too Far
Mental model: One Federated Round (FedAvg)
Mental model. The hub broadcasts the global model to devices, each trains locally on data that never leaves, and the hub averages the returned updates into a new model. Training distributes while data stays put.
Big Picture

Federated Averaging trains one shared model across machines that never share their data by replacing per-step gradient exchange with per-round model exchange: each round the server broadcasts the current model, every selected client improves it on its own data for several local steps, and the server folds the returned models back together with a size-weighted average. That single design choice, letting clients compute many local steps between communications, is what makes federated learning feasible over slow, intermittent links where exchanging a gradient every step would be hopeless. The cost is client drift: the more local work each device does, the further its model wanders toward its own data, and on non-identical data those wanderings do not average out cleanly. This section derives why the size-weighted average is the correct combine, measures the rounds-versus-drift trade-off in running code, and surveys the variants that buy back the accuracy that aggressive local computation gives away. The mental model above captures the whole loop in one image: broadcast, train locally on data that never leaves, average the returns.

The previous sections framed federated learning as training without centralizing data: hospitals, phones, and banks hold partitions that legal and physical constraints forbid moving to one place. That framing raises an immediate question. The data-parallel training of Section 1.1 assumed workers could exchange a gradient on every step over a fast interconnect. A fleet of phones on cellular links cannot do that; the network is slow, metered, and frequently unavailable. The whole apparatus of distributed training has to be rebuilt around a network that is the scarcest resource in the system. Federated Averaging, introduced by McMahan and colleagues in 2017, is the algorithm that performs that rebuild, and it remains the foundation that every later method modifies rather than replaces.

1. The Algorithm: One Round of FedAvg Beginner

FedAvg organizes training into communication rounds rather than gradient steps. At the start of round $t$ the server holds a global model $w^{(t)}$ and selects a subset $S_t$ of the available clients (on a fleet of millions of phones, only a small sample participates per round). It broadcasts $w^{(t)}$ to each selected client. Client $k$ then runs $E$ epochs of local stochastic gradient descent on its own dataset of $n_k$ examples, starting from the broadcast model, and arrives at an updated local model $w_k^{(t+1)}$. The clients send those models (not their data, and not raw gradients) back to the server, which combines them into the next global model. Figure 14.3.1 traces this loop end to end.

Server holds w(t) global model 1. broadcast w(t) Client 1 E local SGD epochs private data, n₁ ex. returns w₁(t+1) Client 2 E local SGD epochs private data, n₂ ex. returns w₂(t+1) Client K E local SGD epochs private data, nₖ ex. returns wₖ(t+1) 2. upload local models 3. weighted average w(t+1) = Σₖ (nₖ / N) wₖ(t+1) next round
Figure 14.3.1: One round of Federated Averaging. The server broadcasts the global model $w^{(t)}$ (blue arrows), each selected client runs $E$ epochs of local SGD on data that never leaves the device (orange boxes), the clients upload only their updated models (orange arrows), and the server folds them into $w^{(t+1)}$ with the size-weighted average before starting the next round. Raw data and per-step gradients stay local; only models cross the network.

Two numbers control everything. The number of local epochs $E$ (equivalently, the number of local SGD steps per round) sets how much computation each client does before it has to communicate. The number of rounds sets how many times the network is used. The art of FedAvg is spending local computation, which is cheap and private, to buy down communication rounds, which are expensive and unreliable. The rest of this section is about how far that trade goes before it breaks.

2. Why the Size-Weighted Average Is the Right Combine Intermediate

The server does not average the returned models equally; it weights each by the size of the client's dataset. With $N = \sum_k n_k$ total examples across the participating clients, the aggregation rule is

$$w^{(t+1)} = \sum_{k \in S_t} \frac{n_k}{N}\, w_k^{(t+1)}, \qquad N = \sum_{k \in S_t} n_k.$$

This weighting is not a heuristic; it is forced by the same gradient identity that justified data parallelism in Section 1.1. The federated objective is the average loss over all data, $L(w) = \frac{1}{N}\sum_k \sum_{i \in \mathcal{D}_k} \ell(w; x_i, y_i)$, which regroups by client into $L(w) = \sum_k \frac{n_k}{N} L_k(w)$, where $L_k$ is client $k$'s local average loss. Because the global loss is the size-weighted average of the local losses, its gradient is the size-weighted average of the local gradients: $\nabla L(w) = \sum_k \frac{n_k}{N}\nabla L_k(w)$. Consider the special case $E = 1$ with a single full-batch local step. Each client returns $w_k = w - \eta \nabla L_k(w)$, and the size-weighted average of those returned models is exactly $w - \eta \sum_k \frac{n_k}{N}\nabla L_k(w) = w - \eta \nabla L(w)$, one true gradient step on the global objective. FedAvg with one local step is centralized SGD; the size-weighting is what makes that equivalence hold, and a client with twice the data correctly pulls twice as hard. This is the same local-SGD-then-average primitive analyzed in Section 10.7, now run over a network where local work per round is the whole point rather than an optimization.

Key Insight: Equality at One Local Step, Approximation Beyond It

FedAvg is exact only at $E = 1$ full-batch local step, where the size-weighted average of the local updates reconstructs one true gradient step on the global loss. For every $E > 1$ the clients optimize their own local objectives for a while before averaging, so the average of their endpoints is no longer the endpoint of any single global trajectory. That gap is client drift. It is zero when all clients hold identical data and grows with both $E$ and the dissimilarity of the data. Every FedAvg variant in this section is an attempt to enlarge the usable range of $E$ by shrinking that gap.

3. The Communication-versus-Accuracy Trade-off of E Intermediate

Raising $E$ has two opposing effects. More local steps per round means the global model improves more per round, so fewer rounds are needed to reach a target, which is pure communication savings. But more local steps also let each client travel further toward its own local optimum, so the endpoints the server averages are more scattered, and on non-identical data that scatter degrades the averaged model. On identically distributed data the first effect dominates and large $E$ is close to free; the value of $E$ is settled empirically per deployment. The code below makes both effects visible at once. It runs FedAvg on identically distributed clients and counts the communication rounds needed to drive the global training loss to within $10^{-3}$ of the centralized optimum as $E$ increases, then measures how far the local models spread apart when $E$ is large.

import numpy as np

# Logistic-regression FedAvg on IID data. We count communication rounds needed to
# drive the GLOBAL training loss to within a small gap of the optimum, as a function
# of the number of local steps E. More local work per round means fewer rounds.
rng = np.random.default_rng(7)
N, d, C = 20_000, 30, 20            # examples, features, clients
w_star = rng.standard_normal(d)
X = rng.standard_normal((N, d))
y = (rng.random(N) < 1.0 / (1.0 + np.exp(-(X @ w_star)))).astype(np.float64)

# IID split: shuffle, then chunk into C equal client shards.
perm = rng.permutation(N)
shards = np.array_split(perm, C)
n_k = np.array([len(s) for s in shards], dtype=np.float64)   # local dataset sizes

def sigmoid(z):
    return 1.0 / (1.0 + np.exp(-np.clip(z, -30, 30)))

def loss(w):
    z = X @ w
    return float(np.mean(np.logaddexp(0.0, z) - y * z))

def local_train(w, idx, E, lr):
    """E full-batch gradient steps on one client's data, starting from the global w."""
    w = w.copy()
    Xk, yk = X[idx], y[idx]
    for _ in range(E):
        w -= lr * (Xk.T @ (sigmoid(Xk @ w) - yk) / len(idx))
    return w

# Reference optimum: many centralized steps on the full dataset.
w_opt = np.zeros(d)
for _ in range(4000):
    w_opt -= 0.5 * (X.T @ (sigmoid(X @ w_opt) - y) / N)
L_opt = loss(w_opt)

def run_fedavg(E, lr=0.5, gap=1e-3, max_rounds=500):
    w = np.zeros(d)                                       # global model
    for r in range(1, max_rounds + 1):
        locals_ = [local_train(w, s, E, lr) for s in shards]   # broadcast + local SGD
        # Weighted average: sum_k (n_k / N) w_k, the FedAvg aggregation rule.
        w = np.tensordot(n_k / n_k.sum(), np.stack(locals_), axes=1)
        if loss(w) - L_opt < gap:
            return r
    return max_rounds

print(f"optimal training loss L* = {L_opt:.4f}")
print(f"{'E (local steps)':>16} | {'rounds to L*+1e-3':>18}")
print("-" * 39)
for E in (1, 2, 5, 20):
    print(f"{E:>16} | {run_fedavg(E):>18}")

# Drift glimpse: with very many local steps each client moves far toward its own
# shard optimum, so the spread of local models around their average grows.
print()
for E in (1, 50):
    locals_ = np.stack([local_train(np.zeros(d), s, E, 0.5) for s in shards])
    drift = float(np.mean(np.linalg.norm(locals_ - locals_.mean(0), axis=1)))
    print(f"mean client drift at E={E:>2}: {drift:.3f}")
Code 14.3.1: A complete FedAvg loop in pure NumPy: broadcast, $E$ local steps per client, size-weighted aggregation, repeat. The first experiment counts rounds to a fixed loss target as $E$ grows; the second measures how far the local models drift apart when $E$ is large.
optimal training loss L* = 0.2309
 E (local steps) |  rounds to L*+1e-3
---------------------------------------
               1 |                347
               2 |                174
               5 |                 70
              20 |                 17

mean client drift at E= 1: 0.041
mean client drift at E=50: 0.354
Output 14.3.1: Communication rounds to reach the target collapse from 347 at $E=1$ to 17 at $E=20$, a roughly 20-fold reduction in network usage on identically distributed data. The drift measurement shows the cost waiting at large $E$: the average distance of a client model from the consensus grows from 0.041 at one local step to 0.354 at fifty, the seed of the degradation that bites once data stops being identical.

The rounds-to-target numbers in Output 14.3.1 are the entire reason federated learning is practical: a 20-fold cut in communication rounds is the difference between a model that finishes training on a fleet of phones and one that never does. On these identically distributed clients the drift at $E=20$ is still small enough that the averaged model is essentially the centralized one, which is why large $E$ looks free here. The trouble starts when clients hold different distributions, the realistic case, where that same drift pulls each local optimum toward a different place and the average of scattered endpoints is a worse model. That failure mode is the entire subject of Section 14.4; the variants below are the defenses built against it.

Fun Note: The Group Project Failure Mode

FedAvg with a large $E$ is the distributed-systems version of a group project where everyone disappears for two weeks and then merges their work the night before. If everyone built the same thing (identically distributed data), the merge is painless and you saved two weeks of meetings. If everyone interpreted the assignment differently (non-identical data), you spend the merge night discovering that the four halves do not fit together, which is client drift wearing a hoodie.

4. Variants That Correct Client Drift Advanced

Because drift is what limits how large $E$ can be, the major FedAvg variants are each a different mechanism for keeping local training honest. Three families dominate. FedProx (Li et al., 2020) adds a proximal term $\frac{\mu}{2}\lVert w - w^{(t)}\rVert^2$ to each client's local objective, penalizing local models that wander far from the broadcast global model; this softly tethers every client to the consensus and lets the system tolerate both larger $E$ and clients that finish different amounts of work. SCAFFOLD (Karimireddy et al., 2020) attacks drift at its source with control variates: the server and each client maintain correction vectors that estimate the difference between the client's local gradient direction and the global one, and subtract it during local steps, so a client's drift is actively cancelled rather than merely penalized. The third family changes the server side rather than the client side: FedAdam and the broader FedOpt framework (Reddi et al., 2021) treat the average of client updates as a pseudo-gradient and feed it to an adaptive server optimizer such as Adam, which adds momentum and per-coordinate scaling across rounds and noticeably accelerates convergence on heterogeneous data. Table 14.3.1 lays the three families side by side.

Table 14.3.1: The main FedAvg variants, the part of the algorithm each one changes, and the drift problem each one targets. All keep the broadcast, local-training, aggregate structure of Figure 14.3.1.
MethodWhat it changesMechanism against driftCost it adds
FedAvgnothing (the baseline)none; relies on small $E$none
FedProxclient objectiveproximal pull toward the global modelone tuning knob $\mu$
SCAFFOLDclient update directioncontrol variates cancel driftextra control vector per client, doubled upload
FedAdam / FedOptserver aggregationadaptive momentum across roundsserver optimizer state

The variants are not mutually exclusive; a production system might run FedProx on the clients and FedAdam on the server at once. What they share is the structure of Figure 14.3.1, which is why FedAvg is the right thing to understand first: every one of these methods is a modification of one of its three steps, not a replacement for the loop.

Thesis Thread: The Combine Step, Now Over a Hostile Network

The size-weighted average of Section 1.1 returns here as the heart of FedAvg, but the network it crosses has changed character completely. In data-parallel deep learning the all-reduce runs every step over a fast interconnect; in federated learning the same average runs once per round over slow, unreliable, partial-participation links, which is exactly why the algorithm pushes so much computation into the local phase between communications. The combine is the same operation the whole book is built around; the engineering shifts from making it fast to making it rare. That rarity is also what lets a secure-aggregation protocol sit on top of the average so the server learns only the sum, a privacy property we develop in Section 14.6 and whose cost we weigh against the communication-cost models of Chapter 10.

Library Shortcut: Flower Runs the FedAvg Loop for You

Code 14.3.1 implemented broadcast, local training, and weighted aggregation by hand. Production federated frameworks supply that loop and the variants as configured strategies, so switching from FedAvg to FedAdam is a one-line change. In Flower, the client implements only local training and the server picks a strategy; the round orchestration, client sampling, and size-weighted aggregation are built in:

# pip install flwr
import flwr as fl

class Client(fl.client.NumPyClient):
    def fit(self, parameters, config):
        set_weights(model, parameters)
        train_local(model, local_data, epochs=config["E"])   # the only client code you write
        return get_weights(model), len(local_data), {}       # weight = n_k for the average

# FedAvg is the default; swap one line for FedAdam, FedProx, etc.
strategy = fl.server.strategy.FedAvg()      # or FedAdam(...), FedProx(...)
fl.server.start_server(strategy=strategy,
                       config=fl.server.ServerConfig(num_rounds=100))
Code 14.3.2: The FedAvg loop of Code 14.3.1 in Flower. The roughly thirty lines of manual orchestration collapse to a client fit method and a one-line strategy; Flower handles client selection, round management, the size-weighted average (it reads each client's returned example count as the weight), and the swap to drift-correcting variants. TensorFlow Federated (TFF) offers the same loop in a more research-oriented, functional API.
Practical Example: Tuning E for a Mobile Keyboard Model

Who: An on-device machine learning team training a next-word prediction model across a fleet of consumer phones.

Situation: Each round cost real money and battery: phones train only while charging on Wi-Fi, and a round takes minutes of wall-clock as the server waits for enough clients to report in.

Problem: At one local epoch per round the model needed hundreds of rounds to converge, stretching training over weeks because so few rounds complete per day.

Dilemma: Raise the local epoch count $E$ to cut the number of rounds, mirroring the 20-fold reduction in Output 14.3.1, or keep $E$ small to avoid client drift, since each user's typing distribution is sharply non-identical.

Decision: They raised $E$ to a moderate value and adopted FedProx, using the proximal term to tether drifting clients so the larger $E$ did not wreck accuracy on the heterogeneous typing data.

How: They swept $E$ on a held-out shadow population, watched the global validation loss for the drift-induced plateau, and set the proximal weight $\mu$ just high enough to remove it.

Result: Rounds-to-target fell by roughly an order of magnitude while accuracy held, turning a multi-week training cycle into a few days without any user data leaving a device.

Lesson: On non-identical data, $E$ and the drift correction are tuned together: the proximal tether is what lets you spend the large $E$ that the communication budget demands.

5. Where FedAvg Sits in the Larger Story Intermediate

FedAvg is a star-topology algorithm: every round funnels through a central server that broadcasts and aggregates. That server is a single point of coordination and, for the size-weighted average, a single point of trust. The decentralized methods later in this chapter remove the server entirely and let clients average with neighbors over a gossip graph, trading the server's global view for fault tolerance and a flatter trust model. The drift you measured in Output 14.3.1 reappears there in a different guise, as the slowness of consensus across a sparse graph. For now, the server-based FedAvg loop is the workhorse: it is what runs on production phone fleets and federated medical networks, and its $E$ knob and drift-correcting variants are the levers every federated deployment actually pulls.

Research Frontier: Federated Learning of Foundation Models (2024 to 2026)

The pressing current question is how to run FedAvg-style training when the model is a billion-parameter foundation model and the clients are bandwidth-limited edge devices. Uploading a full model per round is infeasible, so 2024 to 2026 work pushes parameter-efficient federated tuning: clients train and exchange only low-rank LoRA adapters rather than full weights, slashing the per-round upload by orders of magnitude while the frozen backbone stays put. A parallel line attacks the heterogeneity that drift-correction only partly fixes, with personalized and clustered federated methods that let groups of similar clients share a head while a common body is averaged. On the systems side, asynchronous and buffered aggregation schemes such as FedBuff relax FedAvg's synchronous round barrier so that slow or dropped clients no longer stall a round, an idea that connects directly to the bounded-staleness optimization of Chapter 10 and the communication-cost models of Chapter 3. The throughline is unchanged: shrink what crosses the network per round, because the network is still the binding constraint.

We now have the core algorithm, the proof that its size-weighted average is the correct combine, a measured account of the rounds-versus-drift trade-off of $E$, and the variants that extend its reach. Every one of those variants exists to cope with one assumption we have quietly kept: that the clients hold identically distributed data. Real federated clients never do. The next section confronts that head-on, showing exactly how non-identical data breaks the clean averaging picture and what the cure costs. That confrontation begins in Section 14.4.

Exercise 14.3.1: When Is FedAvg Exactly Centralized SGD? Conceptual

Section 2 showed that FedAvg with one full-batch local step per round equals one step of centralized SGD on the global loss. State precisely which two conditions on the local update make that equivalence exact, and explain what specifically breaks it when (a) clients take more than one local step, and (b) the server uses a plain unweighted average instead of the size-weighted one. For case (b), describe a small two-client example, one client with many examples and one with few, where the unweighted average visibly pulls the global model toward the wrong place.

Exercise 14.3.2: Make the Data Non-IID and Watch It Break Coding

Modify Code 14.3.1 so the client shards are non-identically distributed: sort the examples by their label before splitting, so each client sees a skewed label mix instead of an IID sample. Rerun the rounds-to-target experiment for the same values of $E$ and report what happens to both the round counts and the final loss gap. Then re-measure client drift at $E=50$ and compare it to the IID value of 0.354 from Output 14.3.1. Explain in two sentences why raising $E$ helped under IID data but hurts here.

Exercise 14.3.3: The Communication Budget of E Analysis

Suppose each communication round costs a fixed $T_{\text{comm}}$ seconds (broadcast plus upload of one model) and each local step costs $T_{\text{local}}$ seconds, so a round with $E$ local steps takes $T_{\text{comm}} + E\,T_{\text{local}}$. Using the rounds-to-target numbers from Output 14.3.1 (347, 174, 70, 17 for $E = 1, 2, 5, 20$), write the total wall-clock time to reach the target as a function of $E$, and find which $E$ minimizes it when $T_{\text{comm}} = 60$ seconds and $T_{\text{local}} = 1$ second. Then redo it for $T_{\text{comm}} = 2$ seconds and explain why the optimal $E$ drops when communication gets cheap. This is the same communication-versus-computation balance quantified in Chapter 3.