Part III: Distributed Machine Learning
Chapter 14: Federated and Decentralized Learning

Non-IID Data

"In the datacenter they shuffled my data until every shard looked the same. Out here, my whole world is one person's camera roll, and I am very confident about cats."

A Client Model Overfit to Its Owner
Big Picture

The defining hardship of federated learning is that nobody gets to shuffle the data. In datacenter training we deliberately randomize examples so every worker's shard is a faithful miniature of the whole distribution, which is exactly what made the data-parallel gradient identity of Chapter 1 exact. Federated clients break that assumption by construction: a phone holds one person's photos, a hospital holds one region's patient mix, and no client may ship its raw data anywhere to be reshuffled. Each client therefore optimizes toward its own local minimum, and when those locally drifted models are averaged, the average can be worse than any honest gradient step on the pooled data would have been. This section names the kinds of heterogeneity, shows why running many local epochs makes the problem worse, and surveys the mitigations (fewer local steps, proximal and control-variate corrections, server momentum, and limited data sharing or distillation) that make federated optimization converge despite the skew.

In Section 14.3 we built FedAvg: each round, the server broadcasts the current model, a set of clients each run several local SGD epochs on their own data, and the server averages the returned models weighted by how many examples each client holds. We mentioned that FedProx and SCAFFOLD exist as variants without yet saying what disease they cure. This section is that diagnosis. The single assumption that FedAvg silently inherited from datacenter data parallelism, that each participant's data is an independent and identically distributed (IID) sample from one global distribution, is precisely the assumption that real federated populations violate. Once we drop it, the comforting "an average decomposes, so distribution is exact" argument loses its footing, and the gap it opens is the subject of the rest of the chapter.

1. Why Shuffling Was Doing All the Work Beginner

Recall what shuffling buys you in ordinary distributed training. You take a dataset, permute it uniformly at random, and hand contiguous blocks to each worker. Because the permutation is uniform, every worker's block is a representative sample: its label histogram, its feature statistics, and its difficulty profile all match the global ones up to sampling noise. The local gradient each worker computes is therefore an unbiased estimate of the global gradient, and averaging those estimates reduces variance without introducing bias. That is the entire statistical foundation of synchronous data parallelism, and it is why a hundred workers can act as one.

Federated learning removes the permutation step and cannot put it back. The data was generated where it lives and must stay there, for reasons of privacy, regulation, bandwidth, or simple physics: you are not going to upload every phone's camera roll to a central shuffler. So the partition into clients is not a random cut through one distribution; it is a partition by source, and sources differ. Formally, write client $k$'s local distribution as $\mathcal{D}_k$ and the global mixture as $\mathcal{D} = \sum_k \frac{n_k}{n}\,\mathcal{D}_k$, where $n_k$ is client $k$'s sample count. In the IID datacenter world every $\mathcal{D}_k \approx \mathcal{D}$. In the federated world the $\mathcal{D}_k$ are different distributions, and the size of that difference is what we now have to reason about.

Key Insight: The Local Optimum Is Not the Global Optimum

Under IID shards, every client's loss surface is a noisy copy of the same surface, so each client's local minimum sits near the global minimum and averaging local models lands near it too. Under non-IID data, client $k$ minimizes its own objective $F_k(w)=\mathbb{E}_{(x,y)\sim\mathcal{D}_k}[\ell(w;x,y)]$, whose minimizer $w_k^\star$ can be far from the global minimizer $w^\star$ of $F(w)=\sum_k \frac{n_k}{n}F_k(w)$. Many local steps pull each client toward its own $w_k^\star$; averaging those pulls does not generally point at $w^\star$. The bias is not noise that more clients average away; it is structural, and it survives no matter how many clients you add.

2. The Four Faces of Heterogeneity Beginner

"Non-IID" is an umbrella over several distinct ways client distributions can differ, and they call for different remedies, so it pays to keep them separate. Table 14.4.1 names the four that dominate practice. Factor the joint distribution as $\mathcal{D}_k(x,y) = \mathcal{D}_k(y)\,\mathcal{D}_k(x\mid y)$ on each client, and most heterogeneity types are a statement about which factor varies.

