Multi-Head Latent Attention
- Author: Shouzheng Liu
- Date: February 19, 2025
Background
Multi-head Attention (MHA)
Attached is the compute graph of an MHA block, where square nodes represent
tensors materialized in memory, and edges denote certain operations. Once Q,
K and V are prepared, they are sent to the mha kernel. K and V can be
cached for reusing during the decoding stage.

Compute graph of multi-head attention
Multi-head Latent Attention (MLA)
The key idea of MLA is to use a "latent" vector (KV_lora) to store a
compressed representation of the K and V tensors during inference. Instead
of directly computing K and V at full precision, the hidden states are
first down-projected to a much lower-dimension space and then
up-projected to the full dimension of head_dim * num_heads . K and V
share the same compressed representation KV_lora.

Original compute graph of multi-head latent attention
Rotary Position Embedding (ROPE) in MLA
Applying ROPE in MLA is somewhat non-trivial because not all elements of an attention head undergo rotary encoding.
- For
Q: ROPE is only applied to the last 64 (rope_dim) elements of each attention head, which has a total size of 192 (no_dim+rope_dim). - For
K: Instead of applying ROPE to the entire latent vector (KV_lora'), we extract the last 64 elements of each token, apply ROPE, and broadcast the results to all attention heads. Then, theKtensor is constructed by concatenating the roped part with the remaining dimensions that do not undergo rotary encoding.
When computing attention scores, the unroped elements of Q are multiplied
with the unroped elements of K. The roped elements of Q are multiplied with
the roped elements of K.
KV cache
While MLA aims to reduce computation and memory usage, its original
implementation does not actually reduce the KV cache size. This is because the
full K and V tensors are still materialized after up-projection, rather
than being stored in a compressed format throughout the process.
Optimized Attention Computation in MLA

Optimized compute graph of multi-head latent attention
We calculate attention scores using
\(p=q^{T}k\),
where:
\(q=W_{qup}Q_{lora}\) , \(k=W_{kup}K_{lora}\).
This allows us to rewrite as:
\(q=W_{kup}^TW_{qup}Q_{lora}\) and \(k=K_{lora}\)
while still keeping the results unchanged.
Similarly, instead of storing and passing the full V tensors, we can simply
reuse the compressed KV_lora tensor and apply the up-projection after the MHA
computation.
In this way, we only need to cache the KV_lora and the K_roped tensors.
(This reduces KV cache size to only 576 values per token!)
Detailed Design
New attention kernel for MLA
In the optimized MLA compute graph, the attention kernel effectively performs multi-query attention (MQA). Specifically:
- The
Qinput has a shape of[seq_len, num_heads, 576]. - The
Kinput has a shape of[seq_len, 1, 576]. - The
Vtensor is derived by reusingK, whereV = K[:, :, :512]
A few points:
- A
head_dimof 576 feels kinda big, although if there will be a performance impact is unclear. - Here, we have
num_kv_heads=1. By default we parallelize across KV heads, but with only one KV head, this would lead to severe hardware under-utilization. We can either apply split-k or parallelize across queries. In the latter case, different thread blocks will issue repeated loads of the sameK, but sinceKis quite small, this might not be a major issue. - We don’t need to reserve shared memory or registers for
VbecauseVisK.
Update KV cache
The KV cache manager needs to be updated to support models that only have K
cache. Currently, the implementation assumes the presence of both K and V
caches.
Multi-GPU
We can still distribute the workload across multiple devices by splitting across the queries heads. However, since there is only one KV head, we cannot further split it, meaning the KV cache must be duplicated on every device.
We can also explore how MQA (multi-query attention) is optimized for multi-GPU setups to see if we can borrow some ideas.