Part III: Distributed Machine Learning
Chapter 10: Distributed Optimization

Mini-Batch Stochastic Gradient Descent

"They asked me to estimate the gradient from a handful of examples. I did. They acted surprised that my handful disagreed with the whole truck."

A Mini-Batch With Commitment Issues
Big Picture

Mini-batch stochastic gradient descent is the optimizer that makes distributed training possible, because the mini-batch gradient is an unbiased estimate of the full-data gradient whose noise shrinks predictably as the batch grows, and because a synchronous step over many workers is, mathematically, one mini-batch step with a larger batch. The previous section framed training as minimizing an average loss over a dataset. Computing that average over every example before each update is too slow at scale, and using a single example per update is too noisy. Mini-batch SGD sits between those extremes: it samples a small batch, averages its gradients, and steps. This section establishes the one fact the rest of the chapter rests on, that combining $K$ workers each with a local batch of $b$ examples is identical to running mini-batch SGD with a global batch of $K b$, which is exactly why Section 10.3 can turn it into a distributed algorithm with no change to the math.

In Section 10.1 we wrote learning as empirical risk minimization: find the parameters $w$ that minimize the average loss over a training set. The clean way to do that is gradient descent, which repeatedly steps downhill along the full-data gradient. The trouble is that the full-data gradient is an average over every example, and at the scale this book cares about there are billions of them. Computing one exact gradient can take minutes, and a single optimizer step that takes minutes is a non-starter. The fix is to stop insisting on the exact gradient and accept a cheap, noisy estimate of it instead. That single concession, estimate the gradient rather than compute it, is the door through which both stochastic optimization and distribution enter.

1. Three Batch Sizes, One Update Rule Beginner

All three methods share the same gradient-descent skeleton: estimate a gradient $g_t$ at the current parameters and step against it, $w_{t+1} = w_t - \eta\, g_t$, where $\eta$ is the learning rate. They differ only in how many examples go into $g_t$. Full-batch gradient descent uses every example, so $g_t$ is the exact gradient $\nabla L(w_t)$; each step is maximally informative and ruinously expensive. Stochastic gradient descent at the other extreme uses a single example drawn at random, so $g_t = \nabla \ell(w_t; x_i, y_i)$; each step is cheap but jittery, and the jitter does not vanish as training proceeds. Mini-batch SGD draws a small batch $B_t$ of $b$ examples and averages their gradients:

$$g_t = \frac{1}{b} \sum_{i \in B_t} \nabla \ell(w_t; x_i, y_i), \qquad w_{t+1} = w_t - \eta\, g_t.$$

The batch size $b$ is the dial between the two extremes. Set $b = N$ and you are doing full-batch descent; set $b = 1$ and you are doing pure SGD. Every practical training run lives somewhere in the middle, and the rest of this section is about where, and why. We refer to Figure 10.2.1 throughout: it shows a mini-batch being sampled from the dataset and the gradient noise it carries shrinking as the batch widens.

Dataset of N examples orange = sampled mini-batch of b examples average Gradient estimate vs the true gradient true full-data gradient b = 1: wide spread b = 16: narrower b = 256: tight around the truth spread shrinks like 1 / sqrt(b)
Figure 10.2.1: Mini-batch SGD in one picture. Left: each step samples a mini-batch of $b$ examples (orange) from the full dataset and averages their gradients. Right: the resulting estimate is centered on the true full-data gradient (green line) but spread around it; as $b$ grows the spread narrows like $1/\sqrt{b}$, so the estimate becomes both unbiased and progressively less noisy. Code 10.2.1 measures this exact shrinkage.

2. The Mini-Batch Gradient Is Unbiased, and Its Noise Falls Like 1/b Intermediate

The reason mini-batch SGD works is a two-part probability fact, and both parts matter for distribution. Suppose the batch $B_t$ is drawn uniformly at random from the $N$ examples. Then the expected value of the mini-batch gradient, over the random draw of the batch, equals the full-data gradient:

$$\mathbb{E}\!\left[ g_t \right] = \frac{1}{b} \sum_{i \in B_t} \mathbb{E}\!\left[ \nabla \ell(w_t; x_i, y_i) \right] = \frac{1}{N} \sum_{i=1}^{N} \nabla \ell(w_t; x_i, y_i) = \nabla L(w_t).$$

