Mar 26, 2026

KV Caching - Making Transformers Actually Fast

Why naive transformer inference is O(N²), how KV caching fixes it, and what happens when the cache itself becomes the bottleneck.

The Transformer architecture revolutionized AI by looking at an entire sequence at once. Parallel processing across all tokens made training on internet-scale datasets feasible. But there's a catch.

Parallel processing is only possible during training, when you already have the entire document. During inference - when the model is deployed and talking to a user - it doesn't have the whole document. It generates text one token at a time, just like humans speak.

This process is called autoregressive generation. And if we use standard Transformer math to do it, it is painfully slow. This post explains why, and builds the engine that makes real-time generation possible: KV Caching.

The Autoregressive Loop

Let's say we prompt the model with "The cat sat" and want it to generate more text.

Step 1: Feed ["The", "cat", "sat"] into the model. It computes attention for all 3 tokens and predicts: "on".

Step 2: Append "on" to the input. Feed ["The", "cat", "sat", "on"] into the model. It computes attention for all 4 tokens and predicts: "the".

Step 3: Feed ["The", "cat", "sat", "on", "the"] - all 5 tokens - and predict: "mat".

Do you see the problem? At step 2, the model recalculates the Q, K, V projections for "The", "cat", and "sat" completely from scratch - even though nothing about those tokens has changed. At step 3, it recomputes all of them again, plus "on".

For a 1,000-token generation, the model recalculates token 1 a thousand times, token 2 nine hundred and ninety-nine times, and so on.

The total work is the sum 1+2+3++N=N(N+1)21 + 2 + 3 + \ldots + N = \frac{N(N+1)}{2}, which is O(N2)O(N^2). As the generated text gets longer, the computation grows quadratically and generation grinds to a halt.

The Mathematical Epiphany

To fix this, look closely at the attention formula:

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

During generation, we only need the attention output for the newest token. What does that newest token need?

  • Its own Query (QQ): "What am I looking for?" - This must be freshly computed because it's a new token.
  • All Keys (KK): The new token compares its query against the keys of all previous tokens. But the keys for past tokens haven't changed - Ki=WKxiK_i = W_K x_i depends only on the embedding xix_i, which is fixed once token ii is generated.
  • All Values (VV): Same story. Vi=WVxiV_i = W_V x_i is static once token ii exists.

Here's the epiphany: the Keys and Values of past tokens never change. The word "cat" has the same K and V vector whether we're at step 3 or step 900. The only new computation needed is the Query for the latest token.

So instead of recalculating K and V for every past token at every step, we calculate them once, save them in GPU memory, and reuse them. This is the KV Cache.

How KV Caching Works

Prefill phase

When the user sends a prompt, the model processes all prompt tokens in parallel - just like training. This prefill step computes Q, K, V for every token simultaneously and populates the initial cache.

Decode phase

For each new token generated:

  1. Compute only qt=WQxtq_t = W_Q x_t - the query for the new token (a single vector, not a matrix)
  2. Compute kt=WKxtk_t = W_K x_t and vt=WVxtv_t = W_V x_t - the key and value for the new token
  3. Append ktk_t and vtv_t to the cached K1:t1K_{1:t-1} and V1:t1V_{1:t-1}
  4. Compute attention: softmax(qtK1:tTdk)V1:t\text{softmax}\left(\frac{q_t \cdot K_{1:t}^T}{\sqrt{d_k}}\right) V_{1:t}

The Q matrix has shrunk from (t×d)(t \times d) to (1×d)(1 \times d). We've eliminated the re-projection of all past tokens. The per-step projection cost drops from O(td)O(t \cdot d) to O(d)O(d), and the total cost across all steps drops from O(N2d)O(N^2 \cdot d) to O(Nd)O(N \cdot d).

Step through the visualization below to see it in action:

Implementation

The naive approach recomputes everything:

python
def naive_attention(x, W_q, W_k, W_v): """Recomputes Q, K, V for ALL tokens every step.""" Q = x @ W_q # (seq_len, d) K = x @ W_k # (seq_len, d) V = x @ W_v # (seq_len, d) scores = Q @ K.T / (K.size(-1) ** 0.5) weights = F.softmax(scores, dim=-1) return weights @ V

With KV caching, we only process the new token:

python
def cached_attention(x_new, past_kv, W_q, W_k, W_v): """Only project the new token. Reuse cached K, V.""" # x_new: (1, d) - just the newest token q = x_new @ W_q k_new = x_new @ W_k v_new = x_new @ W_v if past_kv is not None: K = torch.cat([past_kv[0], k_new], dim=0) V = torch.cat([past_kv[1], v_new], dim=0) else: K, V = k_new, v_new scores = q @ K.T / (K.size(-1) ** 0.5) weights = F.softmax(scores, dim=-1) return weights @ V, (K, V)

