Demystifying GQA — Grouped Query Attention for Efficient LLM Pre-training
The variant of multi-head attention powering LLMs like LLaMA-2, Mistral7B, etc.
In the previous article on training large-scale models, we looked at LoRA. In this article, we will examine another strategy adopted by different large language models for efficient training — Grouped Query Attention (GQA). In short, Grouped Query Attention (GQA) is a generalization of multi-head attention (MHA) and multi-query attention (MQA) — with each of them being a special case of GQA. Therefore, before we dive into Grouped Query Attention, let’s revisit traditional multi-head attention proposed by Vaswani et al. in the seminal “Attention is All You Need” paper. Following that, we will explore Multi-query attention and how it addresses challenges with MHA. Finally, we will answer the questions “What is GQA?” and “How does it give us the best of both worlds?”
Multi-head attention is a critical component of Transformer models, enabling them to efficiently process and understand complex sequences in tasks like language translation, summarization, and more. To grasp its intricacies, we must delve into the mathematical underpinnings and understand how multiple heads in the attention mechanism function.
The basic attention mechanism computes a weighted sum of values, with weights dependent on a query and a set of keys. Mathematically, this is expressed as:
This is referred to as scaled dot product attention. In this equation, Q (Query) and K (Key) are matrices representing the queries and keys. V (Value) is the matrix for values. “d_k” is the dimensionality of keys, which is used for scaling.
Expanding with Multi-Head Attention (MHA)
Multi-head attention employs multiple ‘heads’ of attention layers, enabling the model to attend to information from different representation subspaces. In each head, there is an independent set of linear layers (projection matrices) for the query, key, and values (this is an important point that we will revisit in GQA). For each head (numbered h):
headʰ = Attention(Q.Wqʰ,K.Wkʰ,V.Wvʰ)
Concatenating Head Outputs
The outputs of the individual heads are concatenated and then linearly transformed.
MultiHead(Q,K,V) = Concat(head¹,head²,…,headʰ) .Wᵒ
Wᵒ is another weight matrix that linearly transforms the concatenated vector to the final output dimension.
The intuition behind multi-head attention is that by applying the attention mechanism multiple times in parallel, the model can capture different types of relationships in the data.
However, MHA enables a nuanced understanding of the relationships between different parts of the input. Nevertheless, this complexity comes at a cost — a significant demand on memory bandwidth, especially during decoder inference.
The Memory Bandwidth Challenge in Multi-Head Attention
The crux of the issue lies in the memory overhead. Each decoding step in autoregressive models like Transformers requires loading decoder weights along with all attention keys and values. This process is not only computationally intensive but also memory bandwidth-intensive. As model sizes grow, this overhead also increases, making scaling up an increasingly arduous task.
Emergence of Multi-Query Attention (MQA)
Multi-query attention (MQA) emerged as a solution to mitigate this bottleneck. The idea is simple yet effective: use multiple query heads but only a single key and value head. This approach significantly reduces the memory load, enhancing inference speed. It has been employed in multiple large-scale models such as PaLM, StarCoder, and Falcon.
In multi-query attention, we average the heads for keys and values so that all query heads share the same key and value head. This is achieved by replicating the mean-pooled “head” H times, where H is the number of query heads.
An interesting question to ask here is — how does one convert an existing pre-trained Multi-head attention model into a multi-query attention model (MQA)? The creation of a multi-query attention model from an existing multi-head model involves a two-step process: conversion of the model’s structure and subsequent pre-training. [1]
Conversion of Checkpoint: This step transforms the structure of a multi-head model into a multi-query model. It is achieved by merging (mean pooling) the projection matrices (linear layers) for keys and values from the multiple heads of the original model into single projection matrices for keys and values. This approach of mean pooling is found to be more effective than either selecting one of the existing key and value heads or initializing new key and value heads from scratch. The resulting structure has a consolidated key and value projection, characteristic of the multi-query model.
Pre-Training the Converted Model: After the structural transformation, the model undergoes additional training. This training is not as extensive as the original model training; it’s a fraction (denoted as α) of the original model’s training steps. The purpose of this pre-training phase is to allow the model to adjust and optimize its performance according to its new, simplified attention mechanism. The training follows the same recipe as the original, ensuring consistency in learning dynamics.
However, MQA is not without its drawbacks. The reduced complexity can lead to quality degradation and training instability.
Grouped Query Attention
Grouped-query attention (GQA) is a simple approach that blends elements of multi-head attention (MHA) and multi-query attention (MQA) to create a more efficient attention mechanism. The mathematical framework of GQA can be understood as follows:
Division into Groups: In GQA, the query heads (Q) from a traditional multi-head model are divided into G groups. Each group is assigned a single key (K) and value (V) head. This configuration is denoted as GQA-G, where G represents the number of groups.
Special Cases of GQA:
- GQA-1 = MQA: With only one group (G = 1), GQA becomes equivalent to MQA, as there’s only a single key and value head for all query heads.
- GQA-H = MHA: When the number of groups equals the number of heads (G = H), GQA behaves like traditional MHA, with each query head having its unique key and value head.
We mean-pool the key and value projection matrices of the original heads within each group to convert a multi-head model into a GQA model. This technique averages the projection matrices of each head in a group, resulting in a single key and value projection for that group.
By utilizing GQA, the model maintains a balance between MHA quality and MQA speed. Because there are fewer key-value pairs, memory bandwidth and data loading needs are minimized. The choice of G presents a trade-off: more groups (closer to MHA) result in higher quality but slower performance, whereas fewer groups (near to MQA) boost speed at the risk of sacrificing quality. Furthermore, as the model size grows, GQA allows for a proportional decrease in memory bandwidth and model capacity, corresponding with the model’s scale. In contrast, for bigger models, the reduction to a single key and value head can be unduly severe in MQA.
Conclusion
In this post, we first looked at traditional multi-head attention (MHA) and its variant Multi-query attention. Then we looked at a more generic formulation GQA, which is used by many LLM models for effective pre-training. GQA combines multi-head attention (MHA) with multi-query attention (MQA), providing a fair trade-off between quality and speed. GQA minimizes memory bandwidth demands by grouping query heads, making it appropriate for scaling up models. GQA has been used in place of typical multi-head attention in recent models such as the LLaMA-2 and Mistral7B.
References:
[1] GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints — https://arxiv.org/pdf/2305.13245.pdf
[2] MQA: Fast Transformer Decoding: One Write-Head is All You Need — https://arxiv.org/abs/1911.02150
[3] MHA: Attention is all you need: https://arxiv.org/abs/1706.03762
Demystifying GQA — Grouped Query Attention was originally published in Towards Data Science on Medium, where people are continuing the conversation by highlighting and responding to this story.
Originally appeared here:
Demystifying GQA — Grouped Query Attention
Go Here to Read this Fast! Demystifying GQA — Grouped Query Attention