Multi-Head Flash Attention
- Author: Chris Elrod
- Date: May 7, 2025
This document describes the implementation of Multi-Head Attention (MHA) using Flash Attention 3.
Background
The self-attention mechanism is defined as:
Where Q, K, and V are the set of queries, keys, and values for that
attention head.
Multi-head attention extends this, adding the parameters q_heads and
kv_heads, where q_heads % kv_heads == 0. Let group = q_heads // kv_heads.
Then, we have:
Thus, we can index arrays using only q_head. We additionally have a
batch_idx, meaning the operation we want to perform is:
Thus, we must essentially do batch_size * num_q_head number of attention
evaluations, although we only need to load batch_size * num_kv_head unique
K and V matrices.
We can view Q, K, and V as 4-D (ragged) arrays:
Indexing with batch_idx and q_head_idx we:
- Multiply a
seq_lenxdepthmatrixQwith adepthxnum_keysmatrixK'. - Evaluate row-wise
softmaxof theseq_lenxnum_keysoutput matrix. - Multiply the
seq_lenxnum_keysmatrix with thenum_keysxdepthmatrixV.
Flash Attention 2
This naive algorithm is costly in terms of data movement. depth tends to be
small, e.g. 128, while both seq_len and num_keys can be large, e.g. up to
8192 and 119132 respectively in llama3.3.70b. Thus, materializing an 8192
x 119132 matrix would impose a high memory bandwidth cost that the reduction
over 128 elements of depth cannot hide.
The innovation of flash attention is to avoid materializing this array, holding
the output in registers, and performing an online softmax.
where a is the rowwise maximum value. For a 32-bit values of x >= 88.72284,
exp(x)=Inf. To prevent overflow and preserve numerical accuracy, the
subtraction guarantees that the largest exponential value is 1.0.
The online algorithm allows us to split this into batches. First, we avoid applying the denominator until the end, which we apply to the final output array.
Thus we focus only on updating the numerator, as well as the outputs computed
from previous tiles. Let b be the old maximum index from prior batches.
To update old values, we simply scale them by the
correction factor exp(S[i,b]-S[i,a]).
This requires keeping track of the rowmax values through the online algorithm,
as well as the exponential sum which we use as the denominator at the end.
With this, we’re able to tile K' by columns. The Flash Attention 2 algorithm
is essentially:
Note that there is no communication across rows, therefore it is natural to
block across seq_len. Doing this, the size of all temporaries is bounded; we
pick values such that all temporaries (row_max, row_sum, S, P, and O)
can be held in registers.
This avoids materializing the large array, and reading it to and from memory.
Our only writes are the final answer at the end, and our only reads are the
inputs, Q, K, and V.
An important special case is token generation. Using a KV-cache, we can save
prior results, and future work uses seq_len=1, incrementally computing new
results.
With this, we index our arrays using kv_head_idx, operating on group rows
at a time. Otherwise, the algorithm is similar.
Further optimizations include pipelining with buffering, e.g. using
asynchronous copies of global to shared memory. This can help hide latency:
while computing one iteration, we’re copying the data used
num_pipeline_stages - 1 in advance.
Flash Attention 3
FA3 is Flash Attention 3, the flash attention algorithm specialized for the
Hopper (sm90) architecture. Hopper adds asynchronous wgmma instructions, for
computing matrix multiply @ operations asynchronously. It also adds
shared-memory barriers and dynamic register allocation/deallocation, which
allow for warp-specialization.
Warp group specialized kernels are also often referred to as “ping pong” kernels.
One warp group specializes on launching asynchronous copies to shared memory (and deallocates most of its registers).
Others perform computation. The GPU’s execution can bounce or “ping pong”
between them, while their operations asynchronously run in the background. By
running two compute warp groups, we allow their matrix multiplies to overlap
the softmax-related instructions such as exponentials, as these use separate
units on the hardware that are able to fully run in parallel.
To further optimize within a warpgroup (which is also the only option for
context decoding, where we don’t have enough rows to Q to divide it into two
warpgroups), we can pipeline the loop above:
Now, the P @ V[subset, :] wgmma instructions are capable of overlapping the
bulk of the vector instructions within a kernel, while all these operations
additionally overlap the memory transfers handled by the other warp group.
This allows us to use a great deal of the hardware’s resources in parallel.