The generation loop becomes:

python
def generate(model, prompt_ids, max_new_tokens): # Prefill: process entire prompt in parallel logits, kv_cache = model(prompt_ids, kv_cache=None) next_token = logits[:, -1].argmax(dim=-1) generated = [next_token] for _ in range(max_new_tokens - 1): # Decode: one token at a time, reusing cache logits, kv_cache = model( next_token.unsqueeze(0), kv_cache=kv_cache ) next_token = logits[:, -1].argmax(dim=-1) generated.append(next_token) return generated

The Memory Wall

KV caching solves the compute problem brilliantly. But in computer science, there's no free lunch. We traded a compute problem for a memory problem.

How big is the cache?

The KV cache stores two tensors (K and V) for every layer, every attention head, at every position in the sequence:

Memory=2×L×H×dh×N×B×sizeof(dtype)\text{Memory} = 2 \times L \times H \times d_h \times N \times B \times \text{sizeof(dtype)}

Where LL = layers, HH = heads, dhd_h = head dimension, NN = sequence length, BB = batch size, and dtype is typically fp16 (2 bytes).

Concrete numbers:

python
def kv_cache_memory_gb(n_layers, n_heads, d_head, seq_len, batch=1): """KV cache memory in GB (fp16).""" return 2 * n_layers * n_heads * d_head * seq_len * batch * 2 / (1024**3) # Llama 2 7B at 4096 tokens kv_cache_memory_gb(32, 32, 128, 4096) # ~2.0 GB # Llama 2 13B at 4096 tokens kv_cache_memory_gb(40, 40, 128, 4096) # ~3.1 GB # Llama 2 70B at 4096 tokens kv_cache_memory_gb(80, 64, 128, 4096) # ~20.0 GB

2 GB for a 7B model seems manageable. But scale up the sequence length or batch size and things get dire fast:

At batch size 8 with a 32K context window, even a 7B model's KV cache exceeds the memory of a consumer GPU. A 70B model at 128K tokens? The cache alone needs hundreds of gigabytes - more than the model weights themselves.

This is the memory wall. The KV cache becomes the dominant memory consumer, and it determines how many users a single GPU can serve simultaneously.

Beyond Basic KV Caching

The memory wall spawned an entire field of engineering innovations:

Multi-Query and Grouped-Query Attention

Standard multi-head attention gives each head its own K and V projections. Multi-Query Attention (MQA) (Shazeer, 2019) shares a single K/V head across all query heads, slashing cache size by the number of heads. Grouped-Query Attention (GQA) (Ainslie et al., 2023) compromises - groups of query heads share K/V heads. Llama 2 70B uses GQA with 8 KV heads shared across 64 query heads, reducing the cache by 8×.

FlashAttention

FlashAttention (Dao et al., 2022) doesn't eliminate the cache, but makes the attention computation itself dramatically faster by being IO-aware. Standard attention materializes the full N×NN \times N attention matrix in GPU high-bandwidth memory (HBM). FlashAttention tiles the computation into blocks that fit in fast on-chip SRAM, never materializing the full matrix. The result: 2-4× faster attention with significantly less memory.

PagedAttention

Traditional KV caching pre-allocates a contiguous block of memory for the maximum possible sequence length. If the actual sequence is shorter, the rest is wasted. With multiple concurrent requests of varying lengths, fragmentation is massive.

PagedAttention (Kwon et al., 2023), implemented in vLLM, treats KV cache like virtual memory - storing it in non-contiguous pages that are allocated on demand. This eliminates fragmentation, enables sharing cache blocks across requests with common prefixes (like system prompts), and can nearly double the throughput of a serving system.

Quantized KV Caches

If fp16 K/V values can be quantized to int8 or int4 with minimal quality loss, the cache shrinks by 2-4×. This is an active area of research (Hooper et al., 2024), with some models tolerating aggressive quantization of cached values better than others.

The Full Picture

Without KV caching, real-time text generation from large language models would be unusable. A 1,000-token response would require roughly 500,000 redundant matrix multiplications. With KV caching, it requires 1,000.

But the cache creates its own bottleneck - a memory wall that dominates GPU utilization in production serving. Every major LLM deployment today is shaped by this tradeoff: compute vs. memory, speed vs. capacity. FlashAttention, PagedAttention, GQA, and quantization are all different angles of attack on the same fundamental constraint.

The inference engine is built. Now the challenge is making it fit.