Mar 26, 2026

Flash Attention - Breaking the Memory Wall

How Flash Attention bypasses GPU memory bottlenecks by never materializing the N×N attention matrix, and why understanding hardware matters as much as understanding math.

Flash Attention - Breaking the Memory Wall hero image

At a sequence length of 100,000 tokens, computing standard attention requires a single intermediate matrix of 20 GB - and that's just one head, one layer, one sequence. The GPU runs out of memory before a single forward pass completes. The Transformer's math is elegant. Its memory requirements are catastrophic.

This is the story of the Memory Wall, and the brilliantly simple engineering insight that broke through it: Flash Attention (Dao et al., 2022).

The O(N2)O(N^2) Problem

At the heart of the Transformer:

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

Beautiful math. But look at the dimensions. If NN is the sequence length, both QQ and KK have NN rows. When you multiply QQ by KTK^T, the result is an N×NN \times N matrix - the attention score between every token and every other token.

The numbers escalate fast:

  • N=1,000N = 1{,}000: 1,000×1,000=1M1{,}000 \times 1{,}000 = 1\text{M} elements. Fine.
  • N=100,000N = 100{,}000 (a short novel): 100,000×100,000=10B100{,}000 \times 100{,}000 = 10\text{B} elements.

At fp16 (2 bytes per element), that intermediate matrix alone requires 20 GB of memory. For one head, one sequence, one layer.

You can watch this play out directly:

python
import torch import torch.nn.functional as F def naive_attention(q, k, v): """Standard attention - materializes the full N×N matrix.""" scores = torch.matmul(q, k.transpose(-2, -1)) / (q.size(-1) ** 0.5) attn_weights = F.softmax(scores, dim=-1) return torch.matmul(attn_weights, v) batch, heads, head_dim = 1, 12, 64 for N in [1024, 4096, 16384]: torch.cuda.reset_peak_memory_stats() q = torch.randn(batch, heads, N, head_dim, device="cuda") k = torch.randn(batch, heads, N, head_dim, device="cuda") v = torch.randn(batch, heads, N, head_dim, device="cuda") out = naive_attention(q, k, v) peak_mb = torch.cuda.max_memory_allocated() / (1024 ** 2) print(f"N={N:5d} | Peak Memory: {peak_mb:,.0f} MB")

Memory quadruples every time you double the sequence length. At N=32,768N = 32{,}768, a standard GPU throws OutOfMemoryError and crashes. The math works, but the hardware physically cannot hold the intermediate matrix.

This curve tells you something important: the problem isn't the final output. The output matrix has size N×dN \times d, which is linear in sequence length. The problem is the intermediate N×NN \times N matrix - the thing we compute and immediately throw away. What if we never materialized it?

The Hardware Reality

To understand why that question matters, you need to understand how a GPU is actually structured. It isn't a single bucket of memory - it has a hierarchy:

HBM (High Bandwidth Memory): The main memory. Large (40-80 GB on an A100), but physically distant from the compute cores. Moving data in and out is slow.

SRAM (Shared Memory): The working memory. Sits directly next to the compute cores. Blazingly fast, but tiny - only about 20 MB per streaming multiprocessor.

When PyTorch runs naive attention, it executes operations one at a time. Each operation reads inputs from HBM and writes outputs back to HBM:

  1. Read QQ and KK from slow HBM into fast SRAM
  2. Compute S=QKTS = QK^T in SRAM
  3. Write the massive N×NN \times N matrix SS back to slow HBM (bottleneck)
  4. Read SS back from HBM into SRAM (bottleneck)
  5. Compute softmax P=softmax(S)P = \text{softmax}(S) in SRAM
  6. Write the massive PP matrix back to HBM (bottleneck)
  7. Read PP and VV from HBM into SRAM (bottleneck)
  8. Compute output and write to HBM

The compute cores sit idle while gigabytes of intermediate data travel back and forth across a slow bus. This is called being memory-bound: the bottleneck isn't computation, it's data movement. The GPU's arithmetic units could compute attention much faster - they're just waiting for data that never arrives quickly enough.

Profiling a naive attention forward pass confirms this. The GPU's compute utilization is embarrassingly low - most wall-clock time is spent on memory transfers, not on the matrix multiplications we actually care about.

The fix isn't to make the transfers faster. It's to do fewer of them.

The Solution: Tiling

Flash Attention's core idea is deceptively simple: what if we never write the N×NN \times N matrix to HBM at all?

