Part III: Distributed Machine Learning
Chapter 10: Distributed Optimization

Large-Batch Training and Learning-Rate Scaling

"They gave me a thousand friends to share the work, then doubled my step size to match. For a while I sprinted. Then I tripped over my own enthusiasm and they invented warmup."

An Optimizer That Scaled Its Learning Rate
Big Picture

In synchronous distributed SGD, every worker you add enlarges the global batch, and the global batch is the single knob that couples your hardware budget to your optimizer's stability. Adding workers does not change the math of the gradient (it stays the exact average from Section 1.1); it changes the statistics of each update by averaging over more examples. A larger, lower-variance gradient invites a larger step, and the linear scaling rule makes that precise: raise the learning rate in proportion to the batch, warm it up so the early steps do not diverge, and a $K$-worker run tracks the small-batch run almost exactly, until you cross the critical batch size, past which more workers stop buying convergence speed. This section is about that coupling: how to ride it, and where it runs out.

The previous section reduced the bytes moved per synchronization step. This one asks a complementary question: as you add workers to a synchronous data-parallel job, what actually changes for the optimizer, and how do you keep it converging? The answer is entirely about batch size. With $K$ workers each computing a gradient over a local batch of $B_{\text{loc}}$ examples and an all-reduce averaging them, the optimizer sees one update built from a global batch of $B = K \, B_{\text{loc}}$ examples. Doubling the worker count doubles $B$. Everything stable and everything fragile about large-scale synchronous training follows from that one identity, so we start there and build outward to the rules that tame it.

1. More Workers Means a Bigger Batch, Which Is Weak Scaling Beginner

Hold the per-worker batch fixed and add workers: the global batch grows linearly, and the number of optimizer updates per pass over the data shrinks by the same factor. This is the defining signature of weak scaling, where the total problem size grows with the worker count rather than staying fixed, the regime introduced in Section 3.3. Strong scaling would keep the global batch constant and shrink each worker's share, which quickly starves the accelerators; the dominant practice in distributed training is instead to keep each accelerator well-fed at a fixed local batch and let the global batch expand. So the question "how many workers?" is, for the optimizer, identical to the question "how large a batch can I train with and still converge well?"

That equivalence is why this topic belongs in a chapter on optimization rather than systems. The communication machinery (the all-reduce of Chapter 4) is indifferent to batch size; it averages whatever gradients it is given. The optimizer is not indifferent. A gradient averaged over $B$ examples is an estimate of the true gradient whose variance falls roughly like $1/B$. Larger batches give cleaner gradients, and a cleaner gradient is one you can trust to take a larger step along. The rest of this section turns that intuition into a rule, then shows where the rule breaks.

Key Insight: The Worker Count Is a Batch-Size Knob in Disguise

In synchronous data-parallel SGD with a fixed per-worker batch, the global batch size is $B = K \, B_{\text{loc}}$, so choosing the number of workers $K$ is choosing the batch size. Communication averages the gradients exactly regardless of $K$; what changes with $K$ is the variance of each update and therefore the step size the optimizer can safely take. Every large-batch training decision is really a decision about how the optimizer should respond to a lower-variance gradient.

2. The Linear Scaling Rule and Why Warmup Is Mandatory Intermediate

The practical recipe that made large-batch training routine is the linear scaling rule of Goyal et al. (2017): when you multiply the batch size by a factor, multiply the learning rate by the same factor. Concretely, if a reference batch $B_0$ trains well at learning rate $\eta_0$, then a batch of size $B$ should use

$$\eta = \eta_0 \cdot \frac{B}{B_0}.$$

The heuristic behind it is short. Take $K$ consecutive small-batch SGD steps with learning rate $\eta_0$ on batches $\mathcal{B}_1, \dots, \mathcal{B}_K$. If the parameters barely move across those $K$ steps, the gradients are evaluated at nearly the same point, and the net update is approximately $-\eta_0 \sum_{j=1}^{K} \nabla \ell_{\mathcal{B}_j}(w)$. One large-batch step on the union $\mathcal{B}_1 \cup \dots \cup \mathcal{B}_K$ with learning rate $\eta$ produces $-\eta \cdot \frac{1}{K}\sum_{j=1}^{K}\nabla \ell_{\mathcal{B}_j}(w)$. Matching the two requires $\eta = K \eta_0$, exactly the linear rule. The rule is an approximation, and it names its own failure mode: it holds only while the assumption "the parameters barely move" holds.

