At a sequence length of 100,000 tokens, computing standard attention requires a single intermediate matrix of 20 GB - and that's just one head, one layer, one sequence. The GPU runs out of memory before a single forward pass completes. The Transformer's math is elegant. Its memory requirements are catastrophic.
This is the story of the Memory Wall, and the brilliantly simple engineering insight that broke through it: Flash Attention (Dao et al., 2022).
The Problem
At the heart of the Transformer:
Beautiful math. But look at the dimensions. If is the sequence length, both and have rows. When you multiply by , the result is an matrix - the attention score between every token and every other token.
The numbers escalate fast:
- : elements. Fine.
- (a short novel): elements.
At fp16 (2 bytes per element), that intermediate matrix alone requires 20 GB of memory. For one head, one sequence, one layer.
You can watch this play out directly:
Memory quadruples every time you double the sequence length. At , a standard GPU throws OutOfMemoryError and crashes. The math works, but the hardware physically cannot hold the intermediate matrix.
This curve tells you something important: the problem isn't the final output. The output matrix has size , which is linear in sequence length. The problem is the intermediate matrix - the thing we compute and immediately throw away. What if we never materialized it?
The Hardware Reality
To understand why that question matters, you need to understand how a GPU is actually structured. It isn't a single bucket of memory - it has a hierarchy:
HBM (High Bandwidth Memory): The main memory. Large (40-80 GB on an A100), but physically distant from the compute cores. Moving data in and out is slow.
SRAM (Shared Memory): The working memory. Sits directly next to the compute cores. Blazingly fast, but tiny - only about 20 MB per streaming multiprocessor.
When PyTorch runs naive attention, it executes operations one at a time. Each operation reads inputs from HBM and writes outputs back to HBM:
- Read and from slow HBM into fast SRAM
- Compute in SRAM
- Write the massive matrix back to slow HBM (bottleneck)
- Read back from HBM into SRAM (bottleneck)
- Compute softmax in SRAM
- Write the massive matrix back to HBM (bottleneck)
- Read and from HBM into SRAM (bottleneck)
- Compute output and write to HBM
The compute cores sit idle while gigabytes of intermediate data travel back and forth across a slow bus. This is called being memory-bound: the bottleneck isn't computation, it's data movement. The GPU's arithmetic units could compute attention much faster - they're just waiting for data that never arrives quickly enough.
Profiling a naive attention forward pass confirms this. The GPU's compute utilization is embarrassingly low - most wall-clock time is spent on memory transfers, not on the matrix multiplications we actually care about.
The fix isn't to make the transfers faster. It's to do fewer of them.
The Solution: Tiling
Flash Attention's core idea is deceptively simple: what if we never write the matrix to HBM at all?
If we can compute attention scores, apply softmax, and multiply by entirely inside the fast SRAM, we bypass all those round trips. The problem is that SRAM is only 20 MB. The full matrix won't fit there either.
The answer is tiling. Break , , and into small blocks that do fit in SRAM, and process them one tile at a time:
- Load a block of Queries into SRAM
- Iterate over blocks of Keys and Values:
- Load a K block and V block into SRAM
- Compute a small patch of the attention matrix (fits in SRAM)
- Multiply by the V block
- Accumulate the result
- Write only the final output back to HBM
The matrix is never fully materialized - not in HBM, not in SRAM, not anywhere. Each small block is computed, used, and discarded before the next block is loaded.
The online softmax trick
There's one mathematical obstacle. Standard softmax requires the maximum value of the entire row before computing anything:
When processing tiles, we don't have the full row - we only see one block at a time. We can't compute the global maximum until we've seen everything. But by then, we've already discarded the earlier blocks.
Flash Attention uses online softmax (Milakov & Gimelshein, 2018) to resolve this: maintain a running maximum and a running sum of exponentials . As each new block arrives, rescale the previous accumulator to account for the updated maximum, then fold in the new block.
The final result is mathematically identical to standard softmax - the paper includes a formal proof. The trick is that we never need all scores at once. We process them tile by tile, correcting our running estimate each time. This is what makes tiling possible: it turns a two-pass algorithm (find max, then normalize) into a single pass.
Implementation: From PyTorch to GPU Kernels
Standard PyTorch abstracts away memory management. To control what goes into SRAM vs HBM, we need to write a GPU kernel. Triton (developed by OpenAI) lets us write GPU kernels in Python:
Look closely: there is no variable for the matrix. The small qk block is computed, used to update acc, and overwritten on the next iteration. The full attention matrix never exists in memory.
The wrapper
Benchmarking the Difference
At , the gap becomes concrete:
Memory: Naive attention demands gigabytes for the intermediate matrix. Flash Attention's memory footprint stays nearly flat regardless of sequence length - only the final output grows with .
Speed: With compute cores no longer waiting for HBM round trips, the kernel runs 2-4x faster wall-clock time on the same hardware. The computation hasn't changed. The data movement has.
Modern PyTorch exposes this directly via torch.nn.functional.scaled_dot_product_attention, which automatically dispatches to an optimized Flash Attention kernel. But now you know exactly what it's doing and why.
Flash Attention 2 and 3
The original paper proved the idea. The follow-up papers pushed the implementation toward the GPU's physical limits.
Flash Attention 2 (Dao, 2023) achieved roughly 2x speedup over Flash Attention 1 by reducing the non-matmul overhead that had been accumulating around the online softmax bookkeeping. It also parallelized across the sequence length dimension - not just across batches and heads - which matters when batch sizes are small but sequences are long. On an A100, it reaches 50-73% of theoretical peak throughput for attention.
Flash Attention 3 (Dao et al., 2024) targets the Hopper architecture (H100) specifically. The H100 introduced dedicated hardware for asynchronous data movement - you can start loading the next tile from HBM while the compute cores are still working on the current one. Flash Attention 3 exploits this by:
- Overlapping SRAM-to-register data movement with computation
- Using FP8 tensor cores, halving memory bandwidth for compatible workloads
- Scheduling warps to hide instruction latency behind useful work
Each version pushes closer to the theoretical hardware limit. The algorithm hasn't changed - it's still tiling with online softmax. What changes is how precisely the implementation matches the GPU's physical capabilities.
This is a pattern worth recognizing: the most impactful systems work happens not by inventing new algorithms, but by mapping existing ones more faithfully onto hardware reality. Flash Attention didn't change the math. It changed the execution order, and that was enough to unlock million-token context windows.
The same IO-aware thinking now runs through nearly every corner of efficient LLM inference: KV cache management (PagedAttention), activation checkpointing, quantization strategies that reduce data volume before it hits the memory bus. Once you see the memory hierarchy clearly, you start seeing the same optimization opportunity everywhere.
