# Multi Head Attention

Noam Shazeer Google

This November 2019 paper discusses an optimisation for the multi-head attention mechanism used in the Transformer model, a popular deep learning model for sequence-to-sequence tasks.&#x20;

The authors <mark style="color:yellow;">introduce a variant called multi-query attention</mark>, which aims to address the inefficiencies in incremental inference by &#x73;*<mark style="color:yellow;">**haring keys and values across all attention heads**</mark>*.&#x20;

This approach reduces the memory bandwidth required during inference, leading to faster decoding times with minimal impact on model quality.

The multi-head attention mechanism is an extension of the self-attention mechanism. It involves splitting the input embeddings into multiple parts, called heads, and applying self-attention to each part independently. This allows the model to learn different types of relationships between tokens simultaneously, enabling it to capture various aspects of the input more effectively.

{% embed url="<https://arxiv.org/abs/1911.02150>" %}
Fast Transformer Decoding: One Write-Head is All You Need
{% endembed %}

### <mark style="color:blue;">Background and Problem Statement</mark>

Transformers use multi-head attention to enable parallel processing of sequence information, significantly speeding up training.&#x20;

However, during incremental inference, the model can't leverage this parallelism, resulting in slower performance mainly due to the high cost of repeatedly accessing the large keys and values tensors in memory.

### <mark style="color:blue;">Proposed Solution: Multi-Query Attention</mark>

The key innovation is the multi-query attention mechanism, where *<mark style="color:yellow;">**instead of having separate keys and values for each attention head, they are shared across all heads.**</mark>* This strategy decreases the tensor sizes and, consequently, the memory bandwidth needed during decoding, improving inference speed.

### <mark style="color:blue;">Technical Details</mark>

<mark style="color:green;">**Neural Attention Mechanism**</mark>

The attention mechanism allows a model to focus on different parts of a sequence when producing an output. It takes a query vector and a set of key-value pairs, computes attention weights using the query and keys, and applies these weights to the values to produce an output.

```python
def DotProductAttention(q, K, V):
    """Dot-Product Attention on one query.
    Args:
        q: a vector with shape [k]
        K: a matrix with shape [m, k]
        V: a matrix with shape [m, v]
    Returns:
        y: a vector with shape [v]
    """
    logits = tf.einsum("k,mk->m", q, K)
    weights = tf.softmax(logits)
    return tf.einsum("m,mv->v", weights, V)
```

This function computes the attention output for a <mark style="color:blue;">single query vector</mark> <mark style="color:yellow;">`q`</mark> using the key <mark style="color:yellow;">`K`</mark> and value <mark style="color:yellow;">`V`</mark> matrices. The <mark style="color:yellow;">`einsum`</mark> function is used for efficient tensor contractions and broadcasting.

#### <mark style="color:green;">**Multi-Head Attention**</mark>

In the Transformer model, the attention mechanism is extended to multi-head attention, allowing the model to aggregate information from different representation subspaces at different positions.

**Multi-Head Attention Code**

Here's the code for multi-head attention:

```python
def MultiheadAttention(x, M, P_q, P_k, P_v, P_o):
    """Multi-head Attention on one query.
    Args:
        x: a vector with shape [d]
        M: a matrix with shape [m, d]
        P_q: a tensor with shape [h, d, k]
        P_k: a tensor with shape [h, d, k]
        P_v: a tensor with shape [h, d, v]
        P_o: a tensor with shape [h, d, v]
    Returns:
        y: a vector with shape [d]
    """
    q = tf.einsum("d,hdk->hk", x, P_q)
    K = tf.einsum("md,hdk->hmk", M, P_k)
    V = tf.einsum("md,hdv->hmv", M, P_v)
    logits = tf.einsum("hk,hmk->hm", q, K)
    weights = tf.softmax(logits)
    o = tf.einsum("hm,hmv->hv", weights, V)
    y = tf.einsum("hv,hdv->d", o, P_o)
    return y
```

This function <mark style="color:yellow;">extends the attention to multiple heads by projecting the input</mark> <mark style="color:yellow;"></mark><mark style="color:yellow;">`x`</mark> <mark style="color:yellow;"></mark><mark style="color:yellow;">and the memory</mark> <mark style="color:yellow;"></mark><mark style="color:yellow;">`M`</mark> <mark style="color:yellow;"></mark><mark style="color:yellow;">into multiple subspaces, calculating attention for each, and then recombining the results.</mark>

The multi-query attention modification offers a practical solution to enhance the efficiency of Transformer models during incremental inference by reducing the memory access requirements.

&#x20;The proposed changes lead to faster inference times without significant sacrifices in model performance.

#### <mark style="color:green;">Batched Multi-Head Attention Explanation</mark>