Early in training the parameters move a great deal, the loss surface is steep and curved, and the linear-rule learning rate (now $K$ times larger) overshoots and can diverge on the very first steps. The fix is a warmup: start the learning rate small and ramp it linearly to the scaled value over the first few hundred updates, by which point the parameters have settled into a region where the matching argument is valid. Warmup is not a cosmetic safety margin; it is the repair for the one place the derivation is known to be invalid. Skip it at large $K$ and the run frequently diverges before it begins.

Fun Note: The Rule That Survived by Admitting Its Own Flaw

The linear scaling rule is unusual among deep-learning heuristics because it ships with the exact conditions under which it fails written into its derivation. Most rules of thumb pretend to be universal and quietly disappoint you. This one says, in effect, "I work whenever the weights are not moving much, which is everywhere except the start, so please warm me up at the start." Engineers trust it precisely because it does not oversell.

3. The Critical Batch Size: Where Adding Workers Stops Helping Intermediate

The linear scaling rule cannot hold forever, because no fixed number of training examples needs an unboundedly large step. Below some threshold, doubling the batch (and the learning rate) roughly halves the number of optimizer steps to a target loss, so you reach the target in about the same number of passes over the data and the extra workers translate into real speedup. Above that threshold, doubling the batch no longer halves the steps; the gradient is already clean enough that lowering its variance further does little, and you are in the regime of diminishing returns. McCandlish et al. (2018) named this threshold the critical batch size and gave it a predictor: it scales with the gradient noise scale, a measurable ratio of gradient variance to gradient magnitude that tends to grow as training proceeds.

The critical batch size is the optimization counterpart of the serial bottleneck in Section 3.5: just as Amdahl's law caps speedup once the parallelizable fraction is exhausted, the critical batch size caps useful worker count once the gradient-variance reduction is exhausted. Past it, adding workers grows the batch beyond what the optimizer can convert into faster convergence, so wall-clock-per-epoch keeps falling but epochs-to-target stops falling, and total time-to-train flattens or worsens. Knowing roughly where this knee sits for your model is what separates a worker count that pays for itself from one that burns budget. Figure 10.8.1 sketches the shape.

global batch size B (more workers →) steps to target loss ideal: steps ∝ 1/B critical batch size scaled LR: more workers stop helping naive fixed LR: batch hurts convergence scaled LR tracks the 1/B speedup here
Figure 10.8.1: Steps-to-target versus global batch size on log-log axes. With the linear scaling rule plus warmup (solid green), steps fall along the ideal $1/B$ line, so each added worker buys real speedup, until the curve bends up at the critical batch size and further workers stop helping. With a naive fixed learning rate (dashed red), enlarging the batch reduces the updates per epoch without compensating, so steps-to-target climb and convergence degrades. The knee, not the hardware budget, sets the useful worker count.

4. A From-Scratch Demonstration of the Knee Intermediate

The claims above are testable in pure Python without any cluster. The code below trains a logistic-regression model with plain mini-batch SGD on a fixed synthetic dataset, sweeping the global batch from 32 to 4096. For each batch it measures epochs-to-target-loss twice: once with a fixed learning rate (the naive baseline) and once with the linear scaling rule plus warmup. Counting epochs (passes over the data) is the right cost unit, because a larger batch means fewer optimizer updates per epoch, so a fixed learning rate makes less progress per pass while the scaled rule compensates.

import numpy as np

rng = np.random.default_rng(7)
N, d = 20_000, 30
X = rng.standard_normal((N, d))
w_star = rng.standard_normal(d)
p = 1.0 / (1.0 + np.exp(-(X @ w_star)))
y = (rng.random(N) < p).astype(np.float64)        # noisy binary labels

