LogoLogo
Continuum WebsiteContinuum ApplicationsContinuum KnowledgeAxolotl Platform
  • TensorRT-LLM
  • The TensorRT-LLM Process
  • Performance
  • Virtual Machine Creation
  • CUDA Introduction
    • CUDA Architecture
    • Stream Multiprocessors: The Heart of GPU Computing
    • Pre Installation
    • Compatibility Assessment
    • NVCC: The NVIDIA CUDA Compiler
    • Installing Cuda
    • Installing the NVIDIA Container Toolkit
    • CUDA and bandwidth
    • Tensor Cores
  • Building TensorRT-LLM
    • Building from Source
    • TensorRT-LLM Dockerfile
      • Base Image
      • install_base.sh
      • install_cmake.sh
      • install_tensorrt.sh
      • install_pytorch.sh
      • requirements.txt
      • build_wheel.py
      • setup.py
      • Docker Makefile
      • Persistence
      • Running with persistent volumes
  • TensorRT-LLM Architecture and Process
    • The TensorRT-LLM process
    • INetworkDefinition
    • Model Definition
    • Compilation
    • Runtime Engine
    • Weight Bindings
    • Model Configuration
  • TensorRT-LLM build workflow
    • TensorRT-LLM build workflow - process
  • CUDA Graphs
    • Experimentation with CUDA Graphs
  • TensorRT-LLM Libraries
    • tensorrt_llm folders
    • tensorrt_llm/builder.py
    • tensorrt_llm/network.py
    • tensorrt_llm/module.py
    • top_model_mixin.py
    • trt-llm build command
    • trtllm-build CLI configurations
  • LLama2 installation
    • Converting Checkpoints
      • Checkpoint List - Arguments
      • Examples of running the convert_checkpoint.py script
      • convert_checkpoint examples
      • Checkpoint Script Arguments
      • checkpoint configuration file
      • run_convert_checkpoint.py script
    • LLama2 Files Analysis
    • TensorRT-LLM Build Engine Process
    • TensorRT-LLM Build Process Documentation
    • Build arguments
    • trtllm build configuration file
    • Run the buildconfig file
    • Analysis of the output from build.py
    • LLama3 configurations
    • Proposed checkpoint config file for LLama3
    • Proposed build config file for LLama3
    • run.py for inference
    • Using the models - running Llama
    • generate_int8 function
    • summarize.py script in Llama folder
    • Compiling LLama Models
  • Tasks
  • LLama Model Directory
    • llama/model.py
    • llama/utils.py
    • llama/weight.py
    • llama/convert.py
    • PreTrainedModel class
    • LlamaForCausalLM class
    • PretrainedConfig class
  • TensorRT-LLM Tutorial
  • Tutorial 2 - get inference going
  • examples/run.py
  • examples/utils.py
  • examples/summarize.py
  • The Python API
    • Layers
    • Functionals
    • functional.py
    • tensorrt_llm.functional.embedding
    • tensorrt_llm.functional.gpt_attention
    • tensorrt_llm.functional.layer_norm
    • tensorrt_llm.functional.rms_norm
    • Model
    • Quantization
    • Runtime
    • Runtime Process
  • Transformer Architecture
    • Attention Mechanism
    • Multi Head Attention
    • Positional Encoding
    • Scaled dot-product attention
    • Layer Normalisation
    • Activation Functions
    • Residual Connections
    • Position Wise Feed-Forward Layer
    • Transformer Feed-Forward Layers Are Key-Value Memories
    • KV Cache
      • Efficient Streaming Language Models with Attention Sinks
      • Input QKV tensor
    • General Notes on Model Architecture
  • Best Practices for Tuning the Performance of TensorRT-LLM
    • Optimisation Techniques
    • Batch Manager
    • Alibi
    • Relative Attention Bias
    • Beam Search
    • Rotary Positional Embedding (RoPE)
    • Numerical Precision
    • FP8 Formats for Deep Learning
  • Graph Rewriting
  • Reducing Activation Recomputation in Large Transformer Models
  • Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM
  • Numerical Position
  • TensorRT Models
  • Bloom
    • Huggingface Bloom Documentation
  • Runtime
  • Graph Rewriting (GW) module
  • FasterTransfomer Library
  • Dual ABI issues
  • Phi 2.0
  • ONNX
  • Message Passing Interface (MPI)
  • NVIDIA Nsight Systems: A Comprehensive Guide for TensorRT-LLM and Triton Inference Server
  • NCCL
Powered by GitBook
LogoLogo

Continuum - Accelerated Artificial Intelligence

  • Continuum Website
  • Axolotl Platform

Copyright Continuum Labs - 2023

On this page
  • Background and Problem Statement
  • Proposed Solution: Multi-Query Attention
  • Technical Details
  • Performance Considerations
  • Conclusion

Was this helpful?

  1. Transformer Architecture

Multi Head Attention

Fast Transformer Decoding: One Write-Head is All You Need

PreviousAttention MechanismNextPositional Encoding

