Matmul to Flash Attention
- Author: Hengjie Wang
- Date: August 6, 2024
In this article, we describe how the implementation of Flash Attention can be
understood within the context of fast matrix multiplication for Ampere
hardware. By taking advantage of new asynchronous data transfer instructions,
data transfer overhead can be hidden by the computation within a single thread
block through pipelined execution. Compared to matrix multiplication,
Flash Attention addresses the global data dependence incurred by softmax via
tiling and running statistics.
How to write a fast matmul on Ampere
We begin with a discussion about how to write a fast matmul kernel on Ampere
hardware, drawing from the very good overview, "How to Optimize a CUDA Matmul
Kernel for cuBLAS-like Performance: a
Worklog," by Simon
Boehm.
Baseline Shared Memory Matmul
We begin with a baseline implementation of shared memory matmul. The outer
loop of the multiplication advances the address of the row-major matrix A and
the column-major matrix B by a tile size until the final matrix C is fully
calculated.

Baseline implementation of shared memory matmul. Source Boehm, Simon.
Each thread block loads tiles from A and B into shared memory and
multiplies the sub-matrices.
Why is it slow?
This baseline implementation has two key blockers to performance improvements.
- We are not using tensor cores.
- The barrier prevents us from overlapping data transfer and computation.

Strict serial ordering of load, store, and accumulation operations.
We added tensor core matrix multiply-accumulate (mma). As a consequence of the barrier synchronization, four sets of instructions have to be executed in strict serial order for each thread block.
- LDG: load from global memory to registers
- STS: store registers to share memory
- LDSM: load from shared memory to register for tensor core accumulation.
- MMA: tensor core accumulation using register values.
Each thread block wastes computation resource while transferring data, and vice versa.
Overlap computation and global memory access
To address these performance issues, the Nvidia Ampere architecture introduced new asynchronous data transfer instruction, LDGSTS.
- The synchronous load from global memory to shared memory has two implicit steps:
LDG, registers, [global_memory_address]STS, [shared_memory_address], registers- In comparison, asynchronous transfer is one instruction and does not use registers.
LDGSTS [shared_memory_address], [global_memory_address]
This design relieves register pressure and allows for better overlap with computation.
How do we overlap?
The key insights are to:
- Maintain multiple buffers to resolve the RAW, WAR data dependence.
- Pipeline computation and data transfer.
Example of 3 pipeline stages:

Pre-fetching data concurrently with accumulation operation.
Overlap computation and shared memory access
Similar to overlap global memory access, we pipeline LDSM and MMA. Two stages are typically good enough in practice.
2-stages pipeline (double-buffer) for shared memory transfer and computation.
3-stages pipeline for global memory transfer and computation.

Pipelines for global and shared memory transfer and computation.
Note how within both 2-stage and 3-stage pipelines, the global memory transfer
to shared memory (LDGSTS) for future computation happens concurrently with
the loading from shared memory (LDSM) and accumulation (MMA) operations.
The following figure illustrates the pipeline execution with matrix tiling and
architecture hierarchy.

Matrix view of fast multiplication. Source cutlass.
Split-K
Recap: we partition M and N dimension to create sub-matrices for thread blocks.

Recap - baseline implementation of shared memory matmul. Source Boehm, Simon.
What happens when M and N are small? E.g. M = 64, N = 3072, K = 3072, and tile_size = 128 ==> 24 thread blocks. We are using < 25% SMs on A100 (108 in total).
We need to partition the K dimension to create more tasks.
E.g. M = 64, N = 3072, K = 3072, split K into 2 partitions
Flash Attention
With these optimizations in mind, we can now consider the implementation of Flash Attention.
Multi-Head Attention Block
We begin with a brief overview of multi-head attention:
Q, K, and V in our stack have shape [B, S, H, D], where:
B: batch sizeS: sequence/context length (matmul dim)H: number of heads (similar toB)D: depth (matmul dim)

Multi-head attention block
For example, with Replit-3B, assuming an input sequence length = 1000.
- context encoding,
Q: [B, 1000, 24, 128],KandV: [B, 1000, 8, 128],P: [B, 24, 1000, 1000]- i-th token generation,
Q: [B, 1, 24, 128],KandV: [B, 1000 + i, 8, 128],P: [B, 24, 1, 1000+i]
Note: sequence length (for Q) and context length for (K, V) are two separate dims.
The above materializes a S x S intermediate matrix (for context encoding), which introduces huge memory usage and traffic for long sequences.
Flash Attention Algorithm
Flash Attention is an algorithm that optimizes the computation of multi-head
attention by addressing the memory bottleneck that occurs when computing the
attention matrix P (of size S×S for context encoding, where S is sequence
length). Instead of materializing the entire attention matrix in memory at once,
Flash Attention uses a tiling approach where it processes small chunks of queries,
keys, and values sequentially, computing attention scores and applying softmax
in an "online" manner within each tile.
The key innovation is the online softmax technique that maintains running
statistics (row maximum and row sum) across tiles, allowing it to compute
numerically stable softmax without needing the entire row of attention
scores. This approach delivers significant performance gains by dramatically
reducing memory usage and memory traffic—eliminating the need to store the
large S×S intermediate matrix that grows quadratically with sequence
length—while maintaining mathematical equivalence to the standard attention
computation.

Flash attention
One of the primary challenges for implementing flash attention is how to compute a numerically stable softmax for a small tile.
Online Softmax: break down softmax and blend it with matmul
To address this challenge, for each tile of keys/values, Flash Attention
computes attention scores Q×K', then applies "online softmax" - a technique
that maintains running statistics (rowmax, rowsum) to compute numerically
stable softmax across tiles without needing the full attention matrix.
The key insight is that when a new tile produces a larger maximum value, it corrects both the running sum and the previous output using an exponential correction factor, ensuring the final result is mathematically equivalent to computing softmax over the entire sequence at once. This allows attention to be computed with constant memory usage regardless of sequence length.
Creating task for thread blocks:
- Batch and Head dims are independent. (Recap
Q: [B, S, H, D]) - Partition sequence length i.e. each thread block updates a
Qtile.
What about token generation i.e. sequence length = 1?
E.g. sequence_length = 1, batch_size = 1, num_heads = 24 -> 24 thread
blocks. We are using < 25% SMs on A100 (108 in total).
The above is mitigated by batch_size > 1. Yet, we’re faced with vector matrix
multiplication than matmul, which calls out for additional optimizations.
Flash Decoding for Token generation
- Use more optimized vector matrix multiplication and flat matmul (out-of-scope).
- Split-K
PyTorch has a super nice animation for flash decoding:

Flash decoding. Source PyTorch.
Flash Attention 3
Hopper GPU introduces asynchronous warp group mma instructions (wgmma). In
addition to overlapping computation and data transfer, FA3 further pipelines
wgmma and softmax to overlap tensor core and cuda core computation.

Flash attention 3