If we can compute attention scores, apply softmax, and multiply by VV entirely inside the fast SRAM, we bypass all those round trips. The problem is that SRAM is only 20 MB. The full N×NN \times N matrix won't fit there either.

The answer is tiling. Break QQ, KK, and VV into small blocks that do fit in SRAM, and process them one tile at a time:

  1. Load a block of Queries into SRAM
  2. Iterate over blocks of Keys and Values:
    • Load a K block and V block into SRAM
    • Compute a small patch of the attention matrix (fits in SRAM)
    • Multiply by the V block
    • Accumulate the result
  3. Write only the final output back to HBM

The N×NN \times N matrix is never fully materialized - not in HBM, not in SRAM, not anywhere. Each small block is computed, used, and discarded before the next block is loaded.

The online softmax trick

There's one mathematical obstacle. Standard softmax requires the maximum value of the entire row before computing anything:

softmax(xi)=eximax(x)jexjmax(x)\text{softmax}(x_i) = \frac{e^{x_i - \max(\mathbf{x})}}{\sum_j e^{x_j - \max(\mathbf{x})}}

When processing tiles, we don't have the full row - we only see one block at a time. We can't compute the global maximum until we've seen everything. But by then, we've already discarded the earlier blocks.

Flash Attention uses online softmax (Milakov & Gimelshein, 2018) to resolve this: maintain a running maximum mm and a running sum of exponentials dd. As each new block arrives, rescale the previous accumulator to account for the updated maximum, then fold in the new block.

python
def online_softmax_step(m_prev, d_prev, acc_prev, x_new, v_new): """ Process one new tile of attention scores (x_new) and values (v_new), updating the running softmax state without needing the full row. """ # New running maximum after seeing this tile m_new = torch.maximum(m_prev, x_new.max(dim=-1, keepdim=True).values) # Rescale previous accumulator: earlier exponentials used the old max, # so we multiply by exp(m_prev - m_new) to correct them scale = torch.exp(m_prev - m_new) d_new = scale * d_prev + torch.exp(x_new - m_new).sum(dim=-1, keepdim=True) # Update the weighted value accumulator with the same correction acc_new = scale * acc_prev + torch.exp(x_new - m_new) @ v_new return m_new, d_new, acc_new # Final output: divide accumulator by the normalization constant # output = acc / d -> mathematically identical to standard softmax

The final result is mathematically identical to standard softmax - the paper includes a formal proof. The trick is that we never need all NN scores at once. We process them tile by tile, correcting our running estimate each time. This is what makes tiling possible: it turns a two-pass algorithm (find max, then normalize) into a single pass.

Implementation: From PyTorch to GPU Kernels

Standard PyTorch abstracts away memory management. To control what goes into SRAM vs HBM, we need to write a GPU kernel. Triton (developed by OpenAI) lets us write GPU kernels in Python:

python
import triton import triton.language as tl @triton.jit def flash_attention_kernel( Q_ptr, K_ptr, V_ptr, Out_ptr, stride_qb, stride_qh, stride_qs, stride_qd, stride_kb, stride_kh, stride_ks, stride_kd, stride_vb, stride_vh, stride_vs, stride_vd, stride_ob, stride_oh, stride_os, stride_od, num_heads, seq_len, head_dim, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): # Which block of Queries does this thread handle? batch_idx = tl.program_id(0) head_idx = tl.program_id(1) block_m_idx = tl.program_id(2) # Calculate memory offsets for this batch/head q_off = batch_idx * stride_qb + head_idx * stride_qh k_off = batch_idx * stride_kb + head_idx * stride_kh v_off = batch_idx * stride_vb + head_idx * stride_vh o_off = batch_idx * stride_ob + head_idx * stride_oh # Load one block of Queries into SRAM offs_m = block_m_idx * BLOCK_M + tl.arange(0, BLOCK_M) offs_d = tl.arange(0, head_dim) q_ptrs = Q_ptr + q_off + offs_m[:, None] * stride_qs + offs_d[None, :] * stride_qd q_block = tl.load(q_ptrs) # Accumulator lives in SRAM - never touches HBM acc = tl.zeros((BLOCK_M, head_dim), dtype=tl.float32) # Iterate over blocks of K and V for block_n_idx in range(0, seq_len // BLOCK_N): offs_n = block_n_idx * BLOCK_N + tl.arange(0, BLOCK_N) # Load K and V tiles from HBM to SRAM k_ptrs = K_ptr + k_off + offs_n[:, None] * stride_ks + offs_d[None, :] * stride_kd v_ptrs = V_ptr + v_off + offs_n[:, None] * stride_vs + offs_d[None, :] * stride_vd k_block = tl.load(k_ptrs) v_block = tl.load(v_ptrs) # Compute attention for this tile - entirely in SRAM qk = tl.dot(q_block, tl.trans(k_block)) / (head_dim ** 0.5) attn = tl.softmax(qk, axis=1) acc += tl.dot(attn, v_block) # Write final output from SRAM back to HBM o_ptrs = Out_ptr + o_off + offs_m[:, None] * stride_os + offs_d[None, :] * stride_od tl.store(o_ptrs, acc)

