"They asked me to recompute the same dataset for the fourth time today. I have a perfectly good copy in memory. Nobody asked."
A Cached Partition Tired of Being Ignored
Two levers move Spark performance more than any other: where the data lives, and how often it is recomputed. Partitioning decides how a dataset is split across the cluster, which fixes both the parallelism you get and the network traffic you pay when two datasets must be combined. Caching decides whether a dataset that several computations reuse is held in memory or rebuilt from scratch each time it is touched. Get partitioning right and a join costs nothing extra; get it wrong and the cluster spends its time moving rows instead of processing them. Cache a reused dataset and repeated work becomes nearly free; forget to, and Spark faithfully replays the entire lineage on every action. This section makes both levers concrete, measures them in pure Python, then shows the one-line PySpark calls that pull each one in production.
In the previous section we saw that Spark splits work into transformations, which are lazy and build up a plan, and actions, which force that plan to run. That distinction is exactly what makes partitioning and caching matter. Because transformations are lazy, the layout of the data when an action finally fires determines how much shuffling the action triggers, and because Spark rebuilds a dataset from its lineage whenever an action needs it, a dataset reused by many actions is recomputed many times unless you intervene. Both levers are direct descendants of the partitioning idea introduced in Section 2.3: a distributed dataset is a collection of partitions, and almost every performance question reduces to where those partitions sit and whether they get rebuilt.
1. Partitioning Sets Parallelism and Shuffle Beginner
A Spark dataset is physically a set of partitions, and a partition is the unit of parallel work: one task processes one partition on one core. The number of partitions therefore caps how much of the cluster a stage can use at once. With $P$ partitions and a cluster of $C$ cores, the stage runs at parallelism $\min(P, C)$. Too few partitions and most cores sit idle while a handful of tasks grind through oversized chunks; too many and the per-task overhead, scheduling, serialization, and bookkeeping, starts to dominate the useful work. The practical sweet spot is a partition count that is a small multiple of the cluster's core count, with each partition large enough to amortize task overhead but small enough to fit comfortably in a task's memory.
Partitioning matters even more when two datasets must be combined. A join on a key needs every row with a given key from one side to meet every row with that key from the other side. If the two datasets are partitioned by unrelated rules, matching rows live on different workers and Spark must repartition one or both sides by the join key, moving rows across the network. That movement is a shuffle, the single most expensive operation in distributed data processing and the subject of Section 7.7. The decisive observation is that a shuffle is avoidable: if both datasets are already partitioned by the join key into the same number of partitions, matching keys are co-located and the join runs locally on each worker with no network traffic. This co-partitioning trick is the same idea that lets the MapReduce shuffle of Chapter 6 group keys once and reuse the grouping.
A join's cost is dominated by whether its two inputs are co-partitioned by the join key. If they are, the join is a local, per-partition operation and moves zero rows across the network. If they are not, Spark inserts a shuffle that repartitions data by the key before joining. Partitioning is therefore not a tuning afterthought; it is a decision about which future operations get to be free. Partition once by the key you will join, group, or aggregate on most often, and every later operation on that key inherits a shuffle that has already been paid for.
2. Caching Avoids Recompute From Lineage Beginner
Spark does not store the intermediate results of a transformation chain. It stores the lineage, the recipe of transformations that produces a dataset from its source, and replays that recipe whenever an action needs the dataset. This is what makes Spark fault tolerant: a lost partition is simply recomputed from its lineage. The same mechanism, though, means that a dataset touched by three actions is built three times, scanning the source and rerunning every transformation on each pass. For a one-shot pipeline that is fine. For any workload that reuses a dataset, an iterative machine learning loop that sweeps the same features each epoch, or an interactive session that runs query after query on one prepared table, the repeated recomputation is pure waste.
Caching breaks the cycle. When you mark a dataset as cached and then run the first action, Spark keeps the materialized partitions in memory across the cluster; every subsequent action reads from memory instead of replaying the lineage. The dataset is built exactly once. The win scales with reuse: $a$ actions on an uncached dataset that costs $T$ to materialize spend $a \cdot T$ on materialization, while the cached version spends $T$ once plus a negligible read per action, so the saving grows linearly in the number of reuses. The flip side is memory: cached partitions occupy cluster memory that the rest of the job could otherwise use, which is why Spark offers storage levels that trade memory for disk.
Storage levels name that trade-off. MEMORY_ONLY keeps partitions in memory and silently recomputes any that do not fit, which is fast but can quietly undo the benefit under memory pressure. MEMORY_AND_DISK spills partitions that overflow memory to local disk, so a reused dataset is never recomputed, only reread from disk in the worst case, which is the safe default for datasets larger than available memory. Serialized variants pack partitions more densely at the cost of CPU to deserialize on each read. The right level depends on whether memory or recomputation is the scarcer resource for your job, the same kind of resource-matching judgment that Chapter 3 turns into explicit cost models.
A common first-day surprise: calling df.cache() appears to do nothing. No data moves, no memory fills, the call returns instantly. That is because cache() is itself lazy; it only marks the dataset as a caching candidate. The partitions are not stored until the next action actually materializes them. Run one action, and the second action is suddenly fast. The cache was waiting for someone to compute the thing worth caching.
3. Measuring Both Levers From Scratch Intermediate
To see both levers without a cluster, we model the work Spark would do in plain Python. For caching, we run three actions over a dataset produced by a multi-stage transformation chain, first recomputing the chain on every action (the uncached path) and then materializing it once and reusing it (the cached path), and we count both the rows scanned and the wall-clock. For partitioning, we join two datasets on a key and count how many rows must cross the network: once when the two sides use mismatched partitioning schemes, and once when both are hash-partitioned by the join key into the same number of partitions. The code below is the exact script that produced the output that follows it.
import time
N = 2_000_000
SOURCE = range(N)
def expensive_lineage(rows):
"""Stand-in for a map -> filter -> map transformation chain.
Cost is proportional to the number of rows scanned."""
out = []
for x in rows:
v = (x * 2654435761) & 0xFFFFFFFF # a per-row map
if v % 3: # a filter
out.append(v % 1000) # a final map to a key bucket
return out
# --- Uncached: 3 actions each replay the lineage from the source ---
ACTIONS = 3
t0 = time.perf_counter()
uncached_rows = 0
for _ in range(ACTIONS):
materialized = expensive_lineage(SOURCE) # recompute every time
uncached_rows += N
_ = sum(materialized) # the action
uncached_time = time.perf_counter() - t0
# --- Cached: materialize ONCE, then reuse across all actions ---
t0 = time.perf_counter()
cached = expensive_lineage(SOURCE) # one materialization
for _ in range(ACTIONS):
_ = sum(cached) # each action reuses memory
cached_time = time.perf_counter() - t0
print("== Caching: cost of repeated actions ==")
print(f"actions : {ACTIONS}")
print(f"uncached rows recomputed : {uncached_rows:,}")
print(f"cached rows recomputed : {N:,}")
print(f"uncached wall-clock : {uncached_time:.3f} s")
print(f"cached wall-clock : {cached_time:.3f} s")
print(f"speedup from caching : {uncached_time / cached_time:.2f}x")
# --- Join shuffle: rows that must move when partitioning is/ isn't matched ---
P, M = 8, 400_000
right = [(i % 5000, i * 3) for i in range(M)] # (key, value)
def hash_part(key): return hash(key) % P # the join-key scheme
def range_part(key): return (key * P) // 5000 # a key-unaware scheme
unmatched = sum(1 for k, _ in right if range_part(k) != hash_part(k))
matched = sum(1 for k, _ in right if hash_part(k) != hash_part(k))
print("\n== Partitioning: rows that must cross the network on a join ==")
print(f"partitions / workers : {P}")
print(f"right-side rows : {M:,}")
print(f"unmatched layout shuffled: {unmatched:,} rows")
print(f"matched layout shuffled : {matched:,} rows")
print(f"shuffle eliminated : {100 * (1 - matched / max(unmatched, 1)):.1f}%")
matched is always zero because both sides use the identical key-based scheme, so every key is co-located.== Caching: cost of repeated actions ==
actions : 3
uncached rows recomputed : 6,000,000
cached rows recomputed : 2,000,000
uncached wall-clock : 1.164 s
cached wall-clock : 0.425 s
speedup from caching : 2.74x
== Partitioning: rows that must cross the network on a join ==
partitions / workers : 8
right-side rows : 400,000
unmatched layout shuffled: 349,440 rows
matched layout shuffled : 0 rows
shuffle eliminated : 100.0%
The two results are the section in miniature. Caching turned three full recomputations into one, and the speedup of 2.74 for three actions becomes roughly $a$-fold for $a$ actions, which is why iterative training loops cache their feature data once and reuse it every epoch. Matched partitioning turned a join that would shuffle nearly nine in ten rows into one that shuffles none. Neither lever changed a single output value; both changed only how much redundant work and network traffic the cluster performed to reach the same answer.
The pure-Python model above counted the work by hand. In PySpark each lever is a single method call, and the engine handles the cluster-wide bookkeeping, memory management, and shuffle scheduling that the model only simulated. Caching is persist (or cache, which is persist at the default level); choosing where partitions live is a StorageLevel; co-partitioning two datasets by a key is repartition:
from pyspark import StorageLevel
# Caching: build the feature table once, reuse it across many actions.
features = spark.read.parquet("s3://bucket/features")
features = features.persist(StorageLevel.MEMORY_AND_DISK) # spill, never recompute
features.count() # first action materializes and stores partitions
for epoch in range(10):
train_one_epoch(features) # every epoch reads from cache, not from S3
# Partitioning: co-partition both sides by the join key so the join is local.
events = events.repartition(200, "user_id") # hash-partition by user_id
profiles = profiles.repartition(200, "user_id") # same key, same partition count
joined = events.join(profiles, "user_id") # no shuffle: keys are co-located
features.unpersist() # free the cache when the loop is done
persist with MEMORY_AND_DISK replaces the manual materialize-once loop and guarantees the source is read a single time; matching repartition calls on both inputs replace the manual co-location check and let Spark skip the join shuffle. Roughly the dozen lines of hand-counting collapse to three calls.Who: A data engineer supporting a recommendation team that retrains a ranking model nightly.
Situation: The training job read a prepared feature DataFrame, derived from a heavy chain of joins and aggregations over the day's logs, once per epoch for forty epochs.
Problem: Each epoch replayed the entire feature-building lineage from the raw logs in object storage, so the job spent most of its hours rebuilding inputs it had already computed.
Dilemma: Add a bigger cluster to make the redundant recomputation faster, or change the job so the recomputation never happens, at the risk of running out of cluster memory for the cached table.
Decision: They cached the feature table with persist(StorageLevel.MEMORY_AND_DISK), choosing the spill-to-disk level because the table was larger than memory and recomputing it forty times was the real cost, not the occasional disk read.
How: One persist call plus a count() to force materialization before the epoch loop, and a repartition(_, "user_id") on both the events and profile tables so the per-epoch join stopped shuffling.
Result: The feature lineage ran once instead of forty times and the join shuffle disappeared; nightly wall-clock fell from over five hours to under one, on the same cluster, with identical model accuracy.
Lesson: When a dataset is reused, the question is never how fast can we recompute it but how do we stop recomputing it. Caching plus matched partitioning attack the two redundant costs that dominate reuse-heavy Spark jobs.
4. Where These Levers Meet the AI Data Path Intermediate
Partitioning and caching are not only Spark concerns; they sit directly on the path that feeds a training job. The output of a Spark feature pipeline becomes the input dataset that an accelerator reads during training, and how that output is partitioned on disk decides how evenly the training workers can each pull a shard without contending for the same files. A feature table written as a few enormous partitions forces some data-loading workers to idle while others stream a giant file; the same table written as many balanced partitions lets every worker pull its share in parallel. The partitioning decisions of this section therefore set up the storage-layer and data-loading concerns that the next chapter takes on directly, where the loader, not the model, is often the bottleneck (Chapter 8).
Caching plays the mirror role on the read side. Within a single Spark session that prepares data for AI, caching the cleaned and joined table once means every downstream query, exploratory statistics, train and validation splits, repeated feature derivations, reads from memory rather than re-running the join. The boundary to watch is the handoff: caching helps while the data lives inside one Spark application, but once the prepared dataset is written out for a separate training process, the in-memory cache is gone and the on-disk partitioning is what carries forward. Knowing which lever applies on which side of that boundary, in-memory caching inside the pipeline, partitioning for the data written out, is the difference between a data path that keeps the accelerators fed and one that starves them.
The partition, introduced as a distributed-systems concept in Section 2.3 and made physical here as the unit of Spark parallelism, is the same primitive that returns as sharded training data in Chapter 8, as sharded parameters in the parameter servers of Part III, and as model shards in the sharded-parallelism methods of Part IV. The co-partitioning trick that makes a Spark join local, lining two datasets up by the same key so no data moves, is the data-side rehearsal of the alignment that lets a sharded all-reduce avoid moving more than it must. Whenever a later chapter asks how to split work so the combine step stays cheap, it is asking this section's question one level up.
5. Choosing Partition Counts and What to Cache Advanced
Two judgments recur. The first is how many partitions to use. The default after a shuffle is often a fixed number (200 in classic Spark) that ignores the actual data size, so on a small dataset it creates hundreds of tiny tasks whose overhead swamps the work, and on a large one it creates a few oversized tasks that spill and straggle. The fix is to size the partition count to the data and the cluster: enough partitions that each holds a manageable slice (tens to a few hundred megabytes is a common target) and that there are at least as many as cluster cores so no core sits idle, but not so many that scheduling overhead dominates. Modern Spark's adaptive query execution can coalesce or split partitions at runtime based on observed sizes, which removes much of the guesswork but does not remove the need to understand what it is correcting for.
The second judgment is what to cache, and the rule is reuse. Cache a dataset only when more than one action will read it, because a dataset read once gains nothing from caching and pays the memory cost for free. The strongest candidates are exactly the AI-flavored workloads: an iterative training or optimization loop that sweeps the same data every step, an interactive analysis session that fires many queries at one prepared table, and any branch point where one expensive intermediate feeds several different downstream computations. The cost to weigh against the saving is memory pressure: a cache that evicts useful partitions to make room, or that triggers spilling elsewhere, can cost more than the recomputation it avoids, which is why measuring, not guessing, is the discipline. The performance models of Chapter 3 give the framework for putting numbers on that trade-off rather than tuning by feel.
The partitioning question of this section has moved to the center of large-model data pipelines. The lakehouse table formats that dominate current practice, Delta Lake, Apache Iceberg, and Apache Hudi, treat physical layout as a first-class, tunable property: liquid clustering and Z-ordering reorganize partitions by the columns queries filter and join on, so the co-partitioning benefit this section demonstrates is applied automatically and continuously as data lands. On the read side, MosaicML's StreamingDataset and the Ray Data lineage have reframed shard layout as a determinant of training throughput, showing that how feature data is partitioned and chunked on object storage governs whether accelerators stay fed or stall on the loader. Recent work on data-loading bottlenecks for foundation-model training (2024 to 2025) reports that layout and caching decisions, not raw bandwidth, frequently set the ceiling on end-to-end training speed, which is precisely the handoff the next chapter formalizes. The lever is the same one in Output 7.6.1; the stakes are a fleet of idle accelerators.
With partitioning and caching in hand, we have the two levers that decide how much redundant work and network traffic a Spark job performs. Both, in the end, point at the same operation: the shuffle that partitioning aims to avoid and that a missing cache forces Spark to redo. The next section confronts the shuffle directly, including the data skew that defeats even well-chosen partitioning, and shows how to diagnose and repair it. That discussion begins in Section 7.7.
A prepared DataFrame costs $T$ to materialize from its lineage and a negligible amount to read once cached. For each scenario, state whether caching helps and why: (a) the DataFrame is read by exactly one action that writes it to disk; (b) it is read once per epoch across 25 training epochs; (c) it is read by three queries but is larger than total cluster memory under MEMORY_ONLY; (d) it feeds two downstream branches that run concurrently. For (c), explain which storage level changes the answer and what it trades away.
Extend Code 7.6.1 with a left dataset of $M$ rows keyed the same way as right. Implement the actual join two ways: first when both sides are hash-partitioned by the key into $P$ partitions (count rows moved, which should be zero), then when left is hash-partitioned but right is range-partitioned by the key (count the rows that must be repartitioned to the correct hash partition). Print the number of rows shuffled and the resulting joined-row count in both cases, and confirm the join output is identical even though one path shuffled and the other did not.
You have a 64 gigabyte dataset and a cluster of 100 cores, and you want each partition to hold roughly 128 megabytes. Compute the partition count that target implies, and compare it to the parallelism you would get with the default of 200 partitions and with a count equal to the core count. Argue from the numbers which of the three is the best default for this dataset, what goes wrong at 10 partitions and at 50,000 partitions, and how adaptive query execution would adjust a poor initial choice at runtime. Tie your reasoning to the $\min(P, C)$ parallelism bound from Section 1.