"They keep sending everyone to me. I am flattered, I am overloaded, and seven of my colleagues have not seen a single token all morning."
A Popular Expert Holding Up the Whole Batch
In an expert-parallel model the gate decides not only which experts compute but, because experts live on different machines, which machines do the work; a gate that prefers a few popular experts therefore overloads a few devices and idles the rest, and the whole batch waits for the busiest one. This makes load balancing the central training problem of a mixture of experts, and it is not a modeling nicety, it is a distributed-systems problem wearing a modeling disguise. A skewed router wastes hardware (most devices sit idle while one is saturated) and wastes capacity (experts that never receive tokens never learn). This section shows why the imbalance arises, ties it to two failure modes you have already met, the straggler and data skew, and develops the three mitigations the field actually uses: an auxiliary balancing loss, expert-choice routing that is balanced by construction, and the auxiliary-loss-free bias correction that powers the most recent large models. A from-scratch demo trains a tiny gate with and without the balancing loss and watches the busiest device go from saturated to nearly fair.
The previous section moved tokens to their chosen experts with an all-to-all exchange, taking the routing decision as given. Now we confront the decision itself. In a dense layer every device does an identical share of the work, so balance is automatic; nobody chooses who computes. A mixture of experts breaks that symmetry on purpose: the gate sends each token to a small number of experts, and the experts are scattered across devices for the memory reasons that motivated expert parallelism in the first place. The gate is now, whether or not its designers think of it this way, a load balancer for the cluster. When it does that job badly, the consequences land on the hardware immediately and on model quality more slowly, and the two failures reinforce each other.
The trouble is that nothing in the basic gating objective asks for balance. The gate is trained to route each token to whichever experts most reduce the loss, and real data is not uniform: some patterns are common, some are rare, and a few experts quickly become specialists for the common patterns. Left alone, the router discovers that sending most tokens to a handful of experts is locally optimal, a self-reinforcing collapse sometimes called routing collapse. Those experts improve because they see the most data, which makes the gate prefer them even more, while the starved experts stay weak and the gate learns to avoid them further. The result is a model that has paid for many experts but trained only a few.
1. Why Imbalance Is a Distributed-Systems Problem Beginner
Expert parallelism places a disjoint set of experts on each device. When a batch arrives, the gate's choices induce a load on each device equal to the number of tokens routed to the experts it holds. The forward pass through the expert layer cannot finish until every device has processed its share, so the wall-clock cost of the layer is set by the busiest device, not the average one. This is exactly the straggler effect introduced in Chapter 2: a synchronous step proceeds at the speed of its slowest participant, and one overloaded device makes everyone else wait at the next all-to-all barrier. Adding more experts and more devices does not help if the gate keeps funnelling tokens to the same few; you have bought parallel hardware and then serialized it through a hot spot.
The mechanism that produces the hot spot is also familiar. A gate that concentrates tokens on popular experts is producing a skewed partition of the batch, the same pathology that wrecks a Spark shuffle when one key dominates, studied in Chapter 7. There, one popular join key sends a disproportionate share of rows to a single reducer; here, one popular expert receives a disproportionate share of tokens on a single device. The cause differs (a learned router rather than a data distribution) but the cure rhymes: detect the skew, and reshape the assignment so no single partition carries far more than its fair share.
Because experts are pinned to devices, the router's per-token choices are also per-device work assignments. The cost of the expert layer is governed by the busiest device, so a gate that is even slightly skewed converts expensive parallel hardware into a serial bottleneck. Load balancing is therefore not a regularizer you add for tidiness; it is the thing that determines whether expert parallelism delivers any speedup at all. Quality and throughput fail together: the overloaded experts dominate learning while the idle ones never train, so the cure for the systems problem is also the cure for the modeling problem.
2. The Auxiliary Load-Balancing Loss Intermediate
The classic fix, introduced with the sparsely gated mixture of experts and carried into Switch Transformer, is to add a term to the training objective that penalizes imbalance directly. For a batch of $T$ tokens and $E$ experts, define $f_e$ as the fraction of tokens dispatched to expert $e$ under the (hard, top-$k$) routing, and $P_e$ as the mean gate probability assigned to expert $e$ across the batch,
$$f_e = \frac{1}{T}\sum_{t=1}^{T}\mathbb{1}[\,\text{token } t \text{ routed to } e\,], \qquad P_e = \frac{1}{T}\sum_{t=1}^{T} g_e(x_t),$$where $g_e(x_t)$ is the softmax gate probability of expert $e$ for token $t$. The auxiliary loss is the scaled dot product of these two vectors,
$$\mathcal{L}_{\text{aux}} = \alpha \, E \sum_{e=1}^{E} f_e \, P_e .$$Two facts make this the right object. First, the dot product $\sum_e f_e P_e$ is minimized, subject to each vector summing to one, when both are uniform at $1/E$, so driving it down pushes usage toward the even split we want; the factor $E$ keeps the target value near a constant as $E$ changes, and $\alpha$ (a small weight such as $10^{-2}$) sets how hard balance competes with the task loss. Second, the gradient is well behaved: the hard counts $f_e$ are not differentiable, but $P_e$ is, so the gradient flows through the gate probabilities and gently lowers the logits of experts that are already popular ($f_e$ large) while raising the rest. The term nudges the router toward fairness without dictating any single token's destination, which is what lets the experts still specialize. Section 17.7 picks up the companion control, capacity factors and token dropping, that bounds the damage when balance is imperfect.
The code below trains a tiny gate over synthetic tokens that genuinely cluster toward a few popular experts, first with no balancing term and then with the auxiliary loss switched on. It reports per-expert usage, the load on the busiest of four devices (two experts each), and the max-to-mean ratio that quantifies the skew.
import numpy as np
rng = np.random.default_rng(7)
N, d, E = 6000, 16, 8 # tokens, feature dim, experts
# Tokens that genuinely cluster toward experts 0 and 1, so an unregularized
# gate piles onto them: a learned, self-reinforcing popularity skew.
pop = rng.standard_normal((E, d)); pop[0] *= 3.0; pop[1] *= 2.2
assign = rng.choice(E, size=N, p=[.40, .22, .10, .08, .07, .06, .04, .03])
X = pop[assign] + 0.6 * rng.standard_normal((N, d))
def softmax(z):
z = z - z.max(axis=1, keepdims=True); e = np.exp(z)
return e / e.sum(axis=1, keepdims=True)
def usage_fraction(W): # top-1 routing, fraction per expert
return np.bincount(np.argmax(softmax(X @ W), axis=1), minlength=E) / N
def train(aux_weight, steps=800, lr=0.5):
W = 0.01 * np.random.default_rng(0).standard_normal((d, E)) # fixed init
for _ in range(steps):
g = softmax(X @ W)
# Task pull: a rich-get-richer signal that nudges each token toward the
# expert already most confident for it, so popular experts grow more
# popular and an unregularized gate tends to collapse onto a few of them.
hard = np.zeros((N, E)); hard[np.arange(N), np.argmax(g, axis=1)] = 1.0
grad_task = X.T @ (g - hard) / N
# Auxiliary balance loss aux = E * ||P||^2, the smooth surrogate for
# E * sum_e f_e P_e, minimized at uniform usage P_e = 1/E. Its gradient
# flows through the softmax mean P_e = mean_t g and pushes mass off hot experts.
P = g.mean(axis=0)
v = 2.0 * E * P
grad_aux = X.T @ (g * (v - (g @ v)[:, None])) / N
W -= lr * (grad_task + aux_weight * grad_aux)
return W
for label, aw in [("no aux loss ", 0.0), ("with aux loss", 1.0)]:
u = usage_fraction(train(aw))
dev = u.reshape(4, 2).sum(axis=1) # 4 devices, 2 experts each
print(f"{label}: usage% = [{', '.join(f'{100*x:4.1f}' for x in u)}]")
print(f" busiest device = {100*dev.max():4.1f}% (ideal 25.0%),"
f" max/mean expert = {u.max()/u.mean():.2f}x")
grad_aux term, the gradient of the balance surrogate $\mathcal{L}_{\text{aux}} = E\sum_e P_e^2$ through the differentiable gate probabilities, pulls usage back toward even; only the relative weight aux_weight changes between the two runs.no aux loss : usage% = [ 0.0, 55.4, 0.0, 0.0, 31.1, 13.5, 0.0, 0.0]
busiest device = 55.4% (ideal 25.0%), max/mean expert = 4.43x
with aux loss: usage% = [13.1, 14.1, 13.5, 16.5, 10.3, 7.5, 10.4, 14.6]
busiest device = 30.0% (ideal 25.0%), max/mean expert = 1.32x
The improvement is decisive here, and the honest lesson of the demo is why: left alone, the rich-get-richer pull lets a few experts swallow the batch and starves the rest into dead weight, exactly the routing collapse this section warns about. The auxiliary loss counters that feedback and brings the idle experts back into play. On real data the gate also carries genuine specialization pressure, so a practical $\alpha$ trades a little of that specialization for balance rather than flattening usage all the way; pushing $\alpha$ higher buys more balance at a rising cost to model quality, which is the central tension of the whole topic.
Expert parallelism is the sparse relative of the data parallelism from Chapter 15: instead of every worker holding the whole model and a slice of the data, every worker holds a slice of the model and the data flows to it. That inversion is what imports the skew problem. In data-parallel training the partition of work is fixed and even by construction; in expert-parallel training the partition is learned per batch by the gate, so it can collapse onto a few devices exactly the way a MapReduce or Spark shuffle collapses onto a hot key. Every parallel method in this book is ultimately a question of how work is partitioned across machines and what it costs to keep that partition balanced; the mixture of experts is the case where the partition is a trainable parameter, which is both its power and its hazard.
3. Balanced by Construction, and Balanced Without a Loss Advanced
The auxiliary loss treats imbalance as something to penalize after the fact. Two other strategies attack it more directly. The first is to change who does the choosing. In the token-choice routing of Section 17.3, each token picks its top-$k$ experts, and nothing stops a popular expert from being everyone's pick. Expert-choice routing inverts the selection: each expert picks the top-$T k / E$ tokens it wants from the batch. Because every expert selects exactly the same number of tokens, the load is uniform by construction, with no auxiliary loss required and no device able to become a hot spot. The cost is that a token may be chosen by several experts or by none, so expert-choice trades guaranteed device balance for variable per-token capacity, a different knob on the same trade-off that Section 17.7 formalizes.
The second strategy keeps token choice but removes the auxiliary loss entirely. The auxiliary loss has a known side effect: its gradient perturbs the task objective, and an $\alpha$ large enough to enforce balance can measurably hurt quality, an interference the literature calls the balance-versus-specialization tension. The auxiliary-loss-free approach replaces the loss term with a per-expert bias $b_e$ added to the routing logits only for the top-$k$ selection. After each step the biases are nudged by a simple control rule: lower $b_e$ for experts that were overloaded, raise it for experts that were starved,
$$\text{route on } g_e(x_t) + b_e, \qquad b_e \leftarrow b_e + \gamma \,\big(\bar{f} - f_e\big),$$where $\bar f = 1/E$ is the target fraction and $\gamma$ is a small update rate. The bias steers the routing toward balance, but because it is added only for selection and not to the gate value used in the weighted combination, it never appears in the task gradient, so it balances load without distorting what the experts learn. This is the mechanism DeepSeek-V3 uses at scale, and it is the current state of the art for keeping a large mixture of experts balanced while letting the experts specialize freely.
The dominant 2024 to 2026 direction is to stop paying the quality tax of the auxiliary loss. Wang et al. (2024), "Auxiliary-Loss-Free Load Balancing Strategy for Mixture-of-Experts" (the Loss-Free Balancing method), introduced the per-expert bias update above and showed it reaches better perplexity than auxiliary-loss balancing at the same level of load balance. DeepSeek-V3 (DeepSeek-AI, 2024) adopts exactly this bias-based balancing across its 256 routed experts, pairing it with a tiny "sequence-wise" auxiliary term only to prevent extreme within-sequence collapse, and reports strong balance with negligible quality cost. A parallel line studies global balance across a long training run rather than per-batch, since a router that is balanced on average can still spike on individual batches, and the spikes are what drop tokens and stall devices. The throughline is that balance is increasingly enforced by lightweight controllers on the routing decision rather than by terms bolted onto the loss, which keeps the gate free to specialize. We return to the stability side of this story, capacity factors and token dropping, in Section 17.7.
In Code 17.6.1 we derived the auxiliary loss and its gradient by hand. In a real training stack the layer returns the balance loss alongside its output and the trainer adds it to the task loss; the $f_e \cdot P_e$ bookkeeping over the dispatched tokens is done for you, and frameworks such as DeepSpeed-MoE and Megatron-Core wire it in automatically when you declare a mixture-of-experts layer:
import torch, torch.nn.functional as F
def load_balance_loss(gate_logits, expert_index, num_experts):
# gate_logits: (T, E) router scores; expert_index: (T,) the top-1 pick per token.
P = F.softmax(gate_logits, dim=-1).mean(dim=0) # mean prob per expert
f = torch.bincount(expert_index, minlength=num_experts).float() / gate_logits.size(0)
return num_experts * torch.dot(f, P) # alpha applied by caller
# In the training step the layer hands this back; you just add it to the task loss:
# loss = task_loss + alpha * aux_loss # alpha ~ 1e-2 in Switch Transformer
Who: A distributed-training engineer bringing up a 64-expert language model on 16 GPUs, four experts per device.
Situation: Training ran, loss decreased, but throughput was roughly a third of the dense baseline the team had projected for the same parameter budget.
Problem: Per-device profiling showed two GPUs pinned near 100% utilization while the other fourteen idled below 30%, with the all-to-all barrier stalling on the busy pair every step.
Dilemma: Crank the auxiliary-loss weight $\alpha$ up hard enough to force balance, accepting a measurable hit to validation perplexity, or switch the balancing mechanism and keep quality intact.
Decision: They first raised $\alpha$ and confirmed the diagnosis: balance improved and the idle GPUs woke up, but perplexity regressed, the textbook balance-versus-specialization tension. They then replaced the loss with the auxiliary-loss-free per-expert bias update.
How: They removed the loss term, added a per-expert bias to the routing logits, and after each step nudged each bias by $\gamma(\bar f - f_e)$ from the measured dispatch counts, exactly the rule in Section 3.
Result: The max-to-mean device load fell from above 3x to near 1.2x within a few hundred steps, GPU utilization evened out across all sixteen devices, throughput roughly tripled to match the projection, and validation perplexity was no worse than the unbalanced run, the same direction Output 17.6.1 shows at toy scale.
Lesson: When throughput collapses on a mixture of experts, suspect the router before the network. Measure per-device load first, and prefer a balancing mechanism that does not fight the task gradient.
A collapsed router produces the saddest object in distributed deep learning: a fully provisioned expert, with its own weights, its own slice of a very expensive GPU, and exactly zero tokens to its name across the entire epoch. It is paid for, powered on, and never asked to think. The auxiliary loss and the bias term are, at heart, an inclusion policy for these neglected experts, a gentle institutional pressure to spread the invitations around so that everyone gets to learn something.
4. The Balance-Versus-Specialization Tension Advanced
Perfect balance and perfect specialization pull in opposite directions, and a good mixture of experts lives at a deliberate compromise between them. Push balance to its limit and you would route tokens to experts almost at random, guaranteeing a flat load but destroying the very specialization that makes a sparse model worth more than a smaller dense one; the experts would all learn the same average function. Push specialization to its limit and the gate collapses onto a few experts, which trains well but utilizes the cluster terribly and starves most of the model. The whole engineering art is to enforce just enough balance that no device becomes a straggler, while leaving the gate free to send genuinely different tokens to genuinely different experts.
This is why the field has drifted from heavy auxiliary losses toward lighter-touch controllers. The auxiliary loss enforces balance by adding a force to the loss landscape that every token feels, which inevitably bends the task objective. The bias-update approach enforces balance by adjusting only the selection threshold, leaving the gradient that shapes the experts untouched, so it buys balance at a lower price in specialization. Either way, balance is never free, and the right operating point depends on how skewed your data is and how tight your hardware budget is. The performance models of Chapter 3 give you the language to put a number on the throughput cost of a given imbalance, and that number is what tells you how hard to push balance for a given cluster.
Using Output 17.6.1, the no-aux run sends 55.4% of tokens to expert 1 and 0.0% to four other experts, with a busiest-device load of 55.4% on four devices. (a) If the expert layer's compute is perfectly proportional to tokens, by what factor does the skewed run stretch the layer's wall-clock time compared to a perfectly balanced run, given that the synchronous step waits on the busiest device? (b) Explain why simply adding a fifth and sixth device would not fix the problem as long as the gate keeps its choices, and why the four dead experts mean the model wasted parameters it paid for. (c) Connect both answers to the straggler discussion in Chapter 2.
Extend Code 17.6.1 with a third training mode that uses no auxiliary loss at all. Keep a per-expert bias vector $b$ initialized to zero, route on $g_e(x_t) + b_e$ for the top-1 selection, and after each step update $b_e \leftarrow b_e + \gamma(\bar f - f_e)$ with $\bar f = 1/E$ and a small $\gamma$ (try $0.01$). Crucially, use the unbiased gate probabilities, not $g_e + b_e$, in the task gradient so the bias never enters the task objective. Report the final per-expert usage, the busiest-device load, and the max-to-mean ratio, and compare all three modes (none, auxiliary loss, bias). Which reaches the most even load, and does it do so without the auxiliary-loss term touching the task gradient?
A 32-expert layer is spread over 8 devices (4 experts each), and per-batch profiling shows the busiest device holds a fraction $m$ of the tokens while the fair share is $1/8$. (a) Write the throughput of the expert layer, relative to a perfectly balanced layer, as a function of $m$, assuming the step waits on the busiest device. (b) For $m = 0.30$, $0.20$, and $0.14$, compute the relative throughput and the fraction of aggregate device-time wasted on idle devices. (c) Suppose the auxiliary loss can reduce $m$ from $0.30$ to $0.16$ but costs $1.5\%$ in model quality on your eval. Using the cost language of Chapter 3, argue whether the trade is worth it for a training run that is throughput-bound versus one that is quality-bound, and state what additional number you would measure to decide.