Batched multi-head attention is an extension of the multi-head attention mechanism that processes multiple sequences and multiple positions within those sequences simultaneously. This batching significantly improves computational efficiency and is crucial for leveraging modern hardware's parallel processing capabilities.

#### <mark style="color:purple;">Components of Batched Multi-Head Attention</mark>

<mark style="color:purple;">**Inputs**</mark>

* <mark style="color:yellow;">`X`</mark><mark style="color:yellow;">:</mark> The tensor representing a batch of input sequences, with shape <mark style="color:blue;">`[b, n, d]`</mark>, where <mark style="color:yellow;">`b`</mark> is the batch size, <mark style="color:yellow;">`n`</mark> is the number of positions (queries) in each sequence, and <mark style="color:yellow;">`d`</mark> is the feature dimension.
* <mark style="color:yellow;">`M`</mark><mark style="color:yellow;">:</mark> The tensor representing memory (key-value pairs), with shape <mark style="color:blue;">`[b, m, d]`</mark>, where <mark style="color:yellow;">`m`</mark> is the number of memory positions.
* <mark style="color:yellow;">`mask`</mark><mark style="color:yellow;">:</mark> A tensor used to prevent information leakage in autoregressive models. It has shape `[`<mark style="color:blue;">`b, h, n, m]`</mark> and contains <mark style="color:blue;">`-inf`</mark> in positions where attention should not be applied (to block backward information flow).

<mark style="color:purple;">**Learned Projections**</mark>

* <mark style="color:blue;">`P_q`</mark><mark style="color:blue;">,</mark> <mark style="color:blue;"></mark><mark style="color:blue;">`P_k`</mark><mark style="color:blue;">,</mark> <mark style="color:blue;"></mark><mark style="color:blue;">`P_v`</mark><mark style="color:blue;">,</mark> <mark style="color:blue;"></mark><mark style="color:blue;">`P_o`</mark><mark style="color:blue;">:</mark> These tensors are the learned linear projection weights for queries, keys, values, and outputs, respectively.  They transform the input and memory into different representation subspaces for each attention head.

#### <mark style="color:green;">Batched Attention Computation</mark>

<mark style="color:purple;">**Query, Key, Value Projections**</mark>

* Queries <mark style="color:blue;">(</mark><mark style="color:blue;">`Q`</mark><mark style="color:blue;">),</mark> keys <mark style="color:blue;">(</mark><mark style="color:blue;">`K`</mark><mark style="color:blue;">),</mark> and values <mark style="color:blue;">(</mark><mark style="color:blue;">`V`</mark><mark style="color:blue;">)</mark> are computed using batch matrix multiplication with the corresponding projection matrices. This operation is efficiently handled using <mark style="color:yellow;">`tf.einsum`</mark>, which allows specifying complex tensor contraction operations succinctly.

<mark style="color:purple;">**Attention Weights**</mark>

* The <mark style="color:blue;">logits</mark> (attention scores before <mark style="color:blue;">softmax</mark>) are computed by <mark style="color:blue;">batched dot products</mark> between queries and keys. The <mark style="color:blue;">`mask`</mark> is then added to the logits to apply the masking, followed by a <mark style="color:blue;">softmax t</mark>o obtain the attention weights.

<mark style="color:purple;">**Output Computation**</mark>

* The outputs <mark style="color:yellow;">(</mark><mark style="color:yellow;">`O`</mark><mark style="color:yellow;">)</mark> are computed as a weighted sum of the values <mark style="color:yellow;">(</mark><mark style="color:yellow;">`V`</mark><mark style="color:yellow;">),</mark> using the attention weights. This operation combines the information from the values based on the attention weights.

<mark style="color:purple;">**Final Output Projection**</mark>

* The final output <mark style="color:yellow;">(</mark><mark style="color:yellow;">`Y`</mark><mark style="color:yellow;">)</mark> is obtained by projecting the attention outputs <mark style="color:yellow;">(</mark><mark style="color:yellow;">`O`</mark><mark style="color:yellow;">)</mark> using the output projection matrix <mark style="color:yellow;">`P_o`</mark>, followed by a sum over the heads.

```python
def MultiheadAttentionBatched(X, M, mask, P_q, P_k, P_v, P_o):
    """Multi-head Attention for a batch of sequences.
    Args:
        X: a tensor with shape [b, n, d]
        M: a tensor with shape [b, m, d]
        mask: a tensor with shape [b, h, n, m]
        P_q: a tensor with shape [h, d, k]
        P_k: a tensor with shape [h, d, k]
        P_v: a tensor with shape [h, d, v]
        P_o: a tensor with shape [h, d, v]
    Returns:
        Y: a tensor with shape [b, n, d]
    """
    Q = tf.einsum("bnd,hdk->bhnk", X, P_q)
    K = tf.einsum("bmd,hdk->bhmk", M, P_k)
    V = tf.einsum("bmd,hdv->bhmv", M, P_v)
    logits = tf.einsum("bhnk,bhmk->bhnm", Q, K)
    weights = tf.softmax(logits + mask)
    O = tf.einsum("bhnm,bhmv->bhnv", weights, V)
    Y = tf.einsum("bhnv,hdv->bnd", O, P_o)
    return Y
```