def loss_grad(w, Xb, yb):
    pr = 1.0 / (1.0 + np.exp(-(Xb @ w)))
    g = Xb.T @ (pr - yb) / len(yb)
    L = -np.mean(yb*np.log(pr+1e-12) + (1-yb)*np.log(1-pr+1e-12))
    return L, g

full_loss = lambda w: loss_grad(w, X, y)[0]

# Demanding target a little above the optimum, so reaching it takes many epochs.
w_opt = np.zeros(d)
for _ in range(4000):
    w_opt -= loss_grad(w_opt, X, y)[1]
TARGET = full_loss(w_opt) + 0.010

base_bs, base_lr, MAX_EPOCHS = 32, 0.05, 80          # reference batch and its tuned lr

def epochs_to_target(B, scale_lr):
    lr = base_lr * (B / base_bs) if scale_lr else base_lr   # linear scaling rule
    warmup = 5 * max(1, B // base_bs) if scale_lr else 0    # warmup grows with scaling
    w, steps = np.zeros(d), 0
    for epoch in range(1, MAX_EPOCHS + 1):
        perm = rng.permutation(N)
        for s in range(0, N, B):
            _, g = loss_grad(w, X[perm[s:s+B]], y[perm[s:s+B]])
            cur = lr * (steps+1)/warmup if warmup and steps < warmup else lr
            if not np.all(np.isfinite(g)): return None       # diverged
            w -= cur * g
            steps += 1
        if full_loss(w) <= TARGET:
            return epoch
    return None                                              # never reached target

for B in [32, 64, 128, 256, 512, 1024, 2048, 4096]:
    ef = epochs_to_target(B, scale_lr=False)
    es = epochs_to_target(B, scale_lr=True)
    fmt = lambda e: ">80 (stalled)" if e is None else str(e)
    print(f"{B:>6} | fixed-lr {fmt(ef):>13} | scaled-lr {fmt(es):>13}")
Code 10.8.1: Pure-NumPy sweep of global batch size, comparing a fixed learning rate against the linear scaling rule with warmup. The cost metric is epochs-to-target so that fewer-updates-per-epoch at large batch is counted fairly. The same dataset and target are reused for every batch, so differences are attributable to the optimizer schedule alone.
    32 | fixed-lr             2 | scaled-lr             2
    64 | fixed-lr             4 | scaled-lr             2
   128 | fixed-lr             8 | scaled-lr             3
   256 | fixed-lr            16 | scaled-lr             3
   512 | fixed-lr            31 | scaled-lr             4
  1024 | fixed-lr            61 | scaled-lr             6
  2048 | fixed-lr >80 (stalled) | scaled-lr            11
  4096 | fixed-lr >80 (stalled) | scaled-lr            22
Output 10.8.1: The fixed learning rate degrades monotonically: epochs-to-target double with every doubling of the batch, then stall past 1024 because too few updates per epoch make no progress. The scaled learning rate stays cheap and roughly flat through batch 512, the regime where each added worker buys speedup, then bends upward (6, 11, 22 epochs at 1024, 2048, 4096), the critical-batch knee where more workers stop helping even with correct scaling.

Read the two columns together and the story of this section is in the numbers. The naive column shows why you cannot simply pour workers into a synchronous job and hope: without scaling the learning rate, a bigger batch is strictly worse. The scaled column shows the linear rule doing its job, holding epochs-to-target nearly constant while the batch grows eightfold, then the same column reveals the knee, where even correct scaling cannot recover the $1/B$ ideal because the gradient is already low-variance. The flattening past batch 512 is the demonstration, in miniature, that the useful worker count is bounded by the optimization problem, not by the cluster you can rent.

Thesis Thread: Scale-Out Reaches an Optimization Ceiling, Not Only a Communication One

Earlier chapters argued that communication is the tax on scale-out. This section adds a second, independent ceiling that no faster network can lift: the critical batch size. Even with free, instant all-reduce, a synchronous data-parallel job stops converging faster once the global batch passes the knee, because the optimizer cannot use a cleaner gradient it already has. Distribution buys speed only while two separate budgets last, the communication budget of Section 3.5 and the batch-size budget here, and a well-designed training run respects whichever binds first.

5. Layer-Wise Adaptive Methods: Pushing the Knee Outward Advanced

The critical batch size is not a hard constant; it depends on the optimizer. The linear scaling rule applies one global learning rate to every parameter, but in deep networks the ratio of update magnitude to weight magnitude varies enormously across layers, and at very large batch a single global rate is simultaneously too large for some layers and too small for others, which is what makes naive scaling diverge well before the statistical knee. Layer-wise adaptive methods attack this directly by giving each layer its own effective step. LARS (You, Gitman, Ginsburg, 2017) rescales each layer's update so its norm is a fixed fraction of that layer's weight norm, which let ResNet-style networks train at batches in the tens of thousands. LAMB (You et al., 2019) carries the same trust-ratio idea into the Adam family and is the method that trained BERT to its target in 76 minutes at a batch of 32k, a result that would diverge instantly under a single global linearly-scaled rate.

The lesson for distributed training is that the optimizer choice sets how far you can scale out before the knee bites. Plain momentum SGD with linear scaling and warmup is robust to moderate batches; LARS and LAMB extend the usable range by an order of magnitude or more for the architectures they were designed for, by replacing one fragile global step with many self-calibrated per-layer steps. They do not abolish the critical batch size, they relocate it, which is exactly the leverage you want when the alternative is leaving expensive accelerators underused.

Practical Example: The Pretraining Run That Hit a Wall at 4096

Who: A research engineer pretraining a mid-size language model on a 64-GPU pod.

Situation: Throughput was excellent at a global batch of 1024, and the team wanted to use all 64 GPUs at a local batch of 64, pushing the global batch to 4096.

Problem: Applying the linear scaling rule alone, the run diverged in the first few hundred steps, and when warmup was lengthened it converged but needed nearly four times the epochs, erasing the speedup.

Dilemma: Cap the job at 1024 and leave 48 GPUs idle, or find an optimizer that converts the larger batch into real progress rather than wasted passes.

Decision: They switched the optimizer to LAMB with a longer warmup, on the hypothesis that the divergence was a per-layer step-size mismatch, not a fundamental critical-batch limit at 4096.

How: They kept the data-parallel all-reduce unchanged and swapped only the optimizer and its warmup schedule, then measured epochs-to-target at batches 1024, 2048, and 4096 before committing to a full run.

Result: At 4096 the LAMB run converged in close to the same epoch count as the 1024 baseline, recovering most of the 4x throughput and keeping the pod fully utilized; the team had relocated the knee, not crossed it.

Lesson: When large-batch scaling stalls, the first question is whether the optimizer or the optimization problem is the limit. A layer-wise adaptive method often buys another factor of batch before the true critical size, as Output 10.8.1's knee, applies.

Library Shortcut: Linear Scaling and Warmup in a Few Lines

Code 10.8.1 implemented the scaling rule and warmup by hand. In a real PyTorch job you set the scaled learning rate once from the world size and let a scheduler handle warmup, while DistributedDataParallel supplies the global batch by averaging gradients across workers. The whole recipe is a handful of lines:

# Run with: torchrun --nproc_per_node=8 train.py
import torch, torch.distributed as dist
from torch.optim.lr_scheduler import LinearLR

dist.init_process_group("nccl")
K = dist.get_world_size()                         # number of workers == batch multiplier

base_lr, base_bs = 0.05, 32
lr = base_lr * (K * local_bs) / base_bs           # linear scaling rule from the world size
opt = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)