Last updated 1 year ago

Was this helpful?

Noam Shazeer Google

This November 2019 paper discusses an optimisation for the multi-head attention mechanism used in the Transformer model, a popular deep learning model for sequence-to-sequence tasks.

The authors introduce a variant called multi-query attention, which aims to address the inefficiencies in incremental inference by sharing keys and values across all attention heads.

This approach reduces the memory bandwidth required during inference, leading to faster decoding times with minimal impact on model quality.

The multi-head attention mechanism is an extension of the self-attention mechanism. It involves splitting the input embeddings into multiple parts, called heads, and applying self-attention to each part independently. This allows the model to learn different types of relationships between tokens simultaneously, enabling it to capture various aspects of the input more effectively.

Background and Problem Statement

Transformers use multi-head attention to enable parallel processing of sequence information, significantly speeding up training.

However, during incremental inference, the model can't leverage this parallelism, resulting in slower performance mainly due to the high cost of repeatedly accessing the large keys and values tensors in memory.

Proposed Solution: Multi-Query Attention

The key innovation is the multi-query attention mechanism, where instead of having separate keys and values for each attention head, they are shared across all heads. This strategy decreases the tensor sizes and, consequently, the memory bandwidth needed during decoding, improving inference speed.

Technical Details

Neural Attention Mechanism

The attention mechanism allows a model to focus on different parts of a sequence when producing an output. It takes a query vector and a set of key-value pairs, computes attention weights using the query and keys, and applies these weights to the values to produce an output.

def DotProductAttention(q, K, V):
    """Dot-Product Attention on one query.
    Args:
        q: a vector with shape [k]
        K: a matrix with shape [m, k]
        V: a matrix with shape [m, v]
    Returns:
        y: a vector with shape [v]
    """
    logits = tf.einsum("k,mk->m", q, K)
    weights = tf.softmax(logits)
    return tf.einsum("m,mv->v", weights, V)

This function computes the attention output for a single query vector q using the key K and value V matrices. The einsum function is used for efficient tensor contractions and broadcasting.

Multi-Head Attention

In the Transformer model, the attention mechanism is extended to multi-head attention, allowing the model to aggregate information from different representation subspaces at different positions.

Multi-Head Attention Code

Here's the code for multi-head attention:

def MultiheadAttention(x, M, P_q, P_k, P_v, P_o):
    """Multi-head Attention on one query.
    Args:
        x: a vector with shape [d]
        M: a matrix with shape [m, d]
        P_q: a tensor with shape [h, d, k]
        P_k: a tensor with shape [h, d, k]
        P_v: a tensor with shape [h, d, v]
        P_o: a tensor with shape [h, d, v]
    Returns:
        y: a vector with shape [d]
    """
    q = tf.einsum("d,hdk->hk", x, P_q)
    K = tf.einsum("md,hdk->hmk", M, P_k)
    V = tf.einsum("md,hdv->hmv", M, P_v)
    logits = tf.einsum("hk,hmk->hm", q, K)
    weights = tf.softmax(logits)
    o = tf.einsum("hm,hmv->hv", weights, V)
    y = tf.einsum("hv,hdv->d", o, P_o)
    return y

This function extends the attention to multiple heads by projecting the input x and the memory M into multiple subspaces, calculating attention for each, and then recombining the results.

The multi-query attention modification offers a practical solution to enhance the efficiency of Transformer models during incremental inference by reducing the memory access requirements.

The proposed changes lead to faster inference times without significant sacrifices in model performance.

Batched Multi-Head Attention Explanation

Batched multi-head attention is an extension of the multi-head attention mechanism that processes multiple sequences and multiple positions within those sequences simultaneously. This batching significantly improves computational efficiency and is crucial for leveraging modern hardware's parallel processing capabilities.

Components of Batched Multi-Head Attention

Inputs

  • X: The tensor representing a batch of input sequences, with shape [b, n, d], where b is the batch size, n is the number of positions (queries) in each sequence, and d is the feature dimension.

  • M: The tensor representing memory (key-value pairs), with shape [b, m, d], where m is the number of memory positions.

  • mask: A tensor used to prevent information leakage in autoregressive models. It has shape [b, h, n, m] and contains -inf in positions where attention should not be applied (to block backward information flow).

Learned Projections

  • P_q, P_k, P_v, P_o: These tensors are the learned linear projection weights for queries, keys, values, and outputs, respectively. They transform the input and memory into different representation subspaces for each attention head.

Batched Attention Computation

Query, Key, Value Projections

  • Queries (Q), keys (K), and values (V) are computed using batch matrix multiplication with the corresponding projection matrices. This operation is efficiently handled using tf.einsum, which allows specifying complex tensor contraction operations succinctly.

Attention Weights

  • The logits (attention scores before softmax) are computed by batched dot products between queries and keys. The mask is then added to the logits to apply the masking, followed by a softmax to obtain the attention weights.

Output Computation

  • The outputs (O) are computed as a weighted sum of the values (V), using the attention weights. This operation combines the information from the values based on the attention weights.