This function encapsulates the batched multi-head attention mechanism, allowing for efficient processing of batches of sequences in parallel.

#### <mark style="color:green;">Analysis of Multi-Query Attention</mark>

Multi-query attention is an adaptation of the multi-head attention mechanism introduced in the seminal Transformer paper by Vaswani et al., 2017. The key distinction lies in the sharing of key (`K`) and value (`V`) matrices across all attention heads, rather than having separate sets for each head.

<mark style="color:green;">**Differences in Multi-Query Attention**</mark>

* **Shared Key and Value Matrices:** In multi-query attention, all heads reference the same `K` and `V` matrices, reducing the memory footprint and potentially improving the efficiency of the attention mechanism, particularly in the context of incremental or online processing where fast access to keys and values is crucial.
* **Batched Attention Computation:** The computation of queries, keys, values, and the final output remains largely similar to the standard multi-head attention but with the shared `K` and `V`.

**Code Explanation**

<mark style="color:green;">**Batched Multi-Query Attention**</mark>

```python
def MultiqueryAttentionBatched(X, M, mask, P_q, P_k, P_v, P_o):
    Q = tf.einsum("bnd,hdk->bhnk", X, P_q)  # Compute queries for each head.
    K = tf.einsum("bmd,dk->bmk", M, P_k)    # Compute shared keys.
    V = tf.einsum("bmd,dv->bmv", M, P_v)    # Compute shared values.
    logits = tf.einsum("bhnk,bmk->bhnm", Q, K)  # Compute attention scores.
    weights = tf.softmax(logits + mask)         # Apply mask and softmax.
    O = tf.einsum("bhnm,bmv->bhnv", weights, V) # Compute weighted sum.
    Y = tf.einsum("bhnv,hdv->bnd", O, P_o)      # Final output projection.
    return Y
```

<mark style="color:green;">**Incremental Multi-Query Self-Attention**</mark>

For incremental attention, typically used in autoregressive generation, the function updates the keys and values based on the current input while maintaining a history of previous computations.

```python
def MultiquerySelfAttentionIncremental(x, prev_K, prev_V, P_q, P_k, P_v, P_o):
    q = tf.einsum("bd,hdk->bhk", x, P_q)  # Compute query for current step.
    # Update keys and values with new input.
    K = tf.concat([prev_K, tf.expand_dims(tf.einsum("bd,dk->bk", x, P_k), axis=2)], axis=2)
    V = tf.concat([prev_V, tf.expand_dims(tf.einsum("bd,dv->bv", x, P_v), axis=2)], axis=2)
    logits = tf.einsum("bhk,bmk->bhm", q, K)  # Attention scores for current step.
    weights = tf.softmax(logits)              # Apply softmax to get weights.
    o = tf.einsum("bhm,bmv->bhv", weights, V) # Weighted sum for current step.
    y = tf.einsum("bhv,hdv->bd", o, P_o)      # Output for current step.
    return y, K, V
```

### <mark style="color:blue;">**Performance Considerations**</mark>

* **Memory Efficiency:** By sharing `K` and `V`, multi-query attention reduces the memory requirements compared to standard multi-head attention, which is beneficial for models with large `d` or where memory bandwidth is a bottleneck.
* **Computational Complexity:** The computational complexity remains similar to standard multi-head attention, with most operations scaling with the dimensions of the input and model size.
* **Applicability:** This approach is particularly useful in scenarios where incremental or online computation is necessary, as it allows for more efficient access and updates to the key and value representations.

### <mark style="color:blue;">**Conclusion**</mark>

The multi-query attention mechanism provides a viable alternative to traditional multi-head attention, especially in scenarios where memory bandwidth is limited. It achieves competitive quality metrics while offering substantial improvements in speed, particularly during incremental inference. This balance of efficiency and effectiveness makes it a promising option for performance-critical applications in sequence modeling.


---

# Agent Instructions: Querying This Documentation

If you need additional information that is not directly available in this page, you can query the documentation dynamically by asking a question.

Perform an HTTP GET request on the current page URL with the `ask` query parameter:

```
GET https://tensorrt-llm.continuumlabs.ai/transformer-architecture/multi-head-attention.md?ask=<question>
```

The question should be specific, self-contained, and written in natural language.
The response will contain a direct answer to the question and relevant excerpts and sources from the documentation.

Use this mechanism when the answer is not explicitly present in the current page, you need clarification or additional context, or you want to retrieve related documentation sections.
