Mar 26, 2026

KV Caching - Making Transformers Actually Fast

Why naive transformer inference is O(N²), how KV caching fixes it, and what happens when the cache itself becomes the bottleneck.

KV Caching - Making Transformers Actually Fast hero image

Generating a single 1,000-token response with a naive transformer requires roughly 500,000 matrix multiplications that produce output you immediately throw away. That number is not a rounding error - it is the actual cost of recomputing the same vectors over and over, for tokens that have not changed, because nobody told the model it was allowed to remember them.

Parallel processing is only possible during training, when you already have the entire document. During inference, the model generates one token at a time - it cannot know token 600 until it has committed to tokens 1 through 599. This is called autoregressive generation, and it creates a mismatch between how transformers are designed and how they are deployed.

The Autoregressive Loop

Start with a concrete example. The model receives "The cat sat" and must continue.

Step 1: Feed ["The", "cat", "sat"] into the model. Attention is computed for all 3 tokens. Output: "on".

Step 2: Feed ["The", "cat", "sat", "on"] - all 4 tokens. Attention is recomputed for every token. Output: "the".

Step 3: Feed ["The", "cat", "sat", "on", "the"] - all 5. Every token again. Output: "mat".

At step 2, the model recalculates the Q, K, V projections for "The", "cat", and "sat" from scratch. Nothing about those tokens has changed - their embeddings are identical to step 1. At step 3, all of them get recomputed again. Token 1 will be recomputed a total of NN times across an NN-token generation.

The total work sums to 1+2+3++N=N(N+1)21 + 2 + 3 + \ldots + N = \frac{N(N+1)}{2}, which is O(N2)O(N^2). Double the generation length, quadruple the compute. This is why, without caching, LLM inference at 4K+ contexts was essentially unusable.

Keys and Values Do Not Change

To see why caching is possible, look at the attention formula:

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

During generation, the only thing the model needs to produce is the attention output for the newest token. What does that token depend on?

  • Its own Query (QQ): computed fresh because it is a new token.
  • All Keys (KK): the newest token compares its query against the keys of every previous token. But Ki=WKxiK_i = W_K x_i depends only on xix_i, which is fixed once token ii has been generated.
  • All Values (VV): same argument. Vi=WVxiV_i = W_V x_i does not change after token ii exists.

The word "cat" produces the same K and V vector whether we are at step 3 or step 900. Every step, we recompute them from scratch and use them once. The fix is obvious in retrospect: compute them once, store the result, and look it up. This is the KV Cache.

How KV Caching Works

Prefill phase

When the user sends a prompt, the model processes all prompt tokens in parallel - just like training. This prefill step computes Q, K, V for every prompt token simultaneously and populates the initial cache. The quadratic cost of attention still applies here, but it happens only once and scales with prompt length, not generation length.

Decode phase

For each new token generated:

  1. Compute only qt=WQxtq_t = W_Q x_t - the query for the new token (a single vector, not a matrix)
  2. Compute kt=WKxtk_t = W_K x_t and vt=WVxtv_t = W_V x_t - the key and value for the new token
  3. Append ktk_t and vtv_t to the cached K1:t1K_{1:t-1} and V1:t1V_{1:t-1}
  4. Compute attention: softmax(qtK1:tTdk)V1:t\text{softmax}\left(\frac{q_t \cdot K_{1:t}^T}{\sqrt{d_k}}\right) V_{1:t}

The Q matrix has shrunk from (t×d)(t \times d) to (1×d)(1 \times d). We eliminated re-projection of all past tokens. Per-step projection cost drops from O(td)O(t \cdot d) to O(d)O(d), and the total cost across all steps from O(N2d)O(N^2 \cdot d) to O(Nd)O(N \cdot d).

Step through the visualization below to see it in action:

Implementation

The naive approach recomputes everything:

python
def naive_attention(x, W_q, W_k, W_v): """Recomputes Q, K, V for ALL tokens every step.""" Q = x @ W_q # (seq_len, d) K = x @ W_k # (seq_len, d) V = x @ W_v # (seq_len, d) scores = Q @ K.T / (K.size(-1) ** 0.5) weights = F.softmax(scores, dim=-1) return weights @ V

With KV caching, we only process the new token:

python
def cached_attention(x_new, past_kv, W_q, W_k, W_v): """Only project the new token. Reuse cached K, V.""" # x_new: (1, d) - just the newest token q = x_new @ W_q k_new = x_new @ W_k v_new = x_new @ W_v if past_kv is not None: K = torch.cat([past_kv[0], k_new], dim=0) V = torch.cat([past_kv[1], v_new], dim=0) else: K, V = k_new, v_new scores = q @ K.T / (K.size(-1) ** 0.5) weights = F.softmax(scores, dim=-1) return weights @ V, (K, V)

The generation loop becomes:

python
def generate(model, prompt_ids, max_new_tokens): # Prefill: process entire prompt in parallel logits, kv_cache = model(prompt_ids, kv_cache=None) next_token = logits[:, -1].argmax(dim=-1) generated = [next_token] for _ in range(max_new_tokens - 1): # Decode: one token at a time, reusing cache logits, kv_cache = model( next_token.unsqueeze(0), kv_cache=kv_cache ) next_token = logits[:, -1].argmax(dim=-1) generated.append(next_token) return generated

Notice what kv_cache is: a list of (K, V) tensor pairs, one per transformer layer. Every layer maintains its own cache because every layer has its own K/V projections. The cache grows by one row per layer per generated token.

