spark-cache

Ever had a Spark job that keeps re-reading the same data over and over? You might need caching. But cache at the wrong time, and you’ll actually slow things down. Here’s when caching helps, when it doesn’t, and how to use it right. This blog will be short but sweet!

🔍 What is Caching?

Think of caching like keeping your frequently used files on your desk instead of walking to the filing cabinet every time you need them.

Here’s what happens without caching in apache spark:

1
2
3
4
5
6
7

# Spark reads the file from disk every single time
df = spark.read.parquet("large_file.parquet")

result1 = df.filter(col("status") == "active").count()  # Reads file
result2 = df.groupBy("category").count()                # Reads file AGAIN
result3 = df.select("id", "name").collect()             # Reads file AGAIN

If you see FileScan without InMemoryTableScan section like this on each print of execution plan then spark is scaning same file all over again.

1
2
3
4
5
6
7
8

== Physical Plan ==
*(1) Filter (isnotnull(amount#58) AND (amount#58 > 5000.0))
+- FileScan json [amount#58,id#59L,provider_id#60L,timestamp#61,user_id#62L] Batched: false, DataFilters: [isnotnull(amount#58), 
 (amount#58 > 5000.0)], Format: JSON, 
  Location: InMemoryFileIndex(1 paths)[file:/home/vesko/Documents/Personal/Projects/dummy-data-rust/output/tr..., 
  PartitionFilters: [], PushedFilters: [IsNotNull(amount), GreaterThan(amount,5000.0)], 
  ReadSchema: struct<amount:double,id:bigint,provider_id:bigint,timestamp:string,user_id:bigint>

With caching:

1
2
3
4
5
6

# Read once, keep in memory
df = spark.read.parquet("large_file.parquet").cache()
result1 = df.filter(col("status") == "active").count()  # Reads file + caches
result2 = df.groupBy("category").count()                # Uses cached data
result3 = df.select("id", "name").collect()             # Uses cached data

You will see something like this:

1
2
3
4
5
6
7

  == Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- Filter (isnotnull(amount#58) AND (amount#58 > 100000.0))
   +- InMemoryTableScan [amount#58, id#59L, provider_id#60L, timestamp#61, user_id#62L], [isnotnull(amount#58), (amount#58 > 100000.0)]
         +- InMemoryRelation [amount#58, id#59L, provider_id#60L, timestamp#61, user_id#62L], StorageLevel(disk, memory, deserialized, 1 replicas)
               +- FileScan json [amount#58,id#59L,provider_id#60L,timestamp#61,user_id#62L] Batched: false, DataFilters: [], Format: JSON, Location: InMemoryFileIndex(1 paths)[file:/home/vesko/Documents/Personal/Projects/dummy-data-rust/output/tr..., PartitionFilters: [], PushedFilters: [], ReadSchema: struct<amount:double,id:bigint,provider_id:bigint,timestamp:string,user_id:bigint>

From these execeution plan we can see that:

  • Plan 1: Reads ALL data from file → loads into memory → then filters (AQE will be active in the next phase of exectuion)
  • Plan 2: Applies filters while reading the file, so less data enters Spark + AQE is enabled and WholeCodeGeneration is active (optimization at max)

The difference?

Instead of reading scaning from disk three times, you read once and reuse the data from memory. This can turn a 10-minute job into a 2-minute job.

âś… When to Cache (The Good Stuff)

You’re Using the Same Data Multiple Times This is the most common scenario. If you’re running several operations on the same DataFrame, cache it.

1
2
3
4
5
6
7
# Perfect use case for caching
sales_data = spark.read.csv("sales.csv").cache()

# All of these reuse the cached data
daily_sales = sales_data.groupBy("date").sum("amount")
top_products = sales_data.groupBy("product").count()
regional_summary = sales_data.groupBy("region").agg(avg("amount"))

Expensive Operations You’ll Need Again

Got a complex join or heavy transformation? Cache the result if you’ll use it multiple times.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11

# Expensive join operation
enriched_data = transactions \
    .join(customers, "customer_id") \
    .join(products, "product_id") \
    .withColumn("profit", col("revenue") - col("cost")) \
    .cache()  # Cache this expensive result

# Now use it multiple times without recomputing
high_value = enriched_data.filter(col("profit") > 1000)
monthly_trends = enriched_data.groupBy("month").sum("profit")

Interactive Data Exploration

Working in a Jupyter notebook? Cache your base dataset so you can explore it quickly.

1
2
3
4
5
6
7
# Cache for interactive analysis
user_data = spark.read.json("users.json").cache()

# Now you can run multiple exploratory queries fast
user_data.describe().show()
user_data.filter(col("age") > 25).count()
user_data.groupBy("country").count().show()

đźš« When NOT to Cache (Avoid These)

One-Time Use Data

Don’t cache data you’ll only use once. It’s wasteful.

1
2
3
# BAD: Only using this once
df = spark.read.csv("temp_file.csv").cache()  # Pointless!
result = df.groupBy("category").count().write.parquet("output")

Linear ETL Pipelines

If your data flows in a straight line from source to destination, caching usually doesn’t help.

1
2
3
4
5
# BAD: Each step only used once
raw_data = spark.read.json("input.json")
cleaned = raw_data.dropna().cache()      # Unnecessary
transformed = cleaned.withColumn(...).cache()  # Unnecessary  
transformed.write.parquet("output")      # Just write it

When You’re Running Low on Memory

If your cluster is already struggling with memory, adding more cached data will make things worse.

1
2
# BAD: Your 4GB cluster can't handle caching 10GB of data
huge_df = spark.read.parquet("10gb_file.parquet").cache()  # Will cause problems

🛠️ How to Cache Right

Use cache() for Most Cases

1
2
# Simple and works for most situations
df = spark.read.parquet("file.parquet").cache()

Use persist() When You Need Control

1
2
3
4
5
6
from pyspark import StorageLevel

# Different storage options
df.persist(StorageLevel.MEMORY_ONLY)     # Fast, but limited by RAM
df.persist(StorageLevel.MEMORY_AND_DISK) # Spills to disk when memory full  
df.persist(StorageLevel.DISK_ONLY)       # Slower, but handles large data

Don’t Forget to Clean Up

1
2
3
4
# Always unpersist when done
df.cache()
# ... do your work ...
df.unpersist()  # Free up memory

🎯 Quick Decision Guide

Cache when:

  • You’ll use the same DataFrame 2+ times
  • You have expensive joins or transformations
  • You’re doing interactive analysis
  • You have enough cluster memory

Don’t cache when:

  • You’ll only use the data once
  • You’re running a simple linear pipeline
  • Your cluster is low on memory
  • The data is too big to fit in memory

đź’ˇ Conclusion

Caching is simple: if you’re going to use the same data multiple times, and you have the memory for it, cache it. If not, don’t. The biggest mistake I see? People either cache everything or cache nothing. The sweet spot is caching strategically - only the data you’ll actually reuse. Try this in your next Spark job: identify one DataFrame you use multiple times and add .cache() to it. You might be surprised by the performance boost.