"They cut my weight matrix into eight strips and handed one to each of us. I only ever see an eighth of the answer; the all-reduce at the end is where we agree on what we computed."
A Shard That Believes It Is the Whole Model
Tensor parallelism splits the arithmetic of a single layer across devices: each holds a slice of the weight matrix, computes part of the matmul, and the partial results are summed with an all-reduce. Where data parallelism replicates the whole model and splits the batch, tensor parallelism keeps the batch whole and splits the operator itself, so a layer too large for one accelerator's memory becomes $T$ smaller pieces that no single device ever holds in full. The catch is the bill for the combining step: an all-reduce sits inside every layer of the forward pass and again in the backward pass, which is far more communication than data parallelism asks for. That cost dictates where tensor parallelism can live, on the fastest interconnect in the machine, and how far it can scale, rarely past the GPUs of one node. This section derives the Megatron scheme that makes the per-layer communication exactly two all-reduces, proves it numerically, and locates the technique on the interconnect it demands.
In Section 16.1 the binding ceiling changed. Data parallelism, the subject of Chapter 15, assumes the whole model fits on one accelerator and replicates it; once a single layer's weight matrix no longer fits, replication is impossible and the model itself must be cut apart. Tensor parallelism is the most fine-grained way to make that cut. It does not move whole layers to different devices (that is pipeline parallelism, the subject of Section 16.3); it reaches inside one layer and partitions the matrix multiply that defines it, so that the layer's parameters, its activations, and the work of computing them are divided $T$ ways at once.
The idea rests on a single algebraic fact about matrix multiplication: a large matmul can be expressed as a sum or concatenation of smaller matmuls on slices of the operands. Tensor parallelism is the disciplined exploitation of that fact across $T$ devices, and the discipline matters, because a careless partition forces a communication after every linear layer, while the right one (the Megatron-LM pattern) forces only one all-reduce per pair of layers. We build the right pattern from the algebra, then run it.
1. Splitting a Single Matmul Two Ways Intermediate
Consider one linear layer, $Y = XW$, with input activations $X$ of shape $B \times d_{\text{in}}$ and a weight matrix $W$ of shape $d_{\text{in}} \times d_{\text{out}}$. There are two natural ways to slice $W$ across $T$ devices, and they behave very differently at combining time. The first splits $W$ by columns, the second by rows, and the whole art of tensor parallelism is choosing which to use where.
Column-parallel. Partition $W$ into $T$ column blocks, $W = [\,W_1 \;\; W_2 \;\; \cdots \;\; W_T\,]$, where block $W_t$ has shape $d_{\text{in}} \times (d_{\text{out}}/T)$ and lives on device $t$. Give every device the full input $X$. Each device computes a slice of the output columns,
$$Y = X\,[\,W_1 \;\; \cdots \;\; W_T\,] = [\,XW_1 \;\; \cdots \;\; XW_T\,], \qquad Y_t = XW_t \in \mathbb{R}^{B \times (d_{\text{out}}/T)}.$$No device holds the whole output, but no communication is needed to produce the slices: the result is the concatenation of what each device already has. The columns are simply distributed, $\frac{1}{T}$ of them per device.
Row-parallel. Now partition $W$ into $T$ row blocks, $W = [\,W_1^\top \;\; \cdots \;\; W_T^\top\,]^\top$, where block $W_t$ has shape $(d_{\text{in}}/T) \times d_{\text{out}}$. This time the input must also be split along its feature dimension, $X = [\,X_1 \;\; \cdots \;\; X_T\,]$ with $X_t$ of shape $B \times (d_{\text{in}}/T)$, and the product becomes a sum of partial products over the shared contraction dimension,
$$Y = XW = \sum_{t=1}^{T} X_t W_t, \qquad X_t W_t \in \mathbb{R}^{B \times d_{\text{out}}}.$$Here every device produces a full-width output, but each is only a partial result; the true $Y$ is their sum. Summing one same-shape tensor held on each device and returning the total to all of them is exactly an all-reduce, the collective introduced in Section 4.3. So a column-parallel layer needs no communication to compute its output, and a row-parallel layer needs one all-reduce. The Megatron insight is to chain them so the column-parallel concatenation feeds the row-parallel sum without ever gathering anything in between.
If a column-parallel layer is followed directly by a row-parallel layer, the output slices of the first ($Y_t = XW_t$, the $t$-th block of hidden units) are precisely the input slices the second needs ($X_t$ for its row block). The intermediate activation never has to be gathered onto one device or re-scattered; each device just hands its own slice forward locally. The only communication in the whole two-layer block is the single all-reduce that sums the row-parallel partials at the very end. One all-reduce buys you two parallelized matmuls.
2. The Megatron Transformer Block Advanced
A transformer block is built from exactly the structure the previous insight rewards: two consecutive linear maps with a nonlinearity wedged between them. In the MLP sub-block the map is $Y = \mathrm{GeLU}(XA)\,B$, and in the attention sub-block it is the projection of the concatenated heads followed by an output projection. Megatron-LM (Shoeybi et al., 2019) makes the first matrix column-parallel and the second row-parallel, and arranges the partition so the nonlinearity in between acts independently on each device's slice.
This last point is what makes the scheme exact rather than approximate. A nonlinearity does not distribute over a sum, so $\mathrm{GeLU}(\sum_t X_t W_t) \neq \sum_t \mathrm{GeLU}(X_t W_t)$; if the activation sat after a row-parallel layer it would force a gather, apply, re-scatter, three communications instead of one. By placing the column-parallel matrix first, each device holds a complete, correct slice of hidden units ($\mathrm{GeLU}(XA_t)$ is the true GeLU of those units, because column $t$ depends only on $A_t$), applies the nonlinearity locally, and feeds the result straight into its own row block of the second matrix. The communication is deferred to a single point: the all-reduce that sums the row-parallel output. Counting both sub-blocks, the forward pass of one transformer block costs exactly two all-reduces (one for attention's output projection, one for the MLP), and by symmetry the backward pass costs two more.
Figure 16.2.1 makes the dataflow concrete: blue column blocks feed green local activations feed orange row blocks, and only at the bottom does anything cross between devices. The memory accounting is the payoff. Each device stores $\frac{1}{T}$ of $A$, $\frac{1}{T}$ of $B$, and $\frac{1}{T}$ of the hidden activations, so a block that needed $M$ bytes of weights and activations on one accelerator now needs about $\frac{M}{T}$ per device. That is precisely the relief Section 16.1 argued for: the layer that did not fit now fits, $T$ times over.
3. Tensor Parallelism From Scratch Intermediate
The code below implements the full column-parallel then row-parallel block on $T$ simulated devices, with the combining all-reduce, and checks the result against the single-device layer. There is no framework and no network; the "devices" are loop iterations, so the only thing being tested is the algebra of Section 1 and the locality of the nonlinearity from Section 2. The GeLU is applied per device, exactly as Megatron requires.
import numpy as np
rng = np.random.default_rng(0)
B, d_in, d_h, d_out, T = 4, 16, 32, 16, 4 # batch, in, hidden, out, devices
X = rng.standard_normal((B, d_in))
W1 = rng.standard_normal((d_in, d_h)) # column-parallel weight
W2 = rng.standard_normal((d_h, d_out)) # row-parallel weight
def gelu(z):
return 0.5 * z * (1.0 + np.tanh(0.7978845608 * (z + 0.044715 * z**3)))
# Single-device reference: Y = gelu(X @ W1) @ W2
ref = gelu(X @ W1) @ W2
# Tensor-parallel across T devices.
# Column-parallel: split W1 by COLUMNS -> each device owns d_h/T hidden units,
# applies its own nonlinearity locally, NO communication needed yet.
# Row-parallel: split W2 by ROWS to match -> each device produces a PARTIAL
# output over the full d_out; the combining all-reduce SUMS the partials.
W1_shards = np.split(W1, T, axis=1) # each (d_in, d_h/T)
W2_shards = np.split(W2, T, axis=0) # each (d_h/T, d_out)
partials = []
for t in range(T):
h_t = gelu(X @ W1_shards[t]) # local activation slice, device t
y_t = h_t @ W2_shards[t] # partial full-width output, device t
partials.append(y_t)
allreduced = np.sum(partials, axis=0) # the ONE forward all-reduce
print("devices T :", T)
print("hidden units / device:", d_h // T)
print("max abs difference :", f"{np.max(np.abs(allreduced - ref)):.2e}")
print("relative error :", f"{np.linalg.norm(allreduced - ref) / np.linalg.norm(ref):.2e}")
W1 and row slice of W2, applies GeLU to its own hidden units, and produces a partial output; np.sum(partials, axis=0) is the single combining all-reduce. The result is compared against the single-device layer ref.devices T : 4
hidden units / device: 8
max abs difference : 1.07e-14
relative error : 1.90e-16
The relative error is at the level of floating-point addition, the same signature of exactness seen for data parallelism in Section 1.1. Splitting the operator changed nothing about the answer; it only changed which device holds which piece and inserted one all-reduce to put the pieces back together. That all-reduce is the whole story of where tensor parallelism can run.
The collective you first performed by hand in Section 1.1 as gradient synchronization, and met formally in Section 4.3, returns here in a new role. In data parallelism the all-reduce fires once per training step, at the boundary between backward pass and optimizer update. In tensor parallelism it fires twice per transformer block, in the middle of the forward pass and again in the backward, because the partial sums it combines are not gradients but the layer's own outputs. Same primitive, far higher frequency, and that frequency is exactly why the next section places this technique on the fastest wire in the machine. Sharded data parallelism (Section 16.4) will summon two more relatives of the same operation, reduce-scatter and all-gather.
4. Why It Must Live on the Fast Interconnect Advanced
Count the communications. A model with $L$ transformer blocks fires $2L$ all-reduces in the forward pass and $2L$ in the backward, every single training step. For a large model $L$ is in the dozens or low hundreds, so tensor parallelism issues hundreds of all-reduces per step, each one blocking: the row-parallel layer cannot finish until the sum has returned to all devices. Data parallelism, by contrast, issues a handful of all-reduces per step (one per gradient bucket) and overlaps them with the backward computation. Tensor parallelism cannot hide its all-reduces nearly as well, because each one sits on the critical path between two matmuls that depend on it.
This volume of blocking, latency-sensitive communication has a hard consequence: tensor parallelism is viable only on the fastest interconnect available, which inside a modern accelerator node is NVLink (or its switched form, NVSwitch), offering an order of magnitude more bandwidth and far lower latency than the network between nodes. The topology argument of Section 4.9 applies in full force here: a collective is only as fast as the slowest link it crosses, so an all-reduce that must traverse the inter-node network would stall every layer on the slow wire. The practical rule that falls out is sharp and widely followed: keep the tensor-parallel degree $T$ within one node, at or below the number of GPUs sharing that node's NVLink fabric (commonly 8). Beyond one node, you switch to a coarser-grained parallelism, pipeline (Section 16.3) or sharded data parallel (Section 16.4), whose communication is rarer and easier to hide.
Code 16.2.1 spelled out the column and row splits and the manual all-reduce. In practice you never partition a weight matrix yourself: Megatron-LM ships the column-parallel and row-parallel linear layers as drop-in modules that register the right collective on the right side of the matmul and handle the backward all-reduce automatically. Wrapping an MLP is a two-line change of layer type:
# pip install megatron-core ; assumes the process group spans one node's GPUs
from megatron.core.tensor_parallel import ColumnParallelLinear, RowParallelLinear
# fc1 splits its OUTPUT (hidden) dim across the tensor-parallel group;
# fc2 splits its INPUT dim to match and all-reduces the partial outputs.
fc1 = ColumnParallelLinear(d_in, d_hidden, gather_output=False) # no gather: feed slices on
fc2 = RowParallelLinear(d_hidden, d_out, input_is_parallel=True) # all-reduce happens here
def mlp(x):
h, _ = fc1(x) # each rank holds its hidden-unit slice
h = gelu(h) # nonlinearity applied locally, exactly as in Code 16.2.1
y, _ = fc2(h) # RowParallelLinear fires the combining all-reduce
return y
gather_output=False and input_is_parallel=True are what keep the intermediate activation distributed (no gather between the matmuls), so the block still costs exactly one forward all-reduce, fired inside RowParallelLinear. Roughly a dozen lines of manual slicing collapse to a layer-type swap, and Megatron handles the process-group setup, the backward all-reduce, and the NCCL transport that Section 4.3 describes.Who: An ML platform engineer bringing up training for a 70-billion-parameter language model on a cluster of 8-GPU nodes.
Situation: A single MLP weight matrix plus its optimizer state and activations no longer fit in one 80 GB accelerator, so data parallelism (which needs a full model replica per device) would not even start.
Problem: The model had to be cut so that no device ever held a full layer, but the team feared the cut would slow training to a crawl from communication.
Dilemma: Cut whole layers onto different devices (pipeline parallelism, coarse communication but pipeline bubbles and load-balancing pain), or cut each layer's matrices across devices (tensor parallelism, exact and load-balanced but an all-reduce inside every layer).
Decision: They used tensor parallelism with degree $T = 8$, the number of GPUs sharing each node's NVLink fabric, and went no further, reserving cross-node scaling for pipeline and sharded data parallelism.
How: They replaced the MLP and attention projections with Megatron's ColumnParallelLinear and RowParallelLinear, kept the tensor-parallel group strictly within a node, and verified that every all-reduce stayed on NVLink rather than the inter-node network.
Result: Each GPU held one eighth of the layer weights and activations; the layer fit, the all-reduces stayed on the fast wire, and the per-step overhead from tensor-parallel communication was a tolerable fraction of compute. Pushing $T$ to 16 across two nodes, which they tested, doubled the communication time because half the all-reduce now crossed the slow network, confirming the one-node cap.
Lesson: Tensor parallelism buys an exact memory split at the price of an in-layer all-reduce; that price is only affordable on NVLink, so bound $T$ to a single node and scale further with a coarser axis.
There is something pleasingly conspiratorial about a row-parallel layer. Every device finishes its matmul holding a full-width output tensor of exactly the right shape, looking for all the world like the finished result, and every one of them is wrong. None is the answer; each is a fragment that only becomes correct after the all-reduce adds the other fragments in. A bug that skips the all-reduce produces output that is the right shape, runs without error, and is quietly nonsense, which is why the combining collective is the first thing to check when a tensor-parallel model trains to garbage.
5. Communication, Memory, and the Limits of the Technique Intermediate
It is worth stating plainly what tensor parallelism does and does not buy. It divides the weights, the optimizer state, and the activations of each parallelized layer by $T$, which is the relief that lets an oversized layer fit. It keeps the computation exact, with no approximation beyond floating-point rounding, as Output 16.2.1 showed. And it is perfectly load-balanced, since every device does an identical $\frac{1}{T}$ slice of the same matmul, unlike pipeline parallelism where some stages can idle. Against those gains stands one cost that dominates everything: communication volume. The all-reduce inside every layer makes tensor parallelism the most communication-intensive of the model-parallel strategies, which is why it never travels alone past one node and why it is almost always composed with the other axes rather than used by itself.
That composition is the subject of the rest of the chapter. Real large-model training stacks tensor parallelism within each node, pipeline parallelism across nodes, and sharded data parallelism across the whole cluster, an arrangement called 3D parallelism that Section 16.9 assembles. Each axis owns the scale at which its communication is affordable: tensor parallelism on NVLink inside the node, pipeline and data parallelism on the slower network between nodes. Getting that mapping right is the difference between a model that trains and one that idles its GPUs waiting on the wrong wire.
Because the per-layer all-reduce is what confines tensor parallelism to a single node, recent work attacks that collective directly. Sequence parallelism, now standard in Megatron-LM, splits the layer-norm and dropout activations along the sequence dimension so that the two all-reduces of the classic scheme become an all-gather plus a reduce-scatter of the same total volume, cutting the activation-memory footprint without adding communication; this is the bridge to the context-parallel methods of Section 16.7. A second line overlaps the tensor-parallel collective with the matmul that feeds it: fine-grained schemes that decompose the all-reduce and interleave its chunks with computation (in the lineage of work on overlapping communication and GEMMs, and of Google's collective-aware compilation) report hiding much of the in-layer communication behind the very matmul whose output it combines. A third pushes tensor parallelism cautiously across the fast scale-up domains of newer hardware (NVLink-connected multi-node "superpods"), testing whether the one-node cap can be relaxed when the inter-node fabric is itself NVLink-class. The common thread is that the all-reduce of Code 16.2.1 is treated as a quantity to be hidden or reshaped, never simply accepted.
We now have the finest-grained cut of a model: split the operator itself, pay one all-reduce per layer-pair, and keep it on NVLink. The next section steps up one level of granularity to pipeline parallelism, which cuts the model between layers rather than inside them, trades the in-layer all-reduce for a much rarer point-to-point handoff, and so reaches across nodes where tensor parallelism cannot. That story begins in Section 16.3.
Suppose an engineer implements the MLP block with the first matrix row-parallel and the second column-parallel, the opposite of Megatron's order. Trace the dataflow through the GeLU between them and explain exactly where a communication becomes unavoidable that the column-then-row order avoids. State how many all-reduces (or gathers) per block the reversed order costs, and why the nonlinearity is the deciding factor. Relate your answer to the Key Insight in Section 1.
Extend Code 16.2.1 to report, for each simulated device, the number of weight elements it stores (sum of its W1 and W2 shard sizes) and confirm it equals $\frac{1}{T}$ of the single-device total. Then sweep $T \in \{1, 2, 4, 8, 16\}$ (keep $d_h$ divisible by $T$), and for each $T$ print the per-device weight count and the relative error against ref. Verify the error stays at the floating-point floor for every $T$ while the per-device memory falls as $\frac{1}{T}$. Explain why exactness is independent of $T$.
A transformer block does roughly $C$ floating-point operations of compute and, per the scheme in this section, two all-reduces each moving about $A$ bytes of activations. On NVLink, bandwidth is $b_{\text{fast}}$; across the inter-node network it is $b_{\text{slow}} \approx b_{\text{fast}}/10$. Using the alpha-beta cost reasoning of Section 4.9, write an expression for the fraction of per-block wall-clock spent in communication on each fabric, and argue from it why doubling $T$ from 8 (one node) to 16 (two nodes) can increase total step time even though each device now does half the compute. Identify the break-even ratio of compute to communication at which crossing the node boundary stops being worthwhile.