This is the unbiasedness property: on average the mini-batch points in the correct direction. It is not approximately correct; the expectation is exactly the full gradient. The second part concerns how far an individual draw strays from that average. If the per-example gradients have variance $\sigma^2$ (per coordinate) and the batch is sampled with replacement, the variance of the average over $b$ examples is

$$\operatorname{Var}\!\left[ g_t \right] = \frac{\sigma^2}{b}.$$

The variance falls inversely with the batch size, so the standard deviation, the typical size of the noise, falls like $1/\sqrt{b}$. Quadruple the batch and you halve the noise. This is the precise sense in which a larger batch gives a more accurate gradient. It is also the law that Figure 10.2.1 draws and that the demo below measures directly on a convex problem.

Key Insight: Unbiased Direction, Tunable Noise

The mini-batch gradient is an unbiased estimator of the full gradient (its expectation is exactly $\nabla L$) with a noise level you control through the batch size $b$ (its variance is $\sigma^2 / b$). These two properties are what let you trade computation for accuracy continuously: a tiny batch is cheap and noisy, a large batch is costly and precise, and neither one is biased. Every distributed-training method in this chapter exploits the first property to stay correct and tunes the second to balance compute against communication.

3. Why Mini-Batches, and Not the Extremes Beginner

Two forces pull the batch size up from one, and one force pulls it back down from $N$. The first upward force is the variance reduction we just derived: a bigger batch is a less noisy gradient, and less noise lets you use a larger learning rate and take steadier steps toward the minimum. The second upward force is hardware. A modern accelerator does not compute one gradient at a time; it computes a batch of them as a single matrix operation, and that vectorized form is dramatically more efficient per example than a loop over single examples. A batch of 256 often costs barely more wall-clock than a batch of 1, because the accelerator was idle waiting for work either way. Mini-batches keep the hardware busy, which is the entire reason GPUs train neural networks quickly.

The downward force is diminishing returns. Because the noise falls like $1/\sqrt{b}$, doubling the batch past a certain point buys a smaller and smaller reduction in noise while doubling the compute, so each example contributes less to progress. There is a regime where a larger batch is pure waste, and finding the edge of that regime is the subject of Chapter 4's cost models and of Section 10.8 on large-batch training. The practical sweet spot is the largest batch that still keeps the gradient noisy enough to explore and small enough that each step does real work, sized to fill the hardware. That sweet spot is almost never $1$ and almost never $N$.

Fun Note: The Noise Is a Feature

It is tempting to read "noise" as a defect to be eliminated, which would argue for the largest batch you can afford. But the jitter of a small batch is partly what lets SGD escape bad flat regions and shallow traps in non-convex loss landscapes; a perfectly exact gradient can get stuck where a slightly noisy one wanders free. Practitioners sometimes describe small-batch training as having a built-in regularizer that bigger batches lose. So the goal is not zero noise; it is the right amount of noise, which is one more reason the batch size is a hyperparameter you tune rather than a quantity you maximize.

4. Watching the Variance Shrink: A From-Scratch Demo Intermediate

The two claims of Section 2, unbiasedness and $1/\sqrt{b}$ noise decay, are easy to see directly. The program below sets up a convex least-squares problem in pure Python, fixes a point $w = 0$, and repeatedly samples mini-batches of growing size to measure the spread of the gradient estimate around the true full-data gradient. It then runs short mini-batch SGD at a small and a large batch size and prints the full loss along each trajectory, so the effect of batch size on the descent path is visible too.

import random, math

# A convex least-squares problem: minimize (1/N) sum_i (w*x_i - y_i)^2 over scalar w.
random.seed(0)
N, w_true = 20000, 3.0
X = [random.gauss(0.0, 1.0) for _ in range(N)]
Y = [w_true * x + random.gauss(0.0, 0.1) for x in X]   # y = 3x + small noise

def full_gradient(w):                                  # exact gradient over ALL N
    return sum(2.0 * x * (w * x - y) for x, y in zip(X, Y)) / N

def minibatch_gradient(w, idx):                        # estimate from a sampled batch
    return sum(2.0 * X[i] * (w * X[i] - Y[i]) for i in idx) / len(idx)