# Linear warmup over the first 500 steps, then hold the scaled rate.
warmup = LinearLR(opt, start_factor=1/K, total_iters=500)
# model = DistributedDataParallel(model)  # averages gradients -> global batch K*local_bs
Code 10.8.2: The linear scaling rule plus warmup expressed with stock PyTorch utilities. The roughly twenty lines of manual schedule logic in Code 10.8.1 collapse to a one-line learning-rate formula and a built-in LinearLR warmup; for LARS or LAMB you swap the optimizer for the implementation in NVIDIA Apex or a comparable library, with no change to the data-parallel loop.
Research Frontier: Predicting the Knee and Scaling the Optimizer (2024 to 2026)

Two recent lines sharpen this section's themes. The first is the maximal-update parametrization, muP, and its descendants, which reparametrize a network so that the optimal learning rate is stable across model widths; this lets practitioners tune the learning rate on a small proxy model and transfer it to a large one without re-searching, and recent work (Everett et al., 2024) extends such hyperparameter transfer jointly across width, depth, and batch size, directly easing the large-batch tuning problem this section describes. The second revisits the critical batch size itself for modern language-model pretraining: studies through 2024 to 2025 measure how the critical batch grows with dataset and model size and with the gradient noise scale of McCandlish et al., refining when extra data-parallel workers still pay. A parallel current explores optimizers such as Muon and Lion that claim better large-batch behavior than AdamW, effectively trying to push the knee of Output 10.8.1 further right. The open question is no longer whether a critical batch exists, but how to predict its location cheaply before committing a full pod to a run.