Final Output Projection

  • The final output (Y) is obtained by projecting the attention outputs (O) using the output projection matrix P_o, followed by a sum over the heads.

def MultiheadAttentionBatched(X, M, mask, P_q, P_k, P_v, P_o):
    """Multi-head Attention for a batch of sequences.
    Args:
        X: a tensor with shape [b, n, d]
        M: a tensor with shape [b, m, d]
        mask: a tensor with shape [b, h, n, m]
        P_q: a tensor with shape [h, d, k]
        P_k: a tensor with shape [h, d, k]
        P_v: a tensor with shape [h, d, v]
        P_o: a tensor with shape [h, d, v]
    Returns:
        Y: a tensor with shape [b, n, d]
    """
    Q = tf.einsum("bnd,hdk->bhnk", X, P_q)
    K = tf.einsum("bmd,hdk->bhmk", M, P_k)
    V = tf.einsum("bmd,hdv->bhmv", M, P_v)
    logits = tf.einsum("bhnk,bhmk->bhnm", Q, K)
    weights = tf.softmax(logits + mask)
    O = tf.einsum("bhnm,bhmv->bhnv", weights, V)
    Y = tf.einsum("bhnv,hdv->bnd", O, P_o)
    return Y

This function encapsulates the batched multi-head attention mechanism, allowing for efficient processing of batches of sequences in parallel.

Analysis of Multi-Query Attention

Multi-query attention is an adaptation of the multi-head attention mechanism introduced in the seminal Transformer paper by Vaswani et al., 2017. The key distinction lies in the sharing of key (K) and value (V) matrices across all attention heads, rather than having separate sets for each head.

Differences in Multi-Query Attention

  • Shared Key and Value Matrices: In multi-query attention, all heads reference the same K and V matrices, reducing the memory footprint and potentially improving the efficiency of the attention mechanism, particularly in the context of incremental or online processing where fast access to keys and values is crucial.

  • Batched Attention Computation: The computation of queries, keys, values, and the final output remains largely similar to the standard multi-head attention but with the shared K and V.

Code Explanation

Batched Multi-Query Attention

def MultiqueryAttentionBatched(X, M, mask, P_q, P_k, P_v, P_o):
    Q = tf.einsum("bnd,hdk->bhnk", X, P_q)  # Compute queries for each head.
    K = tf.einsum("bmd,dk->bmk", M, P_k)    # Compute shared keys.
    V = tf.einsum("bmd,dv->bmv", M, P_v)    # Compute shared values.
    logits = tf.einsum("bhnk,bmk->bhnm", Q, K)  # Compute attention scores.
    weights = tf.softmax(logits + mask)         # Apply mask and softmax.
    O = tf.einsum("bhnm,bmv->bhnv", weights, V) # Compute weighted sum.
    Y = tf.einsum("bhnv,hdv->bnd", O, P_o)      # Final output projection.
    return Y

Incremental Multi-Query Self-Attention

For incremental attention, typically used in autoregressive generation, the function updates the keys and values based on the current input while maintaining a history of previous computations.

def MultiquerySelfAttentionIncremental(x, prev_K, prev_V, P_q, P_k, P_v, P_o):
    q = tf.einsum("bd,hdk->bhk", x, P_q)  # Compute query for current step.
    # Update keys and values with new input.
    K = tf.concat([prev_K, tf.expand_dims(tf.einsum("bd,dk->bk", x, P_k), axis=2)], axis=2)
    V = tf.concat([prev_V, tf.expand_dims(tf.einsum("bd,dv->bv", x, P_v), axis=2)], axis=2)
    logits = tf.einsum("bhk,bmk->bhm", q, K)  # Attention scores for current step.
    weights = tf.softmax(logits)              # Apply softmax to get weights.
    o = tf.einsum("bhm,bmv->bhv", weights, V) # Weighted sum for current step.
    y = tf.einsum("bhv,hdv->bd", o, P_o)      # Output for current step.
    return y, K, V

Performance Considerations

  • Memory Efficiency: By sharing K and V, multi-query attention reduces the memory requirements compared to standard multi-head attention, which is beneficial for models with large d or where memory bandwidth is a bottleneck.

  • Computational Complexity: The computational complexity remains similar to standard multi-head attention, with most operations scaling with the dimensions of the input and model size.

  • Applicability: This approach is particularly useful in scenarios where incremental or online computation is necessary, as it allows for more efficient access and updates to the key and value representations.

Conclusion

The multi-query attention mechanism provides a viable alternative to traditional multi-head attention, especially in scenarios where memory bandwidth is limited. It achieves competitive quality metrics while offering substantial improvements in speed, particularly during incremental inference. This balance of efficiency and effectiveness makes it a promising option for performance-critical applications in sequence modeling.

Fast Transformer Decoding: One Write-Head is All You NeedarXiv.org
Fast Transformer Decoding: One Write-Head is All You Need
Logo
Page cover image