# Part 1: spread of the mini-batch estimate around the true gradient, vs batch size.
w_probe = 0.0
g_full, trials, prev = full_gradient(w_probe), 400, None
print("gradient-estimate quality at w = 0.0   (true full gradient g = %.4f)" % g_full)
print("%-8s %-14s %-14s" % ("batch b", "mean est.", "std of est."))
for b in [1, 4, 16, 64, 256, 1024]:
    ests = [minibatch_gradient(w_probe, [random.randrange(N) for _ in range(b)])
            for _ in range(trials)]
    mean = sum(ests) / trials
    std = math.sqrt(sum((e - mean) ** 2 for e in ests) / (trials - 1))
    ratio = "" if prev is None else "   (%.2fx vs prev)" % (std / prev)
    print("%-8d %-14.4f %-14.4f%s" % (b, mean, std, ratio)); prev = std

# Part 2: effect of batch size on the loss trajectory.
def run_sgd(b, steps, lr):
    w, out = 0.0, []
    for t in range(steps):
        w -= lr * minibatch_gradient(w, [random.randrange(N) for _ in range(b)])
        if t % (steps // 5) == 0 or t == steps - 1:
            out.append((t, sum((w * x - y) ** 2 for x, y in zip(X, Y)) / N, w))
    return out

print("\nloss trajectory (lr = 0.05, 60 steps), small vs large batch")
for b in [1, 256]:
    traj = run_sgd(b, 60, 0.05)
    tail = "  ".join("t=%-2d L=%.4f" % (t, L) for (t, L, _) in traj)
    print("batch b=%-4d -> %s  (final w=%.4f)" % (b, tail, traj[-1][2]))
Code 10.2.1: Pure-Python mini-batch SGD on a convex least-squares loss. Part 1 samples many mini-batches at a fixed point to measure the gradient estimate's mean (testing unbiasedness) and standard deviation (testing the $1/\sqrt{b}$ law); Part 2 runs the optimizer at two batch sizes and records the true loss along the way.
gradient-estimate quality at w = 0.0   (true full gradient g = -5.9093)
batch b  mean est.      std of est.
1        -5.4570        8.3614
4        -5.8701        4.2280           (0.51x vs prev)
16       -6.0549        2.0797           (0.49x vs prev)
64       -5.7980        1.0175           (0.49x vs prev)
256      -5.8951        0.5176           (0.51x vs prev)
1024     -5.9086        0.2555           (0.49x vs prev)

loss trajectory (lr = 0.05, 60 steps), small vs large batch
batch b=1    -> t=0  L=1.1680  t=12 L=0.0102  t=24 L=0.0100  t=36 L=0.0109  t=48 L=0.0100  t=59 L=0.0110  (final w=3.0321)
batch b=256  -> t=0  L=7.1306  t=12 L=0.6862  t=24 L=0.0583  t=36 L=0.0142  t=48 L=0.0103  t=59 L=0.0100  (final w=2.9958)
Output 10.2.1: The mean estimate hugs the true gradient $-5.9093$ at every batch size (unbiasedness), while the standard deviation halves each time the batch quadruples, $8.36 \to 4.23 \to 2.08 \to 1.02 \to 0.52 \to 0.26$, the $0.5\times$-per-$4\times$ signature of $1/\sqrt{b}$. The small batch reaches low loss fast but keeps jittering around it ($L$ wobbles near $0.010$); the large batch descends more smoothly and settles to a tighter final loss and a $w$ nearer the true $3.0$.

Two things in Output 10.2.1 are worth dwelling on. First, the mean column barely moves: every batch size, from $1$ to $1024$, produces an estimate centered on $-5.9093$. That is unbiasedness, exactly as the expectation calculation promised. Second, the standard-deviation column is a clean geometric sequence with ratio one half per fourfold increase in $b$, which is the $1/\sqrt{b}$ law made visible. The trajectory output then shows the practical consequence: the single-example run is fast but never stops jittering, whereas the batch-of-256 run takes smoother steps and lands closer to the true parameter. This is the entire argument for mini-batches on one screen.

5. The Bridge to Distribution: K Workers Make One Big Batch Intermediate

Here is the fact this section exists to deliver. Suppose you have $K$ workers, and each worker $k$ independently samples a local mini-batch $B^{(k)}$ of $b$ examples and computes its local average gradient $g^{(k)} = \frac{1}{b} \sum_{i \in B^{(k)}} \nabla \ell(w; x_i, y_i)$. If the workers average their local gradients,

$$g = \frac{1}{K} \sum_{k=1}^{K} g^{(k)} = \frac{1}{K} \sum_{k=1}^{K} \frac{1}{b} \sum_{i \in B^{(k)}} \nabla \ell(w; x_i, y_i) = \frac{1}{Kb} \sum_{i \in B^{(1)} \cup \cdots \cup B^{(K)}} \nabla \ell(w; x_i, y_i),$$

the result is exactly the mini-batch gradient over the combined batch of $K b$ examples. A synchronous distributed step over $K$ workers with local batch $b$ is therefore identical to a single-machine mini-batch step with global batch $K b$; nothing about the optimizer changes, only where the examples live. By the variance law, this combined estimate has noise $\sigma^2 / (K b)$, so adding workers reduces gradient noise in precisely the way a bigger batch does. The combining step, summing one gradient vector per worker and sharing the result, is the all-reduce collective that Chapter 4 is built around, and the same identity drove the exact-gradient argument of Section 1.1.

Thesis Thread: One Step, Many Machines, Same Math

The whole edifice of synchronous data-parallel training stands on the line of algebra above: averaging $K$ local mini-batch gradients yields the global mini-batch gradient over $K b$ examples, with no approximation. This is why scale-out training can be exact rather than lossy, and why the next section can promote mini-batch SGD to a distributed algorithm by adding only an all-reduce, with no new convergence theory. The same primitive returns as gradient synchronization in data-parallel deep learning (Chapter 15) and as the push-pull average in parameter servers (Chapter 11). When you meet a distributed optimizer later, ask what its global batch is; the answer is almost always workers times local batch.

6. Learning Rate, Momentum, and the Noise Scale Advanced

The batch size does not act alone; it interacts with the learning rate. Because a $K$-worker step is a $Kb$-batch step with $K$ times less gradient noise, the step is more trustworthy, and a more trustworthy gradient can support a larger learning rate. This is the origin of the learning-rate scaling rules that make large-batch and many-worker training converge, examined in detail in Section 10.8: when you grow the global batch, you generally grow the learning rate to match, often linearly, until the step becomes too large for the loss landscape to tolerate. Get this coupling wrong and a perfectly correct distributed gradient will still diverge, not because the math broke but because the step size no longer suits the now-quieter gradient.

Momentum complicates and helps. Plain SGD steps along the current noisy gradient; momentum steps along a running average of past gradients, which damps the per-step noise further and accelerates progress along consistent directions. In a distributed setting momentum is computed on the already-averaged global gradient, so it composes cleanly with the worker-averaging above. The useful mental model is a single noise scale that captures how jittery the effective gradient is, set jointly by the batch size, the number of workers, and the learning rate; the art of tuning a distributed optimizer is keeping that noise scale in the band where training is both stable and fast. Sections 10.8 and Chapter 4's cost models give the quantitative form of this band.

Research Frontier: Sizing the Batch and the Noise (2024 to 2026)

How large a global batch helps before returns collapse is an active question, because every extra worker only pays off if the batch it adds still does useful work. The gradient-noise-scale framing introduced by McCandlish et al. has been refined and applied at frontier scale: 2024 to 2026 work on critical and adaptive batch sizes (for example, the adaptive batch-size schedules studied around large-language-model pretraining and the data-parallel scaling analyses in the lineage of the Chinchilla and follow-on scaling-law papers) tries to predict the largest batch that still trains efficiently, and to grow it over the course of training as the loss landscape flattens. A parallel thread re-examines linear learning-rate scaling and its breakdown at extreme batch sizes, proposing square-root and warmup-aware rules. The shared goal is to turn "how many workers should I add?" into a quantity you can compute from the noise scale rather than discover by trial, which connects directly to the communication-cost models of Chapter 4.

Library Shortcut: One Optimizer Object Does the Mini-Batch Step

Code 10.2.1 wrote the sample-average-step loop by hand to expose the variance behavior. In practice the mini-batch update, with momentum and weight decay, is a single optimizer object, and the batching is handled by the data loader; on $K$ workers you wrap the model and the per-worker batch becomes one shard of the global batch automatically:

import torch
from torch.utils.data import DataLoader

loader = DataLoader(dataset, batch_size=256, shuffle=True)   # local mini-batch of b=256
opt = torch.optim.SGD(model.parameters(), lr=0.05, momentum=0.9)   # the update rule

for xb, yb in loader:                       # each xb is one mini-batch
    opt.zero_grad()
    loss = loss_fn(model(xb), yb)
    loss.backward()                         # fills .grad with the mini-batch gradient
    opt.step()                              # w <- w - lr * (momentum-smoothed grad)
Code 10.2.2: The mini-batch step of Code 10.2.1 as a stock PyTorch loop. The roughly fifteen hand-written lines of sampling, averaging, and stepping collapse to a DataLoader for batching plus an SGD object that owns the learning rate and momentum; Chapter 15 adds one wrapper so the same loop runs the global-batch step across $K$ workers.
Practical Example: The Batch Size That Wasted Half the Cluster

Who: An ML platform engineer scaling a image-classification training job from one node to many.

Situation: A model trained well at batch 256 on one GPU in twelve hours; the team moved to 32 GPUs hoping for a near-32x speedup.

Problem: Keeping the per-GPU batch at 256 made the global batch $32 \times 256 = 8192$, and at that batch the model's accuracy dropped and convergence stalled, so the run was both expensive and worse.

Dilemma: Shrink the per-GPU batch to hold the global batch fixed, wasting the extra memory and compute each GPU offered, or keep the large global batch and somehow recover convergence.

Decision: They kept the large global batch but treated it as a $Kb$-batch step (the identity in Section 5), so the cure was a hyperparameter change, not a code change: scale the learning rate up to match the quieter gradient and add a warmup.

How: They applied a linear learning-rate scaling rule with a short warmup, raising the rate roughly in proportion to the $32\times$ larger batch, exactly the coupling Section 6 describes, and left the optimizer otherwise untouched.

Result: Accuracy returned to the single-GPU baseline and wall-clock fell close to the hoped-for factor, because the 32 workers were now a correctly tuned global-batch step rather than an under-stepped one.

Lesson: Adding workers enlarges the global batch, which lowers the gradient noise, which demands a larger learning rate. The distributed gradient was exact all along; the step size, not the math, was the bug.

We now have the optimizer the rest of the chapter distributes, and the one identity that makes distributing it painless: $K$ workers with local batch $b$ are one mini-batch step with global batch $K b$, exact and unbiased, with noise that falls like $1/\sqrt{Kb}$. What remains is to add the machinery that lets the workers actually exchange their gradients each step and stay in lockstep. That machinery is synchronous distributed SGD, and it is the subject of Section 10.3.

Exercise 10.2.1: Read the Variance Off the Output Conceptual

Using only Output 10.2.1, answer the following without rerunning the code. (a) The standard deviation at $b = 16$ is about $2.08$; predict the standard deviation you would expect at $b = 4096$ and state the rule you used. (b) Explain why the mean estimate stays near $-5.9093$ at every batch size even though the spread changes by a factor of more than thirty across the table. (c) A colleague proposes $b = N = 20000$ to drive the noise to zero; using the $1/\sqrt{b}$ law and Section 3, give one statistical reason and one systems reason this is a poor default.

Exercise 10.2.2: Simulate the K-Worker Identity Coding

Extend Code 10.2.1 so that instead of one mini-batch of size $Kb$, you sample $K$ independent local batches of size $b$, average their per-worker gradients, and compare the result to a single mini-batch gradient of size $Kb$ drawn from the same pool. For $K = 8$ and $b = 32$, run many trials and report the mean and standard deviation of both estimators. Confirm that they match in distribution (same mean, same spread), and explain which line of algebra in Section 5 predicts this. Then break it: make the workers sample with different batch sizes $b_k$ and show that a plain unweighted average of the per-worker gradients is no longer the global mini-batch gradient, and state the size-weighted fix.

Exercise 10.2.3: When Does the Next Worker Stop Helping? Analysis

Assume the useful signal from a gradient step scales with how far the estimate sits below its noise floor, so that the value of a step grows roughly with $1/\operatorname{std} = \sqrt{Kb}/\sigma$, while the wall-clock cost of a synchronous step is the per-worker compute (constant in $K$) plus a communication term that grows with $K$ (you may model it as linear in $K$ for this exercise, sharpened in Chapter 4). Write down value-per-unit-time as a function of $K$ and show that it has a maximum at finite $K$. Interpret that maximum in words: why does the $1/\sqrt{Kb}$ shape of the noise, rather than any failure of the math, guarantee that adding workers eventually stops paying off?