Mar 26, 2026

Flash Attention - Breaking the Memory Wall

How Flash Attention bypasses GPU memory bottlenecks by never materializing the N×N attention matrix, and why understanding hardware matters as much as understanding math.

We spend so much time thinking about the math - matrix multiplications, gradients, loss functions - that we forget about the physical reality of the computers running them. The Transformer is an algorithmic masterpiece. But algorithms don't run in a theoretical void. They run on silicon. And when you scale the elegant math of attention to read entire books, it collides violently with the physical limits of hardware.

This is the story of the Memory Wall, and the brilliantly simple engineering hack that broke through it: Flash Attention (Dao et al., 2022).

The O(N2)O(N^2) Nightmare

At the heart of the Transformer:

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

Beautiful math. But look at the dimensions. If NN is the sequence length, both QQ and KK have NN rows. When you multiply QQ by KTK^T, the result is an N×NN \times N matrix - the attention score between every token and every other token.

Some terrifying arithmetic for a single attention head:

  • N=1,000N = 1{,}000: 1,000×1,000=1M1{,}000 \times 1{,}000 = 1\text{M} elements. Fine.
  • N=100,000N = 100{,}000 (a short novel): 100,000×100,000=10B100{,}000 \times 100{,}000 = 10\text{B} elements.

At fp16 (2 bytes per element), that intermediate matrix alone requires 20 GB of memory. For one head, one sequence, one layer.

Proving it

python
import torch import torch.nn.functional as F def naive_attention(q, k, v): """Standard attention - materializes the full N×N matrix.""" scores = torch.matmul(q, k.transpose(-2, -1)) / (q.size(-1) ** 0.5) attn_weights = F.softmax(scores, dim=-1) return torch.matmul(attn_weights, v) batch, heads, head_dim = 1, 12, 64 for N in [1024, 4096, 16384]: torch.cuda.reset_peak_memory_stats() q = torch.randn(batch, heads, N, head_dim, device="cuda") k = torch.randn(batch, heads, N, head_dim, device="cuda") v = torch.randn(batch, heads, N, head_dim, device="cuda") out = naive_attention(q, k, v) peak_mb = torch.cuda.max_memory_allocated() / (1024 ** 2) print(f"N={N:5d} | Peak Memory: {peak_mb:,.0f} MB")

Memory quadruples every time you double the sequence length. At N=32,768N = 32{,}768, a standard GPU throws OutOfMemoryError and crashes. The math works, but the hardware physically cannot hold the intermediate matrix.

The Hardware Reality

To understand Flash Attention, you need to understand how a GPU is actually built. It's not a single bucket of memory - it has a hierarchy:

HBM (High Bandwidth Memory): The "main memory." Large (40–80 GB on an A100), but physically distant from the compute cores. Moving data in and out is slow.

SRAM (Shared Memory): The "working memory." Sits directly next to the compute cores. Blazingly fast, but tiny - only ~20 MB per streaming multiprocessor.

When PyTorch runs naive attention, it doesn't know our intent. It just executes operations sequentially:

  1. Read QQ and KK from slow HBM into fast SRAM
  2. Compute S=QKTS = QK^T in SRAM
  3. Write the massive N×NN \times N matrix SS back to slow HBM (bottleneck)
  4. Read SS back from HBM into SRAM (bottleneck)
  5. Compute softmax P=softmax(S)P = \text{softmax}(S) in SRAM
  6. Write the massive PP matrix back to HBM (bottleneck)
  7. Read PP and VV from HBM into SRAM (bottleneck)
  8. Compute output and write to HBM

The compute cores sit idle - doing zero math - while gigabytes of intermediate data slowly travel back and forth. This is called being memory-bound. The bottleneck isn't computation; it's data movement.

The Solution: Tiling

Flash Attention's premise is a masterclass in IO-awareness: what if we never write the N×NN \times N matrix to HBM at all?

If we can compute attention scores and multiply by VV entirely inside the fast SRAM, we bypass the memory wall completely. But SRAM is only 20 MB - we can't fit the full N×NN \times N matrix there either.

The solution is tiling. Break QQ, KK, and VV into small blocks (tiles) that fit in SRAM:

  1. Load a block of Queries into SRAM
  2. Iterate over blocks of Keys and Values:
    • Load a K block and V block into SRAM
    • Compute a small patch of the attention matrix (fits in SRAM)
    • Multiply by the V block
    • Accumulate the result
  3. Write only the final output back to HBM

The N×NN \times N matrix is never fully materialized - not in HBM, not in SRAM, not anywhere. Each small block is computed, used, and discarded before the next block is loaded.

The online softmax trick

There's one mathematical subtlety. Standard softmax requires the maximum value of the entire row:

softmax(xi)=eximax(x)jexjmax(x)\text{softmax}(x_i) = \frac{e^{x_i - \max(\mathbf{x})}}{\sum_j e^{x_j - \max(\mathbf{x})}}

When processing tiles, we don't have the full row - we only see one block at a time. Flash Attention uses online softmax (Milakov & Gimelshein, 2018): maintain a running maximum and a running sum of exponentials. As each new block is processed, rescale the previous accumulator to account for the updated maximum. The final result is mathematically identical to standard softmax, but computed incrementally.

Implementation: From PyTorch to GPU Kernels

Standard PyTorch abstracts away memory management. To control what goes into SRAM vs HBM, we need to write a GPU kernel. Triton (developed by OpenAI) lets us write GPU kernels in Python:

