Multi Head Attention
Fast Transformer Decoding: One Write-Head is All You Need
Last updated
Copyright Continuum Labs - 2023
Fast Transformer Decoding: One Write-Head is All You Need
Last updated
Noam Shazeer Google
This November 2019 paper discusses an optimisation for the multi-head attention mechanism used in the Transformer model, a popular deep learning model for sequence-to-sequence tasks.
The authors introduce a variant called multi-query attention, which aims to address the inefficiencies in incremental inference by sharing keys and values across all attention heads.
This approach reduces the memory bandwidth required during inference, leading to faster decoding times with minimal impact on model quality.
The multi-head attention mechanism is an extension of the self-attention mechanism. It involves splitting the input embeddings into multiple parts, called heads, and applying self-attention to each part independently. This allows the model to learn different types of relationships between tokens simultaneously, enabling it to capture various aspects of the input more effectively.
Transformers use multi-head attention to enable parallel processing of sequence information, significantly speeding up training.
However, during incremental inference, the model can't leverage this parallelism, resulting in slower performance mainly due to the high cost of repeatedly accessing the large keys and values tensors in memory.
The key innovation is the multi-query attention mechanism, where instead of having separate keys and values for each attention head, they are shared across all heads. This strategy decreases the tensor sizes and, consequently, the memory bandwidth needed during decoding, improving inference speed.
Neural Attention Mechanism
The attention mechanism allows a model to focus on different parts of a sequence when producing an output. It takes a query vector and a set of key-value pairs, computes attention weights using the query and keys, and applies these weights to the values to produce an output.
This function computes the attention output for a single query vector q
using the key K
and value V
matrices. The einsum
function is used for efficient tensor contractions and broadcasting.
In the Transformer model, the attention mechanism is extended to multi-head attention, allowing the model to aggregate information from different representation subspaces at different positions.
Multi-Head Attention Code
Here's the code for multi-head attention:
This function extends the attention to multiple heads by projecting the input x
and the memory M
into multiple subspaces, calculating attention for each, and then recombining the results.
The multi-query attention modification offers a practical solution to enhance the efficiency of Transformer models during incremental inference by reducing the memory access requirements.
The proposed changes lead to faster inference times without significant sacrifices in model performance.
Batched multi-head attention is an extension of the multi-head attention mechanism that processes multiple sequences and multiple positions within those sequences simultaneously. This batching significantly improves computational efficiency and is crucial for leveraging modern hardware's parallel processing capabilities.
Inputs
X
: The tensor representing a batch of input sequences, with shape [b, n, d]
, where b
is the batch size, n
is the number of positions (queries) in each sequence, and d
is the feature dimension.
M
: The tensor representing memory (key-value pairs), with shape [b, m, d]
, where m
is the number of memory positions.
mask
: A tensor used to prevent information leakage in autoregressive models. It has shape [
b, h, n, m]
and contains -inf
in positions where attention should not be applied (to block backward information flow).
Learned Projections
P_q
, P_k
, P_v
, P_o
: These tensors are the learned linear projection weights for queries, keys, values, and outputs, respectively. They transform the input and memory into different representation subspaces for each attention head.
Query, Key, Value Projections
Queries (Q
), keys (K
), and values (V
) are computed using batch matrix multiplication with the corresponding projection matrices. This operation is efficiently handled using tf.einsum
, which allows specifying complex tensor contraction operations succinctly.
Attention Weights
The logits (attention scores before softmax) are computed by batched dot products between queries and keys. The mask
is then added to the logits to apply the masking, followed by a softmax to obtain the attention weights.
Output Computation
The outputs (O
) are computed as a weighted sum of the values (V
), using the attention weights. This operation combines the information from the values based on the attention weights.
Final Output Projection
The final output (Y
) is obtained by projecting the attention outputs (O
) using the output projection matrix P_o
, followed by a sum over the heads.
This function encapsulates the batched multi-head attention mechanism, allowing for efficient processing of batches of sequences in parallel.
Multi-query attention is an adaptation of the multi-head attention mechanism introduced in the seminal Transformer paper by Vaswani et al., 2017. The key distinction lies in the sharing of key (K
) and value (V
) matrices across all attention heads, rather than having separate sets for each head.
Differences in Multi-Query Attention
Shared Key and Value Matrices: In multi-query attention, all heads reference the same K
and V
matrices, reducing the memory footprint and potentially improving the efficiency of the attention mechanism, particularly in the context of incremental or online processing where fast access to keys and values is crucial.
Batched Attention Computation: The computation of queries, keys, values, and the final output remains largely similar to the standard multi-head attention but with the shared K
and V
.
Code Explanation
Batched Multi-Query Attention
Incremental Multi-Query Self-Attention
For incremental attention, typically used in autoregressive generation, the function updates the keys and values based on the current input while maintaining a history of previous computations.
Memory Efficiency: By sharing K
and V
, multi-query attention reduces the memory requirements compared to standard multi-head attention, which is beneficial for models with large d
or where memory bandwidth is a bottleneck.
Computational Complexity: The computational complexity remains similar to standard multi-head attention, with most operations scaling with the dimensions of the input and model size.
Applicability: This approach is particularly useful in scenarios where incremental or online computation is necessary, as it allows for more efficient access and updates to the key and value representations.
The multi-query attention mechanism provides a viable alternative to traditional multi-head attention, especially in scenarios where memory bandwidth is limited. It achieves competitive quality metrics while offering substantial improvements in speed, particularly during incremental inference. This balance of efficiency and effectiveness makes it a promising option for performance-critical applications in sequence modeling.