Table 14.4.1: The four common forms of client heterogeneity in federated learning, the factor of the joint distribution each one perturbs, and a canonical example.
TypeWhat differs across clientsCanonical example
Label-distribution skewThe class prior $\mathcal{D}_k(y)$One phone photographs mostly its owner's dog; a clinic sees mostly one disease stage
Feature skewThe conditional $\mathcal{D}_k(x\mid y)$The same digit written by different people; the same diagnosis on different scanners
Quantity skewThe sample count $n_k$A power user with 10,000 photos beside a new user with 12
Concept driftThe mapping $y\mid x$, across clients or over timeThe same word means different things to different communities; preferences shift

Label-distribution skew is the form most studies stress-test, because it is both common and severe: the standard benchmark sorts a dataset by label and hands each client only one or two classes, a synthetic worst case that nonetheless mirrors reality where a user generates a narrow slice of the label space. Feature skew is subtler and often invisible to the label histogram: every hospital may see the full range of diagnoses, yet each one's images carry a scanner-specific signature that a model can latch onto. Quantity skew interacts with the weighting in FedAvg, since a few enormous clients can dominate the average. Concept drift is the hardest, because it means there is no single $w^\star$ that is simultaneously right for everyone, which is the motivation for the personalized methods of Section 14.7. This last form connects directly to the distribution-shift and concept-drift machinery we built for streaming systems in Chapter 9; federated non-IID data is distribution shift across space rather than across time.

Fun Note: The Benchmark That Gave Every Client One Class

The pathological partition that the literature loves, sort by label and deal each client a single class, produces clients that are individually certain and collectively useless. A client holding only cats trains a model that screams "cat" at everything; a client holding only dogs does the mirror image. Average the two and you get a model unsure whether anything is an animal at all. It is the machine-learning version of a committee where every member is a single-issue zealot, and the minutes read accordingly.

3. Client Drift, and Why Many Local Epochs Amplify It Intermediate

The mechanism that turns heterogeneity into a convergence problem is client drift: during a round, each client moves its copy of the model away from the shared starting point and toward its own local optimum, so by aggregation time the client models have spread apart. A useful scalar measure of how badly a round drifted is the average distance of the client models from their aggregate,

$$\text{drift} = \frac{1}{K}\sum_{k=1}^{K}\bigl\lVert w_k - \bar{w}\bigr\rVert, \qquad \bar{w}=\sum_{k=1}^{K}\frac{n_k}{n}\,w_k,$$

where $w_k$ is client $k$'s model after its local steps and $\bar{w}$ is the FedAvg aggregate. Two knobs inflate this quantity. The first is heterogeneity itself: the more the local gradients $\nabla F_k$ disagree at the shared starting point, the faster the clients diverge. The second, and the one practitioners control directly, is the number of local epochs $E$ introduced in Section 14.3. Each extra local epoch lets a client travel further down its own loss surface before the server gets a chance to average, so larger $E$ means each $w_k$ sits closer to its private $w_k^\star$ and further from its neighbors. FedAvg uses large $E$ precisely to save communication, which is the tension Section 14.5 takes up: the same local computation that cuts your network bill is what lets clients drift.

Each client's skewed labels pull the shared model a different way global model w Client A mostly class 0 w* (A) Client B mostly class 1 w* (B) Client C mostly class 2 w* (C) average w̄ lands between the pulls, at none of the local optima
Figure 14.4.1: Client drift under label skew. Three clients with different class histograms each pull the shared model $w$ toward their own local optimum $w_k^\star$ (orange arrows); the server averages the drifted models into $\bar{w}$ (dashed circle), which sits in the middle and matches no client. The longer each client trains locally (larger $E$), the further the orange arrows reach and the worse the average becomes. The drift scalar of Section 3 measures the average length of those arrows.

The runnable demo below makes both effects measurable on a tiny softmax classifier. It builds one synthetic dataset, partitions it two ways (an IID shuffle and a label-sorted non-IID split where each client sees essentially one class), and runs FedAvg at $E=1$ and $E=10$ local epochs. For each setting it reports the average per-round drift defined above and the final test accuracy, so you can watch heterogeneity and local epochs inflate drift while accuracy slips.