python
import triton import triton.language as tl @triton.jit def flash_attention_kernel( Q_ptr, K_ptr, V_ptr, Out_ptr, stride_qb, stride_qh, stride_qs, stride_qd, stride_kb, stride_kh, stride_ks, stride_kd, stride_vb, stride_vh, stride_vs, stride_vd, stride_ob, stride_oh, stride_os, stride_od, num_heads, seq_len, head_dim, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): # Which block of Queries does this thread handle? batch_idx = tl.program_id(0) head_idx = tl.program_id(1) block_m_idx = tl.program_id(2) # Calculate memory offsets for this batch/head q_off = batch_idx * stride_qb + head_idx * stride_qh k_off = batch_idx * stride_kb + head_idx * stride_kh v_off = batch_idx * stride_vb + head_idx * stride_vh o_off = batch_idx * stride_ob + head_idx * stride_oh # Load one block of Queries into SRAM offs_m = block_m_idx * BLOCK_M + tl.arange(0, BLOCK_M) offs_d = tl.arange(0, head_dim) q_ptrs = Q_ptr + q_off + offs_m[:, None] * stride_qs + offs_d[None, :] * stride_qd q_block = tl.load(q_ptrs) # Accumulator lives in SRAM - never touches HBM acc = tl.zeros((BLOCK_M, head_dim), dtype=tl.float32) # Iterate over blocks of K and V for block_n_idx in range(0, seq_len // BLOCK_N): offs_n = block_n_idx * BLOCK_N + tl.arange(0, BLOCK_N) # Load K and V tiles from HBM to SRAM k_ptrs = K_ptr + k_off + offs_n[:, None] * stride_ks + offs_d[None, :] * stride_kd v_ptrs = V_ptr + v_off + offs_n[:, None] * stride_vs + offs_d[None, :] * stride_vd k_block = tl.load(k_ptrs) v_block = tl.load(v_ptrs) # Compute attention for this tile - entirely in SRAM qk = tl.dot(q_block, tl.trans(k_block)) / (head_dim ** 0.5) attn = tl.softmax(qk, axis=1) acc += tl.dot(attn, v_block) # Write final output from SRAM back to HBM o_ptrs = Out_ptr + o_off + offs_m[:, None] * stride_os + offs_d[None, :] * stride_od tl.store(o_ptrs, acc)

Look closely: we never created a variable for the N×NN \times N matrix. The small qk block is computed, used to update acc, and overwritten on the next iteration. The full attention matrix never exists in memory.

The wrapper

python
def triton_flash_attention(q, k, v): batch, heads, seq_len, head_dim = q.shape BLOCK_M, BLOCK_N = 64, 64 out = torch.empty_like(q) grid = (batch, heads, triton.cdiv(seq_len, BLOCK_M)) flash_attention_kernel[grid]( q, k, v, out, q.stride(0), q.stride(1), q.stride(2), q.stride(3), k.stride(0), k.stride(1), k.stride(2), k.stride(3), v.stride(0), v.stride(1), v.stride(2), v.stride(3), out.stride(0), out.stride(1), out.stride(2), out.stride(3), heads, seq_len, head_dim, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, ) return out

The Results

Benchmarking our Triton kernel against naive PyTorch at N=16,384N = 16{,}384:

Memory: Naive demands gigabytes for the intermediate N×NN \times N matrix. Flash Attention uses near-zero additional memory - the footprint stays flat regardless of sequence length.

Speed: With compute cores no longer waiting for HBM transfers, the kernel runs 2–4× faster wall-clock time.

Modern PyTorch includes torch.nn.functional.scaled_dot_product_attention, which automatically dispatches to an optimized Flash Attention kernel under the hood. But now you know exactly what it's doing.

Flash Attention 2 and Beyond

Flash Attention 2 (Dao, 2023) improved on the original by:

  • Reducing non-matmul FLOPs (the bookkeeping around online softmax)
  • Better parallelism across sequence length (not just batch and heads)
  • Achieving ~2× speedup over Flash Attention 1, reaching 50–73% of theoretical GPU throughput

Flash Attention 3 (Dao et al., 2024) targets the Hopper architecture (H100) specifically, exploiting:

  • Asynchronous execution - overlapping SRAM-to-register data movement with computation
  • FP8 tensor core support - halving memory bandwidth for compatible workloads
  • Hardware-aware warp scheduling - hiding instruction latency

Each version pushes closer to the theoretical hardware limit. The algorithm hasn't changed - it's still tiling with online softmax. What changes is how precisely the implementation matches the GPU's physical capabilities.

Why This Matters

Flash Attention is a reminder of a fundamental truth: algorithms don't exist in a vacuum. The most elegant formula is useless if it chokes the memory bus of the hardware it runs on.

By understanding the physical architecture of the GPU - the hierarchy of HBM and SRAM, the cost of data movement vs computation - researchers unlocked the era of million-token context windows. The math built the Transformer, but hardware awareness allowed it to conquer the world.

The implications extend far beyond attention:

  • KV cache management (PagedAttention) applies the same IO-aware thinking to inference
  • Activation checkpointing trades compute for memory during training
  • Quantization (INT8, INT4, FP8) reduces data movement by shrinking the data itself
  • Speculative decoding overlaps draft and verification to hide latency

Every frontier of LLM efficiency is, at its core, about matching the algorithm to the machine.