Here is an explanation for people who have basic familiarity with machine learning, that explains the Query/Key/Value computation. The ideas are relatively intuitive when you strip away the matrix manipulations.
Let's say that you are given a very good embedding of each English word as a vector of numbers. The idea of embeddings is that each dimension captures a different characteristic of the word. So for example, dimension 37 might capture gender and dimension 56 might capture how royal the word is. So "king" and "queen" will have very different scores in dimension 37 but both words will have a high score in dimension 56. These embeddings have been available for many years, eg word2vec.
The challenge is this: given a sentence with many words, how can you best encode the meaning of the sentence in a vector? The simplest approach is to take the embeddings for all the words and average them together to get a summary vector. This is a reasonable approach, and will work fine for simple tasks like assigning a positive or negative sentiment to the sentence. For example, it will do a good job of separating “I love this amazing product” and “I hate this terrible product”. This approach is analogous the “bag of words” model.
This simple model is missing two big things. First, when interpreting the meaning of each word, it uses the original embedding of that word without any regard for the context around the word. So “bank” will be assigned the same meaning in the sentence “we got money from the bank” and “we sat by the river bank.” Second, the model does not take into account the ordering of the words, so that “the dog bit the man” and “the man bit the dog” will both get the same result.
Said another way, our simple model lacks the expressibility to distinguish meaningful differences between sentences. Transformers address these deficiencies by making the model more expressive, while keeping it computationally efficient and easy to train.
First, the transformer recognizes that we need to reinterpret each word based on the other words in the sentence. Each “layer” of the transformer can be seen as doing a reinterpretation of each word based on its context. Successive layers are applied to reach iteratively better reinterpretations.
In order to reinterpret the word “bank” in the sentence “we got money from the bank”, we first need to score all of the other words based on their relevance to “bank”. Obviously, “money” should get a higher relevance score than “from”. A natural approach to get a relevance score is to take the dot product of each other word’s embedding against the embedding for the word bank. (The dot product of two vectors is a common metric to gauge their similarity.)
However, this is not quite expressive enough. For example, in the sentence “the food tastes disgusting”, the meaning of the word “disgusting” is actually not very similar to the meaning of “food”, but clearly “disgusting” is very relevant to the interpretation of “food.” To take this into account and improve the expressiveness of the model, the idea is to maintain a separate set of embeddings for each word to be used in the relevance score dot product. These embeddings are called “keys”. So when reinterpreting the word “bank” in the sentence “we got money from the bank”, we grab the key embeddings for all the words, and dot product each one against the separate query embedding for the word “bank”. For example, multiplying the key for “money” against the query for “bank” will tell us how relevant the word “money” is for reinterpreting the word “bank.” Note that we need to separate key and query to break the symmetry of the dot product. In the phrase "Monopoly" "money", the word "Monopoly" significantly changes our interpretation of the word "money", but "money" does not significantly change our interpretation of "Monopoly."
Now that we have these relevance scores, we normalize them to sum to 1, and then we reinterpret “bank” as a relevance-weighted average of the value vectors of all of the other words. This is called the Attention mechanism, because when reinterpreting each word we selectively "pay attention" to the words that are most relevant to it.
There are a number of details omitted in this description, but hopefully it gives a general sense. The black magic of designing ML architectures is developing the right intuition for what is just expressive enough to capture meaningful relationships, while still being easy to compute and leveraging modern hardware.
It's a bit like deciding how many legs to put on a table. It's not so much that 4 legs is theoretically correct, but rather that 2 legs definitely doesn't work, 3 legs seems okay but feels a bit iffy if we put our weight in certain places and it's not too much more expensive to add a fourth leg anyway, and 5 legs definitely seems like overkill.
———————————————
Major omitted details:
- The word embeddings are not fixed, but learned from scratch as trainable parameters
- The query, key and value vectors for each word are actually the output of the input embedding times three different matrices Q, K and V. The reason for doing this is a bit complex. In order to have successive layers of reinterpretation, you cannot keep using the same query vector for each word in the subsequent layers because you have reinterpreted what it means. After the first layer, you no longer have the word "bank", you just have a reinterpreted vector of numbers so there is no way to do a lookup to get a query vector. Multiplying the new vector by three different learned matrices is a clever way to get around this.
- Positional information is encoded by adding a (learned) positional vector the word embedding, so that the embedding for “bank” will look a little different if it is at the beginning of the sentence vs. the end of the sentence.
Let's say that you are given a very good embedding of each English word as a vector of numbers. The idea of embeddings is that each dimension captures a different characteristic of the word. So for example, dimension 37 might capture gender and dimension 56 might capture how royal the word is. So "king" and "queen" will have very different scores in dimension 37 but both words will have a high score in dimension 56. These embeddings have been available for many years, eg word2vec.
The challenge is this: given a sentence with many words, how can you best encode the meaning of the sentence in a vector? The simplest approach is to take the embeddings for all the words and average them together to get a summary vector. This is a reasonable approach, and will work fine for simple tasks like assigning a positive or negative sentiment to the sentence. For example, it will do a good job of separating “I love this amazing product” and “I hate this terrible product”. This approach is analogous the “bag of words” model.
This simple model is missing two big things. First, when interpreting the meaning of each word, it uses the original embedding of that word without any regard for the context around the word. So “bank” will be assigned the same meaning in the sentence “we got money from the bank” and “we sat by the river bank.” Second, the model does not take into account the ordering of the words, so that “the dog bit the man” and “the man bit the dog” will both get the same result.
Said another way, our simple model lacks the expressibility to distinguish meaningful differences between sentences. Transformers address these deficiencies by making the model more expressive, while keeping it computationally efficient and easy to train.
First, the transformer recognizes that we need to reinterpret each word based on the other words in the sentence. Each “layer” of the transformer can be seen as doing a reinterpretation of each word based on its context. Successive layers are applied to reach iteratively better reinterpretations.
In order to reinterpret the word “bank” in the sentence “we got money from the bank”, we first need to score all of the other words based on their relevance to “bank”. Obviously, “money” should get a higher relevance score than “from”. A natural approach to get a relevance score is to take the dot product of each other word’s embedding against the embedding for the word bank. (The dot product of two vectors is a common metric to gauge their similarity.)
However, this is not quite expressive enough. For example, in the sentence “the food tastes disgusting”, the meaning of the word “disgusting” is actually not very similar to the meaning of “food”, but clearly “disgusting” is very relevant to the interpretation of “food.” To take this into account and improve the expressiveness of the model, the idea is to maintain a separate set of embeddings for each word to be used in the relevance score dot product. These embeddings are called “keys”. So when reinterpreting the word “bank” in the sentence “we got money from the bank”, we grab the key embeddings for all the words, and dot product each one against the separate query embedding for the word “bank”. For example, multiplying the key for “money” against the query for “bank” will tell us how relevant the word “money” is for reinterpreting the word “bank.” Note that we need to separate key and query to break the symmetry of the dot product. In the phrase "Monopoly" "money", the word "Monopoly" significantly changes our interpretation of the word "money", but "money" does not significantly change our interpretation of "Monopoly."
Now that we have these relevance scores, we normalize them to sum to 1, and then we reinterpret “bank” as a relevance-weighted average of the value vectors of all of the other words. This is called the Attention mechanism, because when reinterpreting each word we selectively "pay attention" to the words that are most relevant to it.
There are a number of details omitted in this description, but hopefully it gives a general sense. The black magic of designing ML architectures is developing the right intuition for what is just expressive enough to capture meaningful relationships, while still being easy to compute and leveraging modern hardware.
It's a bit like deciding how many legs to put on a table. It's not so much that 4 legs is theoretically correct, but rather that 2 legs definitely doesn't work, 3 legs seems okay but feels a bit iffy if we put our weight in certain places and it's not too much more expensive to add a fourth leg anyway, and 5 legs definitely seems like overkill.
———————————————
Major omitted details:
- The word embeddings are not fixed, but learned from scratch as trainable parameters
- The query, key and value vectors for each word are actually the output of the input embedding times three different matrices Q, K and V. The reason for doing this is a bit complex. In order to have successive layers of reinterpretation, you cannot keep using the same query vector for each word in the subsequent layers because you have reinterpreted what it means. After the first layer, you no longer have the word "bank", you just have a reinterpreted vector of numbers so there is no way to do a lookup to get a query vector. Multiplying the new vector by three different learned matrices is a clever way to get around this.
- Positional information is encoded by adding a (learned) positional vector the word embedding, so that the embedding for “bank” will look a little different if it is at the beginning of the sentence vs. the end of the sentence.