import numpy as np

rng = np.random.default_rng(0)
C, D, N = 6, 30, 12000           # classes, features, total examples
NC = 30                          # number of clients
ROUNDS = 30

means = rng.standard_normal((C, D)) * 0.45
y = rng.integers(0, C, size=N)
X = means[y] + rng.standard_normal((N, D))

def softmax(Z):
    Z = Z - Z.max(axis=1, keepdims=True)
    E = np.exp(Z); return E / E.sum(axis=1, keepdims=True)

def grad(W, Xb, yb):
    P = softmax(Xb @ W)
    Y = np.zeros_like(P); Y[np.arange(len(yb)), yb] = 1.0
    return Xb.T @ (P - Y) / len(yb)

def accuracy(W):
    return float((softmax(X @ W).argmax(1) == y).mean())

def iid_partition():
    return np.array_split(rng.permutation(N), NC)

def noniid_partition():
    # Sort by label, then chop into shards: each client sees one class
    # (extreme label-distribution skew, the standard worst case).
    return np.array_split(np.argsort(y, kind="stable"), NC)

def fedavg(parts, E, lr=1.5, batch=32):
    W = np.zeros((D, C))
    sizes = np.array([len(p) for p in parts], dtype=float)
    drift_acc = 0.0
    for _ in range(ROUNDS):
        locals_ = []
        for p in parts:
            Wl = W.copy()
            for _ in range(E):                       # E local epochs
                pp = rng.permutation(p)
                for b in range(0, len(pp), batch):
                    bi = pp[b:b + batch]
                    Wl -= lr * grad(Wl, X[bi], y[bi])
            locals_.append(Wl)
        stack = np.stack(locals_)
        Wbar = np.average(stack, axis=0, weights=sizes)   # FedAvg aggregation
        # Client drift: mean distance of each local model from the average.
        drift_acc += float(np.mean([np.linalg.norm(Wl - Wbar) for Wl in stack]))
        W = Wbar
    return drift_acc / ROUNDS, accuracy(W)

iid, non = iid_partition(), noniid_partition()
print(f"{'setting':<30}{'mean client drift':>18}{'test accuracy':>16}")
print("-" * 64)
for name, parts, E in [
    ("IID,     E=1",  iid, 1),
    ("IID,     E=10", iid, 10),
    ("non-IID, E=1",  non, 1),
    ("non-IID, E=10", non, 10),
]:
    drift, acc = fedavg(parts, E)
    print(f"{name:<30}{drift:>18.3f}{acc:>16.3f}")
Code 14.4.1: FedAvg on IID versus label-sorted non-IID clients, instrumented to report the per-round client-drift scalar of Section 3 alongside final accuracy. The only differences between runs are the partition (shuffle versus label-sort) and the number of local epochs $E$.
setting                        mean client drift   test accuracy
----------------------------------------------------------------
IID,     E=1                               2.107           0.825
IID,     E=10                              4.932           0.824
non-IID, E=1                               2.632           0.818
non-IID, E=10                              4.971           0.805
Output 14.4.1: Three readings stand out. Holding $E$ fixed, the non-IID split drifts more than the IID split (2.632 against 2.107 at $E=1$). Holding the split fixed, raising $E$ from 1 to 10 more than doubles the drift. And the worst accuracy, 0.805, is the non-IID run with many local epochs, while cutting that run back to $E=1$ recovers it to 0.818, the cheapest non-IID mitigation there is.

The numbers tell the whole story of this section in one table. Non-IID data raises drift on its own, more local epochs raise it further, and the accuracy cost lands hardest exactly where both compound: many local epochs on skewed clients. The convex softmax model here is forgiving, so the accuracy gap is modest; in the non-convex deep networks of Chapter 15 the same drift can stall or destabilize training outright. The drift scalar is the early-warning signal that the final-accuracy number hides until it is too late.

4. Mitigations: From Free to Expensive Intermediate