Look closely: there is no variable for the N×NN \times N matrix. The small qk block is computed, used to update acc, and overwritten on the next iteration. The full attention matrix never exists in memory.

The wrapper

python
def triton_flash_attention(q, k, v): batch, heads, seq_len, head_dim = q.shape BLOCK_M, BLOCK_N = 64, 64 out = torch.empty_like(q) grid = (batch, heads, triton.cdiv(seq_len, BLOCK_M)) flash_attention_kernel[grid]( q, k, v, out, q.stride(0), q.stride(1), q.stride(2), q.stride(3), k.stride(0), k.stride(1), k.stride(2), k.stride(3), v.stride(0), v.stride(1), v.stride(2), v.stride(3), out.stride(0), out.stride(1), out.stride(2), out.stride(3), heads, seq_len, head_dim, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, ) return out

Benchmarking the Difference

At N=16,384N = 16{,}384, the gap becomes concrete:

Memory: Naive attention demands gigabytes for the intermediate N×NN \times N matrix. Flash Attention's memory footprint stays nearly flat regardless of sequence length - only the final output grows with NN.

Speed: With compute cores no longer waiting for HBM round trips, the kernel runs 2-4x faster wall-clock time on the same hardware. The computation hasn't changed. The data movement has.

python
import time def benchmark(fn, q, k, v, warmup=5, iters=20): for _ in range(warmup): fn(q, k, v) torch.cuda.synchronize() t0 = time.perf_counter() for _ in range(iters): fn(q, k, v) torch.cuda.synchronize() return (time.perf_counter() - t0) / iters * 1000 # ms N = 16384 q = torch.randn(1, 12, N, 64, device="cuda", dtype=torch.float16) k, v = torch.randn_like(q), torch.randn_like(q) naive_ms = benchmark(naive_attention, q, k, v) flash_ms = benchmark(triton_flash_attention, q, k, v) print(f"Naive: {naive_ms:.1f} ms") print(f"Flash: {flash_ms:.1f} ms ({naive_ms / flash_ms:.1f}x faster)")

Modern PyTorch exposes this directly via torch.nn.functional.scaled_dot_product_attention, which automatically dispatches to an optimized Flash Attention kernel. But now you know exactly what it's doing and why.

Flash Attention 2 and 3

The original paper proved the idea. The follow-up papers pushed the implementation toward the GPU's physical limits.

Flash Attention 2 (Dao, 2023) achieved roughly 2x speedup over Flash Attention 1 by reducing the non-matmul overhead that had been accumulating around the online softmax bookkeeping. It also parallelized across the sequence length dimension - not just across batches and heads - which matters when batch sizes are small but sequences are long. On an A100, it reaches 50-73% of theoretical peak throughput for attention.

Flash Attention 3 (Dao et al., 2024) targets the Hopper architecture (H100) specifically. The H100 introduced dedicated hardware for asynchronous data movement - you can start loading the next tile from HBM while the compute cores are still working on the current one. Flash Attention 3 exploits this by:

  • Overlapping SRAM-to-register data movement with computation
  • Using FP8 tensor cores, halving memory bandwidth for compatible workloads
  • Scheduling warps to hide instruction latency behind useful work

Each version pushes closer to the theoretical hardware limit. The algorithm hasn't changed - it's still tiling with online softmax. What changes is how precisely the implementation matches the GPU's physical capabilities.

This is a pattern worth recognizing: the most impactful systems work happens not by inventing new algorithms, but by mapping existing ones more faithfully onto hardware reality. Flash Attention didn't change the math. It changed the execution order, and that was enough to unlock million-token context windows.

The same IO-aware thinking now runs through nearly every corner of efficient LLM inference: KV cache management (PagedAttention), activation checkpointing, quantization strategies that reduce data volume before it hits the memory bus. Once you see the memory hierarchy clearly, you start seeing the same optimization opportunity everywhere.