"A local optimum at each hospital, pulling the global model in five directions at once. I average them and please nobody."
An Aggregation Server Caught Between Five Populations
The defining technical difficulty of federated medical AI is that the data is not identically distributed across the hospitals that hold it, and that single fact destabilizes the averaging step at the heart of the method. Each site serves a different population, runs different scanners, codes diagnoses differently, and sees a different mix of diseases. When every site trains locally and the server averages the results, the local solutions have drifted toward different local optima, and their average is pulled away from the optimum of the pooled objective the federation actually wants to minimize. This section names the kinds of heterogeneity precisely, shows mathematically why naive averaging (FedAvg) slows and biases convergence under it, and then develops the three remedies that make federated learning robust to it: proximal regularization (FedProx), control variates that cancel the drift (SCAFFOLD, a variance-reduction idea you met in Chapter 10), and personalization that stops pretending one model fits every site. Heterogeneity is not a defect to be removed; it is the price of distributing learning across institutions that genuinely differ, and the algorithm must be built to survive it.
The previous section assembled the federation: a set of hospitals that will not move their patient records, a coordinating server that never sees raw data, and a training loop in which each round broadcasts a global model, lets each site take several local optimization steps, and averages the returned weights. That loop is FedAvg, and on data that is independent and identically distributed across sites it converges much like ordinary distributed SGD. Clinical data is the opposite of identically distributed. A pediatric hospital, a veterans' hospital, an oncology center, and a rural general practice differ in almost every statistic that matters, and the federation must learn one model (or a family of them) despite that. This section is about what that difference does to the math and what to do about it.
We proceed in three movements. First, a taxonomy: the distinct ways clinical data can be non-IID, because the remedies differ by type. Second, a diagnosis: a precise account of client drift, the mechanism by which heterogeneity corrupts FedAvg. Third, the treatments: FedProx, SCAFFOLD, and personalization, each attacking a different part of the drift, with a runnable demonstration that FedProx tightens a federation of five skewed hospitals.
1. The Four Faces of Non-IID Clinical Data Beginner
"Non-IID" is a single label for several genuinely different statistical phenomena, and conflating them leads to reaching for the wrong remedy. Let site $k$ hold a local data distribution $\mathcal{D}_k(x, y)$ over features $x$ and labels $y$. The federation behaves as though all data were drawn from a single pooled distribution $\bar{\mathcal{D}}$, but each $\mathcal{D}_k$ departs from $\bar{\mathcal{D}}$ in one or more ways. Four faces recur in medical settings, and Figure 37.5.1 will later show how each one tugs the local optima apart.
Label or prevalence skew. The marginal $P_k(y)$ differs across sites even when the conditional $P_k(x \mid y)$ is shared. A specialist oncology center sees malignant cases at twenty times the rate of a general clinic; a pediatric hospital almost never sees adult-onset disease. Formally $P_k(y) \neq P_{k'}(y)$. This is the most common and most damaging form in clinical federations, because a site that has seen almost no positives cannot locally learn the decision boundary for them.
Feature-distribution shift (covariate shift). The marginal $P_k(x)$ differs while the labeling rule $P_k(y \mid x)$ is shared. Two hospitals image the same pathology, but one uses a 1.5-tesla scanner and the other a 3-tesla machine, or different contrast-agent protocols, so the pixel statistics differ systematically. The disease is the same; the appearance is not.
Quantity skew. Sites hold wildly unequal amounts of data. A flagship academic center may contribute a hundred thousand records; a partner rural clinic, a few hundred. The federation must weight contributions so that the average reflects the pooled objective rather than the loudest site, a weighting question that connects directly to the size-weighted combination you saw in Section 1.1.
Concept shift. The conditional $P_k(y \mid x)$ itself differs: the same features map to different labels across sites. This is subtle and dangerous in medicine, because diagnostic criteria and coding practice vary. One hospital codes a borderline lab value as positive; another, following a different guideline, codes it negative. The underlying biology is identical but the label function is not, and no amount of feature alignment can reconcile it without harmonizing the labels first.
Label skew, covariate shift, quantity skew, and concept shift are four different failures wearing one name. Proximal and control-variate methods (FedProx, SCAFFOLD) directly counter the optimization damage of label and covariate skew. Quantity skew is fixed by correct size-weighted aggregation, not by a new algorithm. Concept shift cannot be averaged away at all; it demands either label harmonization upstream or an explicitly personalized model per site. Diagnosing which face you face is the first engineering decision, because applying a drift-correction method to a concept-shift problem will converge cleanly to the wrong answer.
2. Client Drift: Why FedAvg Suffers Intermediate
The federation's true target is the pooled objective, a size-weighted average of the per-site losses. With $n_k$ examples at site $k$ and $n = \sum_k n_k$ total,
$$F(w) = \sum_{k=1}^{K} \frac{n_k}{n}\, F_k(w), \qquad F_k(w) = \mathbb{E}_{(x,y)\sim \mathcal{D}_k}\big[\ell(w; x, y)\big].$$FedAvg approaches this objective indirectly. In each round the server broadcasts the current global weights $w^t$; each site runs $\tau$ steps of local SGD on its own $F_k$, producing $w_k^{t+1}$; the server averages them. The trouble is what happens during those $\tau$ local steps. Each site is descending its own loss surface, whose minimizer $w_k^\star = \arg\min_w F_k(w)$ sits at a different place than the global minimizer $w^\star = \arg\min_w F(w)$ whenever the data is heterogeneous. After $\tau$ steps the local iterate has moved toward $w_k^\star$, and the displacement from the broadcast point,
$$\delta_k^{t} = w_k^{t+1} - w^{t},$$is the client drift of site $k$. On IID data the drifts point in compatible directions and their average is a good global descent step. On non-IID data the drifts point toward scattered local optima, and the averaged update $w^{t+1} = \sum_k \frac{n_k}{n} w_k^{t+1}$ lands at the centroid of those local pulls, not at a descent step of $F$. The more local steps $\tau$ you take to save communication, the further each site wanders and the worse the centroid approximates the true global step. Figure 37.5.1 makes this geometry explicit.
The cost is real and twofold. Convergence is slower, because each round makes less true global progress than its local effort suggests, and it is biased, because the fixed point of FedAvg under heterogeneity is not in general $w^\star$ but a point displaced by the residual drifts. Increasing $\tau$ to amortize the communication cost (the central economy of federated learning, since each round is an expensive wide-area exchange) makes both problems worse. This is the heterogeneity tax: the same trick that saves bandwidth amplifies the drift. The non-IID analysis of FedAvg in Chapter 14 bounds the resulting error in terms of a skew measure; a convenient one is the gradient dissimilarity at the optimum,
$$\zeta^2 = \sup_{w}\; \sum_{k=1}^{K} \frac{n_k}{n} \big\| \nabla F_k(w) - \nabla F(w) \big\|^2,$$which is exactly zero when every site shares the pooled gradient (the IID case) and grows with the spread of the local objectives. Every convergence bound for federated optimization carries a term proportional to $\zeta^2 \tau$, the product of heterogeneity and local-step count, which is the formal statement of the geometry in Figure 37.5.1.
3. FedProx: Anchor the Local Step Intermediate
The simplest remedy attacks the drift directly: forbid the local iterate from wandering too far from the broadcast model. FedProx replaces each site's local objective with a proximally regularized version,
$$h_k(w; w^t) = F_k(w) + \frac{\mu}{2}\,\big\| w - w^t \big\|^2,$$and each site minimizes $h_k$ instead of $F_k$ during its local steps. The quadratic penalty $\tfrac{\mu}{2}\|w - w^t\|^2$ is a spring that pulls every local trajectory back toward the shared starting point; the strength $\mu$ tunes how far sites may roam. At $\mu = 0$ FedProx is exactly FedAvg. As $\mu$ grows, the local moves shrink (the green pull in Figure 37.5.1), the drifts $\delta_k$ shorten, and the averaged update lands closer to a true descent step of $F$. The trade is explicit: large $\mu$ buys stability and an unbiased direction but slows per-round progress, because each site is allowed to do less local work. FedProx also tolerates sites that perform variable, partial amounts of local computation, which matters when a hospital's node is slow or preempted, a robustness property that pairs naturally with the elastic-training ideas of Chapter 18.
The demonstration below builds a five-hospital federation with strong label skew (prevalence ranging from 5% to 80%) and a per-site feature shift, then runs FedAvg ($\mu = 0$) and FedProx ($\mu = 1$) on the identical data with identical local effort. It reports the global log-loss each method reaches and the mean client drift it leaves behind.
import numpy as np
# Five hospitals, binary diagnosis. Each site has a DIFFERENT disease prevalence
# (label skew) plus a site-specific feature shift -> non-IID across clients.
rng = np.random.default_rng(7)
d = 6
prevalence = [0.05, 0.15, 0.30, 0.55, 0.80] # per-site positive rate (label skew)
n_per_site = 400
w_star = rng.standard_normal(d) # shared ground-truth signal
def make_site(p):
y = (rng.random(n_per_site) < p).astype(float)
shift = rng.standard_normal(d) * 0.6 # feature shift
X = rng.standard_normal((n_per_site, d)) + np.outer(2*y-1, w_star)*0.8 + shift
return X, y
sites = [make_site(p) for p in prevalence]
sigmoid = lambda z: 1.0 / (1.0 + np.exp(-np.clip(z, -30, 30)))
grad = lambda w, X, y: X.T @ (sigmoid(X @ w) - y) / len(y)
def logloss_global(w): # pooled objective F(w)
tot, nn = 0.0, 0
for X, y in sites:
p = sigmoid(X @ w)
tot += -np.sum(y*np.log(p+1e-12) + (1-y)*np.log(1-p+1e-12)); nn += len(y)
return tot / nn
def local_train(w_global, X, y, lr, steps, mu):
w = w_global.copy()
for _ in range(steps):
g = grad(w, X, y) + mu * (w - w_global) # mu>0 adds the FedProx pull
w = w - lr * g
return w
def federated(mu, rounds=40, local_steps=60, lr=0.5):
w = np.zeros(d)
for _ in range(rounds):
locals_ = [local_train(w, X, y, lr, local_steps, mu) for X, y in sites]
drift = np.mean([np.linalg.norm(wl - w) for wl in locals_]) # client drift
w = np.mean(locals_, axis=0) # FedAvg aggregation
return logloss_global(w), drift
loss_avg, drift_avg = federated(mu=0.0) # FedAvg
loss_prox, drift_prox = federated(mu=1.0) # FedProx
print(f"site prevalences : {prevalence}")
print(f"FedAvg global logloss: {loss_avg:.4f} mean client drift: {drift_avg:.3f}")
print(f"FedProx global logloss: {loss_prox:.4f} mean client drift: {drift_prox:.3f}")
print(f"drift reduction : {100*(1-drift_prox/drift_avg):.1f}%")
print(f"objective improvement : {100*(loss_avg-loss_prox)/loss_avg:.1f}%")
mu * (w - w_global) added to each local gradient; setting mu=0 recovers FedAvg, mu=1 gives FedProx.site prevalences : [0.05, 0.15, 0.3, 0.55, 0.8]
FedAvg global logloss: 0.2832 mean client drift: 0.710
FedProx global logloss: 0.2744 mean client drift: 0.115
drift reduction : 83.8%
objective improvement : 3.1%
The numbers track the geometry of Figure 37.5.1 exactly. FedAvg's local solutions drift far (mean displacement 0.710) and their average underperforms the pooled objective; FedProx restrains them (drift 0.115) and the restrained average is closer to the true global optimum. The improvement looks modest as a percentage because log-loss is already small here, but the drift collapse is the load-bearing result: it is the drift that biases the fixed point and, on harder non-convex clinical models, separates a federation that converges from one that oscillates.
Who: An ML engineer building a federated early-sepsis-warning model across eight hospitals.
Situation: A vanilla FedAvg run with twenty local epochs per round (chosen to keep the wide-area communication bill low) refused to converge: validation AUROC swung several points between rounds and never stabilized.
Problem: Two academic centers contributed most of the ICU positives; several community hospitals had almost none. Label skew plus high local-step count produced exactly the drift of Figure 37.5.1, and the averaged model lurched between the academic and community optima each round.
Dilemma: Cut local epochs to tame the drift (tripling the communication cost across the wide-area link) or change the algorithm.
Decision: Keep the local-step count, add a FedProx proximal term with $\mu$ tuned on a held-out federation split.
How: One line in the local trainer, the same mu * (w - w_global) addition as Code 37.5.1, with $\mu$ swept over a small grid.
Result: The oscillation vanished, AUROC stabilized and rose, and the communication budget was untouched.
Lesson: When a federation oscillates rather than fails outright, suspect client drift before you suspect the model. The proximal anchor is the cheapest first intervention.
4. SCAFFOLD: Correct the Drift with Control Variates Advanced
FedProx shortens the drift but does not cancel it; it slows every site uniformly, including sites that were not drifting. SCAFFOLD is sharper. It asks: in which direction does each site systematically deviate from the global descent direction, and can we subtract that deviation off so local steps point the right way without being shortened? This is precisely a variance-reduction question, and the tool is the same control variate you met in stochastic optimization in Chapter 10. There, a control variate replaced a high-variance gradient estimate with one whose noise had been partly subtracted off using a correlated, known-mean quantity. SCAFFOLD applies that idea across sites rather than across minibatches.
It maintains a server control variate $c$ approximating the global gradient direction and a per-site control variate $c_k$ approximating site $k$'s gradient direction. During local step on site $k$, the plain local gradient $\nabla F_k(w)$ is corrected to
$$w \leftarrow w - \eta\,\big( \nabla F_k(w) - c_k + c \big).$$The term $(c - c_k)$ is the control variate: it estimates how far site $k$'s local gradient habitually points away from the global gradient and adds that difference back, steering the local step toward the global descent direction. If a site's gradient chronically leans toward its own optimum, $c_k$ captures that lean and the correction cancels it. When the control variates are accurate, the corrected local steps point along the global gradient regardless of heterogeneity, and SCAFFOLD provably converges at a rate independent of the skew $\zeta^2$, which FedProx cannot promise. The cost is communication: each round must exchange the control variates alongside the model, doubling the per-round payload, and each site must hold persistent state, which is awkward when sites participate intermittently.
The control variate is not new machinery invented for federated learning; it is the variance-reduction idea from single-machine stochastic optimization (Chapter 10), scaled out across institutions. In Chapter 10 it subtracted minibatch noise from a gradient estimate on one worker. Here the same algebra subtracts inter-site drift, the systematic disagreement between a hospital's gradient and the federation's, so that distributing the optimization across heterogeneous data costs no extra error in the limit. The book's recurring move, take a primitive and run it across machines, applies as cleanly to a noise-cancellation trick as it did to the all-reduce of Section 1.1.
5. Personalization: Stop Pretending One Model Fits All Intermediate
FedProx and SCAFFOLD both assume a single global model is the right target and merely fight to reach it. When heterogeneity is severe, and especially when it is concept shift (Section 1), that assumption is wrong: no single weight vector serves a pediatric hospital and a geriatric oncology center equally well. Personalized federated learning abandons the one-model premise. The federation still collaborates to learn shared structure, but each site ends with a model adapted to its own distribution.
The simplest form is fine-tuning: run federated training to a good global model, then let each site take a few local gradient steps on its own data to specialize. More structured approaches split the network into shared and private parts, a common feature extractor trained federally plus a per-site head, so that sites share representation but keep their own decision layer. A complementary framing treats heterogeneity as a domain-generalization problem: each hospital is a domain, and the goal is a model whose shared component generalizes across domains while a light per-domain adaptation handles the rest. Personalization trades the clean promise of a single audited model for per-site accuracy, which raises an evaluation question the federation cannot dodge: a global metric can hide a model that is excellent on the two largest sites and useless on the smallest. Per-site evaluation, the stratified reporting discipline of Chapter 5, is therefore not optional in a heterogeneous federation; it is the only way to see whether the shared model serves every population or just the loudest ones.
Code 37.5.1 hand-rolled the aggregation so the proximal term was visible. In production you do not reimplement FedAvg, FedProx, or SCAFFOLD; a federated-learning framework supplies each as a pluggable server strategy. With Flower, moving from FedAvg to FedProx is a one-line change of the strategy object, and the proximal term is injected into the client loop by the framework:
import flwr as fl
# FedAvg baseline ...
strategy = fl.server.strategy.FedAvg(fraction_fit=1.0)
# ... swap to FedProx by changing ONE line; proximal_mu is the mu of Code 37.5.1.
strategy = fl.server.strategy.FedProx(fraction_fit=1.0, proximal_mu=1.0)
fl.server.start_server(strategy=strategy,
config=fl.server.ServerConfig(num_rounds=40))
Data heterogeneity remains the most active problem in federated medical AI. Recent work extends drift correction to the foundation-model era: federated and parameter-efficient fine-tuning (federated LoRA and adapter methods) shares only low-rank updates across hospitals, shrinking the payload that SCAFFOLD's full control variates inflate while retaining drift robustness. Clustered federated learning groups sites with similar distributions and learns one model per cluster, a middle ground between a single global model and full personalization that fits the multi-modal nature of clinical populations. On the analysis side, tighter bounds relate the skew measure $\zeta^2$ to achievable accuracy and inform how aggressively a federation can raise the local-step count before drift dominates. And because clinical federations must also be private, a parallel thread studies how differential-privacy noise (the secure-aggregation and DP machinery of Chapter 35) interacts with heterogeneity, since both inject error and the federation must budget for their sum. The throughline of the 2024 to 2026 literature is that heterogeneity is treated as a quantity to be measured and engineered against, not a nuisance to be wished away.
Averaging five hospital models that each overfit their own population yields a model with the bedside manner of a committee: competent nowhere, offensive to no one, and quietly worst at exactly the rare cases each specialist was good at. Client drift is the formal name for the disappointment you feel when you average five experts and get a mediocre generalist. FedProx is the gentle reminder that they all started from the same textbook this morning.
The arc of this section is the arc of the whole subject in miniature. Heterogeneity is the price of distributing learning across institutions that genuinely differ, just as communication is the price of distributing computation across machines. You do not escape the price; you engineer against it. Name the skew, measure the drift, anchor or correct or personalize, and evaluate per site. The next section turns from the optimization to the wire: how the federation moves these updates privately, the secure-aggregation and communication-compression layer that sits beneath every method here.
For each scenario, state which face of non-IID data (label skew, covariate shift, quantity skew, concept shift) dominates and which remedy from Sections 3 to 5 you would reach for first, justifying why the others would underperform: (a) ten hospitals image the same tumor type on scanners from four vendors with different intensity profiles; (b) one academic center holds 90% of the federation's labeled records and nine clinics hold the rest; (c) two hospitals follow different clinical guidelines, so identical lab panels receive opposite diagnostic codes; (d) a rare-disease registry where most sites have seen fewer than five positive cases. Explain specifically why applying SCAFFOLD to scenario (c) would converge cleanly to a model that is wrong for both hospitals.
Extend Code 37.5.1 to sweep $\mu$ over $\{0, 0.1, 0.5, 1, 2, 5, 10\}$ and, for each, record the final global log-loss, the mean client drift, and the number of rounds to reach within 1% of the best log-loss achieved across the sweep. Plot drift and log-loss against $\mu$. Identify the value of $\mu$ that minimizes log-loss and explain the U-shape (or its absence): why does very large $\mu$ eventually hurt the pooled objective even though it drives drift toward zero? Relate your answer to the per-round-progress trade described in Section 3.
Using the five sites of Code 37.5.1, estimate the gradient-dissimilarity skew measure $\zeta^2$ from Section 2 empirically: at the global FedAvg solution, compute each site's gradient $\nabla F_k$ and the pooled gradient $\nabla F$, and form the size-weighted mean squared difference. Then halve the prevalence spread (move every site's positive rate toward 0.5) and recompute $\zeta^2$ and the FedAvg-versus-FedProx drift gap. Verify empirically that the drift gap shrinks as $\zeta^2$ falls, and discuss how a federation could monitor $\zeta^2$ in practice without sharing data to decide whether drift correction is worth its communication cost.