There is a ladder of responses to client drift, ordered roughly by cost. The cheapest, demonstrated in Output 14.4.1, is to reduce the number of local steps: fewer local epochs means less room to drift, at the price of more communication rounds. The next rung adds a correction term to the local objective. FedProx, introduced in Section 14.3, augments each client's loss with a proximal penalty $\frac{\mu}{2}\lVert w - w^{\text{global}}\rVert^2$ that pulls local updates back toward the round's starting point, directly shrinking the drift scalar at the cost of a hyperparameter $\mu$ to tune. SCAFFOLD goes further with control variates: the server and each client maintain correction vectors that estimate the difference between the client's local gradient direction and the global one, and subtract it during local steps, cancelling the drift bias rather than merely penalizing it. A third rung lives on the server: applying momentum or an adaptive optimizer (FedAdam and relatives) to the sequence of aggregates smooths over the per-round noise that heterogeneity injects.

The most expensive rung relaxes the no-data-movement constraint, where policy allows. Sharing a small globally representative public dataset, or a server-held proxy set, lets clients or the server correct the skew directly; knowledge-distillation variants transmit soft predictions on shared unlabeled data instead of weights, which both sidesteps drift in weight space and shrinks the messages. These methods trade some of the privacy and bandwidth advantages that motivated federation in the first place, so they are a last resort reserved for the cross-silo settings of Section 14.2 where a small shared corpus is legally and practically available. Table 14.4.2 lays the ladder out.

Table 14.4.2: Mitigations for non-IID client drift, from cheapest to most invasive, with the mechanism each uses and the price it pays.
MitigationMechanismCost
Fewer local epochsLess local travel before averagingMore communication rounds
FedProxProximal penalty toward the global modelOne hyperparameter $\mu$
SCAFFOLDControl variates cancel the drift biasExtra state and one more vector per message
Server momentum / FedAdamAdaptive optimizer on the aggregatesServer-side state; momentum tuning
Data sharing / distillationA shared public set corrects the skewRelaxes the no-data-movement guarantee
Library Shortcut: Flower Swaps FedAvg for a Drift-Robust Strategy in One Line

In Code 14.4.1 the aggregation rule and any drift correction are hand-written inside the training loop. Production federated frameworks make the aggregation strategy a pluggable object, so moving from plain FedAvg to a drift-robust method is a one-line change with no edit to the client code. In Flower, the strategy is passed to the server:

import flwr as fl

# Plain FedAvg drifts under non-IID clients ...
strategy = fl.server.strategy.FedAvg(fraction_fit=0.2)

# ... swap in FedProx (proximal mu) or FedAdam (server momentum) instead.
strategy = fl.server.strategy.FedProx(fraction_fit=0.2, proximal_mu=0.5)
strategy = fl.server.strategy.FedAdam(fraction_fit=0.2)   # adaptive server optimizer

fl.server.start_server(config=fl.server.ServerConfig(num_rounds=30),
                       strategy=strategy)
Code 14.4.2: The same federated run as Code 14.4.1, with the aggregation strategy chosen by name. The dozen lines of hand-rolled averaging and any proximal or momentum logic collapse to one constructor call; Flower handles client sampling, the proximal term on each client, and the server-side optimizer state internally.
Practical Example: The Keyboard Model That Learned Only Slang

Who: A machine learning engineer on a mobile keyboard team training a federated next-word predictor across millions of phones.

Situation: To save battery and radio, each phone ran twenty local epochs per round before uploading, a large $E$ chosen purely to cut communication.

Problem: The global model's quality plateaued early and its suggestions skewed toward whichever cohort of phones happened to be sampled, with formal-writing users getting slang and slang users getting formal completions.

Dilemma: Cut local epochs to reduce drift and pay for far more communication rounds and battery, or keep the cheap communication schedule and add an algorithmic correction whose hyperparameters nobody on the team had tuned before.

Decision: They did both in stages: first halved the local epochs as a free fix, measured the drift scalar drop, then layered FedProx on top to claw back the rounds they had added.

How: They logged the per-round client-drift measure from Section 3 as a first-class training metric, swept the proximal $\mu$ on a held-out cohort, and adopted server-side momentum once drift was under control.

Result: The plateau lifted, cohort-dependent swings shrank, and total communication landed below the original twenty-epoch schedule once FedProx let them safely raise local epochs partway back.