6. The Practical Ceiling and What Comes Next Intermediate

Putting the pieces together gives a sober rule for sizing a synchronous data-parallel job. Add workers while the global batch stays below the critical size, scaling the learning rate linearly and warming it up, and use a layer-wise adaptive optimizer if a single global rate begins to diverge before the statistical knee. Stop adding workers near the knee, because past it the extra hardware reduces per-epoch wall-clock without reducing epochs-to-target, so total training time stops improving and cost rises with no return, an Amdahl-shaped ceiling on the optimization side to match the communication-side ceiling of Section 3.5. The useful worker count is the smaller of what your network can sustain and what your optimizer can absorb, and Output 10.8.1 shows the second limit can bind well before the first.

We have treated communication as something to minimize and batch size as something to bound, but we have not yet asked how little communication is fundamentally required to optimize across machines, regardless of cleverness. That lower-bound question, how many bits and rounds any distributed optimizer must pay, is the subject of the next section. It turns the engineering trade-offs of this chapter into the theory that says which of them are avoidable and which are laws.

Exercise 10.8.1: Derive the Rule, Name Its Assumption Conceptual

Reproduce the matching argument of Section 2 in your own notation: write the net parameter change of $K$ small-batch SGD steps at learning rate $\eta_0$, write one large-batch step on the union at learning rate $\eta$, and show that equating them requires $\eta = K\eta_0$. State precisely the assumption that makes the two expressions comparable, and explain why that assumption is violated at the start of training but not in the middle. Use this to argue, without any experiment, why warmup is needed at the start specifically and not throughout.

Exercise 10.8.2: Find the Knee, Then Move It Coding

Extend Code 10.8.1 in two ways. First, for each batch size also record the final loss reached within the epoch budget, and plot epochs-to-target against batch size on log-log axes to locate the knee numerically for this dataset. Second, replace the global learning rate with a crude per-coordinate trust ratio (scale each coordinate's step by the ratio of the weight magnitude to the gradient magnitude in that coordinate, as a LARS-style proxy) and re-measure. Report whether the knee moves to a larger batch, and explain the result in terms of why a single global rate is too large for some directions and too small for others.

Exercise 10.8.3: Cost Out the Critical Batch Analysis

Suppose a model has a critical batch size of 2048 examples, each worker holds a fixed local batch of 64, one optimizer step takes 200 ms of compute plus a 50 ms all-reduce that does not change with worker count, and reaching the target needs $2 \times 10^5$ examples processed below the knee. Estimate total time-to-train at 8, 32, and 64 workers, assuming epochs-to-target stays flat up to the critical batch and then doubles for each further doubling beyond it (as in Output 10.8.1). Identify the worker count that minimizes time-to-train, and explain why buying more workers past it is wasted money even though per-epoch wall-clock keeps falling.