We spend so much time thinking about the math - matrix multiplications, gradients, loss functions - that we forget about the physical reality of the computers running them. The Transformer is an algorithmic masterpiece. But algorithms don't run in a theoretical void. They run on silicon. And when you scale the elegant math of attention to read entire books, it collides violently with the physical limits of hardware.
This is the story of the Memory Wall, and the brilliantly simple engineering hack that broke through it: Flash Attention (Dao et al., 2022).
The Nightmare
At the heart of the Transformer:
Beautiful math. But look at the dimensions. If is the sequence length, both and have rows. When you multiply by , the result is an matrix - the attention score between every token and every other token.
Some terrifying arithmetic for a single attention head:
- : elements. Fine.
- (a short novel): elements.
At fp16 (2 bytes per element), that intermediate matrix alone requires 20 GB of memory. For one head, one sequence, one layer.
Proving it
Memory quadruples every time you double the sequence length. At , a standard GPU throws OutOfMemoryError and crashes. The math works, but the hardware physically cannot hold the intermediate matrix.
The Hardware Reality
To understand Flash Attention, you need to understand how a GPU is actually built. It's not a single bucket of memory - it has a hierarchy:
HBM (High Bandwidth Memory): The "main memory." Large (40–80 GB on an A100), but physically distant from the compute cores. Moving data in and out is slow.
SRAM (Shared Memory): The "working memory." Sits directly next to the compute cores. Blazingly fast, but tiny - only ~20 MB per streaming multiprocessor.
When PyTorch runs naive attention, it doesn't know our intent. It just executes operations sequentially:
- Read and from slow HBM into fast SRAM
- Compute in SRAM
- Write the massive matrix back to slow HBM (bottleneck)
- Read back from HBM into SRAM (bottleneck)
- Compute softmax in SRAM
- Write the massive matrix back to HBM (bottleneck)
- Read and from HBM into SRAM (bottleneck)
- Compute output and write to HBM
The compute cores sit idle - doing zero math - while gigabytes of intermediate data slowly travel back and forth. This is called being memory-bound. The bottleneck isn't computation; it's data movement.
The Solution: Tiling
Flash Attention's premise is a masterclass in IO-awareness: what if we never write the matrix to HBM at all?
If we can compute attention scores and multiply by entirely inside the fast SRAM, we bypass the memory wall completely. But SRAM is only 20 MB - we can't fit the full matrix there either.
The solution is tiling. Break , , and into small blocks (tiles) that fit in SRAM:
- Load a block of Queries into SRAM
- Iterate over blocks of Keys and Values:
- Load a K block and V block into SRAM
- Compute a small patch of the attention matrix (fits in SRAM)
- Multiply by the V block
- Accumulate the result
- Write only the final output back to HBM
The matrix is never fully materialized - not in HBM, not in SRAM, not anywhere. Each small block is computed, used, and discarded before the next block is loaded.
The online softmax trick
There's one mathematical subtlety. Standard softmax requires the maximum value of the entire row:
When processing tiles, we don't have the full row - we only see one block at a time. Flash Attention uses online softmax (Milakov & Gimelshein, 2018): maintain a running maximum and a running sum of exponentials. As each new block is processed, rescale the previous accumulator to account for the updated maximum. The final result is mathematically identical to standard softmax, but computed incrementally.
Implementation: From PyTorch to GPU Kernels
Standard PyTorch abstracts away memory management. To control what goes into SRAM vs HBM, we need to write a GPU kernel. Triton (developed by OpenAI) lets us write GPU kernels in Python:
Look closely: we never created a variable for the matrix. The small qk block is computed, used to update acc, and overwritten on the next iteration. The full attention matrix never exists in memory.
The wrapper
The Results
Benchmarking our Triton kernel against naive PyTorch at :
Memory: Naive demands gigabytes for the intermediate matrix. Flash Attention uses near-zero additional memory - the footprint stays flat regardless of sequence length.
Speed: With compute cores no longer waiting for HBM transfers, the kernel runs 2–4× faster wall-clock time.
Modern PyTorch includes torch.nn.functional.scaled_dot_product_attention, which automatically dispatches to an optimized Flash Attention kernel under the hood. But now you know exactly what it's doing.
Flash Attention 2 and Beyond
Flash Attention 2 (Dao, 2023) improved on the original by:
- Reducing non-matmul FLOPs (the bookkeeping around online softmax)
- Better parallelism across sequence length (not just batch and heads)
- Achieving ~2× speedup over Flash Attention 1, reaching 50–73% of theoretical GPU throughput
Flash Attention 3 (Dao et al., 2024) targets the Hopper architecture (H100) specifically, exploiting:
- Asynchronous execution - overlapping SRAM-to-register data movement with computation
- FP8 tensor core support - halving memory bandwidth for compatible workloads
- Hardware-aware warp scheduling - hiding instruction latency
Each version pushes closer to the theoretical hardware limit. The algorithm hasn't changed - it's still tiling with online softmax. What changes is how precisely the implementation matches the GPU's physical capabilities.
Why This Matters
Flash Attention is a reminder of a fundamental truth: algorithms don't exist in a vacuum. The most elegant formula is useless if it chokes the memory bus of the hardware it runs on.
By understanding the physical architecture of the GPU - the hierarchy of HBM and SRAM, the cost of data movement vs computation - researchers unlocked the era of million-token context windows. The math built the Transformer, but hardware awareness allowed it to conquer the world.
The implications extend far beyond attention:
- KV cache management (PagedAttention) applies the same IO-aware thinking to inference
- Activation checkpointing trades compute for memory during training
- Quantization (INT8, INT4, FP8) reduces data movement by shrinking the data itself
- Speculative decoding overlaps draft and verification to hide latency
Every frontier of LLM efficiency is, at its core, about matching the algorithm to the machine.