The Memory Wall

KV caching trades a compute problem for a memory problem. The compute savings are real and dramatic. The memory cost is also real and dramatic, and at scale, it becomes the binding constraint.

How big is the cache?

The KV cache stores two tensors (K and V) for every layer, every attention head, at every position in the sequence:

Memory=2×L×H×dh×N×B×sizeof(dtype)\text{Memory} = 2 \times L \times H \times d_h \times N \times B \times \text{sizeof(dtype)}

Where LL = layers, HH = heads, dhd_h = head dimension, NN = sequence length, BB = batch size, and dtype is typically fp16 (2 bytes).

Concrete numbers:

python
def kv_cache_memory_gb(n_layers, n_heads, d_head, seq_len, batch=1): """KV cache memory in GB (fp16).""" return 2 * n_layers * n_heads * d_head * seq_len * batch * 2 / (1024**3) # Llama 2 7B at 4096 tokens kv_cache_memory_gb(32, 32, 128, 4096) # ~2.0 GB # Llama 2 13B at 4096 tokens kv_cache_memory_gb(40, 40, 128, 4096) # ~3.1 GB # Llama 2 70B at 4096 tokens kv_cache_memory_gb(80, 64, 128, 4096) # ~20.0 GB

2 GB for a 7B model seems manageable. But scale up the sequence length or batch size and things get dire fast:

At batch size 8 with a 32K context window, even a 7B model's KV cache exceeds the memory of a consumer GPU. A 70B model at 128K tokens? The cache alone needs hundreds of gigabytes - more than the model weights themselves.

This is the memory wall. The KV cache becomes the dominant memory consumer, and it determines how many users a single GPU can serve simultaneously. A GPU that can hold the 70B weights and produce one slow response might be able to serve zero concurrent users if there is no memory left for their caches.

The compute problem was solved by storing more. The memory problem cannot be solved by storing more - it must be attacked by storing less, or storing smarter.

Attacking the Memory Wall

Multi-Query and Grouped-Query Attention

Standard multi-head attention gives each head its own K and V projections. Multi-Query Attention (MQA) (Shazeer, 2019) shares a single K/V head across all query heads. The cache shrinks by a factor of HH - the number of heads. The quality hit is real but often acceptable.

Grouped-Query Attention (GQA) (Ainslie et al., 2023) finds a middle ground: groups of query heads share K/V heads. Llama 2 70B uses GQA with 8 KV heads shared across 64 query heads, reducing cache size by 8x with less quality degradation than full MQA.

python
# Standard MHA: each head has its own K, V projection # KV cache shape: (batch, n_heads, seq_len, d_head) # MQA: one shared K, V for all query heads # KV cache shape: (batch, 1, seq_len, d_head) <- 1/n_heads the size # GQA: g groups, each with their own K, V # KV cache shape: (batch, n_groups, seq_len, d_head) <- n_groups/n_heads the size

The cache is just the stored K and V tensors. Fewer KV heads means fewer tensors to store - the query heads still exist and still attend, they just all read from a shared set of keys and values.

FlashAttention

FlashAttention (Dao et al., 2022) does not eliminate the cache, but makes the attention computation itself dramatically faster by being IO-aware. Standard attention materializes the full N×NN \times N attention score matrix in GPU high-bandwidth memory (HBM). For long sequences, that matrix dominates memory bandwidth - you spend most of your time reading and writing it, not computing.

FlashAttention tiles the computation into blocks that fit in fast on-chip SRAM, never materializing the full matrix. The result is 2-4x faster attention and significantly lower peak memory, with mathematically identical output.

PagedAttention

Traditional KV caching pre-allocates a contiguous block of memory for the maximum possible sequence length. A request that terminates early leaves most of that allocation unused. With hundreds of concurrent requests, the fragmentation is massive - you might have 40% of your GPU memory reserved but unreachable.

PagedAttention (Kwon et al., 2023), implemented in vLLM, treats KV cache like virtual memory. Cache is stored in non-contiguous pages that are allocated on demand. Pages can be shared across requests that share a common prefix (system prompts, few-shot examples). Fragmentation drops to near zero. Throughput nearly doubles on realistic serving workloads - not because the math changed, but because the allocator got smarter.

Quantized KV Caches

The cache is in fp16 because attention arithmetic is sensitive to precision. But how sensitive, exactly? If K and V can be quantized to int8 or int4 with acceptable quality loss, the cache shrinks by 2-4x. Research (Hooper et al., 2024) shows models tolerate K quantization better than V quantization, because errors in V accumulate multiplicatively through the weighted sum. Some production systems run K at int8 and V at fp16 as a pragmatic compromise.

What This Means in Practice

Every technique above is a different angle of attack on the same constraint: a fixed amount of GPU memory that must be shared between model weights, KV cache, and activations. The KV cache, unlike the weights, grows with sequence length and batch size - which means at high throughput, it crowds everything else out.

A 70B model serving 100 concurrent users at 32K context each has a KV cache footprint in the terabytes. No single GPU holds that. Production systems shard it across devices, evict cold pages to CPU memory, and batch requests that share prefixes to amortize cache cost. The inference engine is not just attention math running fast - it is a memory allocator that happens to also compute attention.

Understanding KV caching is understanding why inference infrastructure is hard. The compute was always tractable. The memory is what scales badly.