Lesson: Treat client drift as a metric you watch, not a footnote. The cheapest mitigation (fewer local steps) and the algorithmic ones (FedProx, server momentum) compose, and you only know how to trade them off once drift is on the dashboard.

Research Frontier: Taming Heterogeneity (2024 to 2026)

Non-IID data remains the most active sub-field of federated learning. Recent work pushes past the FedProx and SCAFFOLD lineage in several directions. Server-side adaptive optimizers in the FedAdam and FedYogi family (Reddi et al.) are now standard baselines, and 2024 to 2026 work studies how they interact with client sampling under severe skew. Federated foundation-model fine-tuning has become a focus: parameter-efficient methods (federated LoRA and adapter tuning) cut the bytes per round so far that clients can afford fewer local epochs, attacking drift and communication together, while studies of client heterogeneity in federated instruction tuning show that skew in which clients hold which tasks reshapes the merged model. A parallel thread brings personalization (clustered and meta-learned federated models, the subject of Section 14.7) to bear on concept drift, accepting that one global model cannot fit genuinely conflicting client objectives. The honest current consensus is that no single mitigation dominates across heterogeneity types, which is why production stacks combine fewer local steps, a drift correction, and server momentum rather than betting on one.

5. Where Non-IID Data Reappears Beginner

The heterogeneity introduced here is not confined to this chapter; it is the constraint that shapes every later federated setting in the book. When federated learning moves to the network edge in Chapter 34, device populations are even more skewed and intermittently available, so the drift problem is compounded by clients that appear and vanish mid-round. When we reach the federated medical case study of Chapter 37, feature skew across hospital scanners and label skew across regional disease prevalence are the central engineering reality, not a textbook caricature. And the same averaging-of-divergent-updates tension recurs whenever local computation is traded against communication, including the local-SGD methods we will compare in Chapter 10; non-IID federated optimization is local SGD with the additional cruelty that you cannot fix the data.

With heterogeneity diagnosed and its mitigations laid out, the natural next question is the one we kept deferring: how few bytes can a federated round actually cost, and how does that budget interact with the local-epoch knob that drives drift? That is the communication-constraint story, and it begins in Section 14.5.

Exercise 14.4.1: Name the Skew Conceptual

For each federated scenario, identify which of the four heterogeneity types in Table 14.4.1 dominates, and state which mitigation from Table 14.4.2 you would reach for first and why: (a) a wildlife-camera network where each camera sees mostly the species native to its location; (b) a federation of hospitals that all diagnose the same conditions but image them on different MRI machines; (c) a keyboard federation where a handful of power users generate a thousand times more text than the median user; (d) a sentiment model where the word "sick" means approval in one user community and illness in another. Explain why a single global model is hardest to justify in case (d).

Exercise 14.4.2: Add a Proximal Term Coding

Extend Code 14.4.1 with a FedProx proximal penalty: inside the local update, replace the gradient with $\nabla \ell + \mu\,(W_{\text{local}} - W_{\text{global}})$ for a configurable $\mu$, then re-run the non-IID, $E=10$ setting while sweeping $\mu \in \{0, 0.1, 0.5, 1.0\}$. Plot or tabulate the per-round drift scalar and the final accuracy against $\mu$. Identify the value of $\mu$ that minimizes drift, confirm that very large $\mu$ over-damps the update and hurts accuracy by freezing the model near its starting point, and explain the trade-off in terms of the drift definition in Section 3.

Exercise 14.4.3: Bound the Drift Analysis

Suppose every client runs $E$ full-batch local gradient steps of size $\eta$ starting from a shared model $w$, and assume each client's gradient at $w$ differs from the global gradient by at most $\zeta$ in norm (a standard bounded-heterogeneity assumption). Argue that after one local epoch the client models can differ from each other by a quantity that grows with $\eta$, $E$, and $\zeta$, and shrinks to zero when $\zeta = 0$ (the IID case). Use this to explain analytically why Output 14.4.1 shows drift rising with both the non-IID partition and with $E$, and why halving $E$ is a first-line mitigation. You do not need a tight constant; a scaling argument in $\eta$, $E$, and $\zeta$ suffices.