Transformer Feed-Forward Layers Are Key-Value Memories
This September 2021 paper presents a novel perspective on the role of feed-forward layers, showing that they function as key-value memories, where each key captures specific input patterns and each value represents a distribution over the output vocabulary.
The key idea is that feed-forward layers emulate neural memories.
The input vectors are multiplied by the keys (the first parameter matrix) to produce memory coefficients, which then weigh the distributions over the output vocabulary stored in the values (the second parameter matrix).
The output of the feed-forward layer is the weighted sum of these values.
The diagram below illustrates:
The authors' experiments reveal several interesting findings:
The learned patterns captured by the keys are human-interpretable, ranging from shallow patterns in lower layers to more semantic ones in upper layers. This suggests that feed-forward layers play a crucial role in capturing and encoding relevant features from the input sequence.
The values complement the keys by inducing output distributions that concentrate probability mass on tokens likely to appear immediately after each pattern, particularly in the upper layers. This indicates that feed-forward layers contribute to the model's ability to predict the next token based on the detected patterns.
The output of a feed-forward layer is a composition of its memories, which is then refined through the model's layers via residual connections to produce the final output distribution. This highlights the importance of the interactions between feed-forward layers and the overall model architecture in generating coherent and contextually appropriate outputs.
The authors' findings shed light on the previously under-explored role of feed-forward layers in transformer-based language models. By demonstrating that these layers function as pattern detectors and contribute to the composition of the final output distribution, the paper deepens our understanding of how transformers process and generate language.
Input Patterns Captured
The authors further investigate the role of keys in feed-forward layers by examining the input patterns they capture.
Through a well-designed experiment, they provide strong evidence that each key vector corresponds to specific patterns over the input sequence.
The experiment involves retrieving the training examples most associated with a given key (i.e., the input prefixes that yield the highest memory coefficients) and having human experts identify and classify patterns within these examples.
The results are striking: for almost every key in the sample, experts could identify a small set of well-defined patterns that cover most of the associated examples. This finding strongly supports the hypothesis that keys act as pattern detectors over the input sequence.
Moreover, the analysis reveals an interesting trend across the layers of the transformer model.
In lower layers (1-9), the detected patterns are predominantly shallow, often involving prefixes that share the last word. In contrast, upper layers (10-16) are characterised by more semantic patterns, where prefixes come from similar contexts but lack clear surface-form similarities.
This observation aligns with recent findings that lower layers in deep contextualized models encode shallow features, while upper layers capture more semantic information.
To further validate this trend, the authors conduct an additional experiment where they apply local modifications to the top trigger examples of randomly sampled keys.
By removing either the first, last, or a random token from the input and measuring the change in the memory coefficient, they demonstrate that the model considers the end of an example as more salient than the beginning for predicting the next token.
Notably, removing the last token has less impact in upper layers, supporting the conclusion that upper-layer keys are less correlated with shallow patterns.
Information Stored in the Corresponding Values
After analysing the input patterns captured by the keys, the authors turn their attention to the information stored in the corresponding values. They present compelling evidence that each value vector can be interpreted as a distribution over the output vocabulary, and that this distribution complements the patterns in the corresponding key, particularly in the upper layers of the model.
To convert value vectors into probability distributions, the authors multiply each value vector by the output embedding matrix and apply a softmax function. While the resulting distribution is uncalibrated due to the input-dependent memory coefficient, the ranking induced by the distribution remains informative.
The authors then investigate the relationship between the top-ranked token according to the value vector and the next token in the top-1 trigger example associated with the corresponding key. They calculate the agreement rate, i.e., the fraction of memory cells where the value's top prediction matches the key's top trigger example. The results reveal an interesting trend: in lower layers (1-10), the agreement rate is close to zero, but starting from layer 11, it quickly rises to 3.5%. This agreement rate is significantly higher than what would be expected from random token prediction, indicating that upper-layer memories possess non-trivial predictive power.
Furthermore, the authors examine the rank of the next token of a trigger example in the value vector's distribution. They find that the rank of the next token tends to increase through the layers, suggesting that it receives higher probability in the upper layers.
To automatically detect values with high agreement rates, the authors analyze the probability of the values' top predictions. They observe that distributions with higher maximum probabilities are more likely to agree with their key's top trigger example. By focusing on the 100 values with the highest probability across all layers and dimensions (mostly in the upper layers), they find that in almost half of the cases, there is at least one trigger example that agrees with the value's top prediction.
These findings support the hypothesis that values in the upper layers store information on how to directly predict the output (the distribution of the next word) from the input (patterns in the prefix). The fact that this clear correlation between keys' patterns and values' distributions is not observed in the lower layers suggests that these layers may not operate in the same embedding space as the upper layers.
Information from multiple memory cells
In the final part of the paper, the authors investigate how information from multiple memory cells across different layers is aggregated to form the model's final prediction.
They provide evidence that each feed-forward layer combines multiple memories to produce a distribution that is qualitatively different from its component memory value distributions. These layer-wise distributions are then combined through residual connections in a refinement process, where each feed-forward layer updates the residual's distribution to ultimately form the model's output.
The authors first examine intra-layer memory composition by analysing the behavior of randomly sampled prefixes from the validation set. They find that a typical example triggers hundreds of active memories per layer, but the majority of cells remain inactive. Interestingly, the number of active memories drops around layer 10, coinciding with the transition from shallow to semantic patterns observed in the expert annotations.
To determine if the feed-forward layer's output is dominated by a single memory cell or a composition of multiple memories, the authors count the instances where the layer's top prediction differs from all of its memories' top predictions. They find that in at least ~68% of the examples, the layer's final prediction is different from every one of the memories' predictions, indicating that the output is typically a result of composing multiple memories.
Next, the authors explore inter-layer prediction refinement, hypothesizing that the model uses the sequential composition through residual connections to refine its predictions from layer to layer. They measure how often the residual vector's probability distribution matches the model's final output and find that roughly a third of the predictions are determined in the bottom few layers, with this number growing rapidly from layer 10 onwards.
The authors also analyze the probability mass assigned by each layer's residual vector to the model's final prediction. The results show that not only the top prediction's identity but also the model's confidence in its decision is refined as the information progresses through the layers.
To understand the refinement process at each layer, the authors examine how often the residual's top prediction changes after interacting with the feed-forward layer and whether this change results from the feed-forward layer overriding the residual or from a true composition. They find that in most cases, the residual's top prediction ends up being the model's prediction, and when the residual's prediction does change, it rarely changes to the feed-forward layer's prediction. Instead, composing the residual's distribution with that of the feed-forward layer produces a "compromise" prediction, which is equal to neither.
Finally, the authors manually analyze random cases of last-layer composition and find that the feed-forward layer can modify the residual output by shifting the prediction to either a semantically distant word or a related word, suggesting that feed-forward layers tune the residual predictions at varying granularity, even in the final layer.
These findings provide valuable insights into the complex dynamics of information aggregation and refinement in transformer-based language models. The authors demonstrate that the model's predictions are the result of a multi-step process involving intra-layer memory composition and inter-layer prediction refinement through residual connections.
The observation that the feed-forward layers act as an elimination mechanism to "veto" the residual's top prediction and shift probability mass towards other candidates highlights the importance of the interaction between the feed-forward layers and the residual connections in shaping the model's output.
Moreover, the manual analysis of last-layer composition reveals that feed-forward layers can modify predictions at different levels of granularity, from semantically distant words to related words, underscoring the versatility and flexibility of the model's composition mechanism.
Overall, this part of the paper provides a comprehensive analysis of how the transformer model aggregates and refines information from multiple memories across layers to generate its final predictions. The authors' findings contribute to our understanding of the internal workings of transformer-based language models and showcase the complex interplay between feed-forward layers, residual connections, and the model's output.
These insights can inform future research on model interpretability, as well as the development of more efficient and transparent architectures for natural language processing tasks. By shedding light on the composition and refinement processes within the model, the authors pave the way for further investigations into how to enhance the model's performance, robustness, and explainability.
Conclusion
In conclusion, this paper makes significant contributions to our understanding of the inner workings of transformer-based language models by unveiling the role of feed-forward layers.
The authors present compelling evidence that feed-forward layers function as key-value memories, where keys capture human-interpretable input patterns and values induce distributions over the output vocabulary that correlate with the next-token distribution of the corresponding key patterns.
The findings highlight the importance of feed-forward layers in the transformer architecture and shed light on the complex process of information aggregation and refinement that leads to the model's final predictions.
The observation that the model's output is formed through a combination of intra-layer memory composition and inter-layer prediction refinement via residual connections provides valuable insights into the mechanisms underlying the transformer's impressive performance on natural language processing tasks.
Moreover, the authors identify several important research directions that stem from their work. The question of how the embedding space transforms across layers and the interplay between feed-forward layers and self-attention layers is a fascinating avenue for future investigation. Understanding these dynamics could lead to the development of more efficient and interpretable transformer architectures.
The potential generalizability of the findings to other transformer-based models, such as BERT encoders and neural translation models, opens up exciting possibilities for exploring the role of feed-forward layers in various NLP tasks. By extending the analysis to these different settings, researchers can gain a more comprehensive understanding of how transformers process and represent language.
Furthermore, the practical implications of this work are far-reaching. The insights gained from studying feed-forward layers can inform the development of interpretability methods, help address concerns related to training-data privacy, and guide the design of novel architectures that overcome the limitations of current transformer models.
In summary, this paper represents a significant step forward in demystifying the operation of transformer-based language models. By illuminating the role of feed-forward layers as key-value memories and elucidating the processes of memory composition and prediction refinement, the authors provide a solid foundation for future research on model interpretability, efficiency, and robustness. As the field of NLP continues to advance at a rapid pace, understanding the mechanisms behind the success of transformers will be crucial in developing the next generation of language models that are not only powerful but also transparent, reliable, and aligned with human values.
Last updated