First, we take a sequence of words and represent it as a grid of numbers: each column of the grid is a separate word, and each row of the grid is a measurement of some property of that word. Words with similar meanings are likely to have similar numerical values on a row-by-row basis.
(During the training process, we create a dictionary of all possible words, with a column of numbers for each of those words. More on this later!)
This grid is called the "context". Typical systems will have a context that spans several thousand columns and several thousand rows. Right now, context length (column count) is rapidly expanding (1k to 2k to 8k to 32k to 100k+!!) while the dimensionality of each word in the dictionary (row count) is pretty static at around 4k to 8k...
Anyhow, the Transformer architecture takes that grid and passes it through a multi-layer transformation algorithm. The functionality of each layer is identical: receive the grid of numbers as input, then perform a mathematical transformation on the grid of numbers, and pass it along to the next layer.
Most systems these days have around 64 or 96 layers.
After the grid of numbers has passed through all the layers, we can use it to generate a new column of numbers that predicts the properties of some word that would maximize the coherence of the sequence if we add it to the end of the grid. We take that new column of numbers and comb through our dictionary to find the actual word that most-closely matches the properties we're looking for.
That word is the winner! We add it to the sequence as a new column, remove the first-column, and run the whole process again! That's how we generate long text-completions on word at a time :D
So the interesting bits are located within that stack of layers. This is why it's called "deep learning".
The mathematical transformation in each layer is called "self-attention", and it involves a lot of matrix multiplications and dot-product calculations with a learned set of "Query, Key and Value" matrixes.
It can be hard to understand what these layers are doing linguistically, but we can use image-processing and computer-vision as a good metaphor, since images are also grids of numbers, and we've all seen how photo-filters can transform that entire grid in lots of useful ways...
You can think of each layer in the transformer as being like a "mask" or "filter" that selects various interesting features from the grid, and then tweaks the image with respect to those masks and filters.
In image processing, you might apply a color-channel mask (chroma key) to select all the green pixels in the background, so that you can erase the background and replace it with other footage. Or you might apply a "gaussian blur" that mixes each pixel with its nearest neighbors, to create a blurring effect. Or you might do the inverse of a gaussian blur, to create a "sharpening" operation that helps you find edges...
But the basic idea is that you have a library of operations that you can apply to a grid of pixels, in order to transform the image (or part of the image) for a desired effect. And you can stack these transforms to create arbitrarily-complex effects.
The same thing is true in a linguistic transformer, where a text sequence is modeled as a matrix.
The language-model has a library of "Query, Key and Value" matrixes (which were learned during training) that are roughly analogous to the "Masks and Filters" we use on images.
Each layer in the Transformer architecture attempts to identify some features of the incoming linguistic data, an then having identified those features, it can subtract those features from the matrix, so that the next layer sees only the transformation, rather than the original.
We don't know exactly what each of these layers is doing in a linguistic model, but we can imagine it's probably doing things like: performing part-of-speech identification (in this context, is the word "ring" a noun or a verb?), reference resolution (who does the word "he" refer to in this sentence?), etc, etc.
And the "dot-product" calculations in each attention layer are there to make each word "entangled" with its neighbors, so that we can discover all the ways that each word is connected to all the other words in its context.
So... that's how we generate word-predictions (aka "inference") at runtime!
By why does it work?
To understand why it's so effective, you have to understand a bit about the training process.
The flow of data during inference always flows in the same direction. It's called a "feed-forward" network.
But during training, there's another step called "back-propagation".
For each document in our training corpus, we go through all the steps I described above, passing each word into our feed-forward neural network and making word-predictions. We start out with a completely randomized set of QKV matrixes, so the results are often really bad!
During training, when we make a prediction, we KNOW what word is supposed to come next. And we have a numerical representation of each word (4096 numbers in a column!) so we can measure the error between our predictions and the actual next word. Those "error" measurements are also represented as columns of 4096 numbers (because we measure the error in every dimension).
So we take that error vector and pass it backward through the whole system! Each layer needs to take the back-propagated error matrix and perform tiny adjustments to its Query, Key, and Value matrixes. Having compensated for those errors, it reverses its calculations based on the new QKV, and passes the resultant matrix backward to the previous layer. So we make tiny corrections on all 96 layers, and eventually to the word-vectors in the dictionary itself!
Like I said earlier, we don't know exactly what those layers are doing. But we know that they're performing a hierarchical decomposition of concepts.
The "entanglement" part intuitively makes sense to me, but one bit I always get caught up on the key, query, and value matrices. In every self-attention explanation I've read/watched they tend to get thrown out there and similar to what you did here but leave their usage/purpose a little vague.
Would you mind trying to explain those in more detail? I've heard the database analogy where you start with a query to get a set of keys which you then use to lookup a value, but that doesn't really compute with my mental model of neural networks.
Is it accurate to say that these separate QKV matrices are layers in the network? That doesn't seem exactly right since I think the self-attention layer as a whole contains these three different matrices. I would assume they got their names for a reason that should make it somewhat easy to explain their individual purposes and what they try to represent in the NN.
I'm still trying to get a handle on that part myself... But my ever-evolving understanding goes something like this:
The "Query" matrix is like a mask that is capable of selecting certain kinds of features from the context, while the "Key" matrix focuses the "Query" on specific locations in the context.
Using the Query + Key combination, we select and extract those features from the context matrix. And then we apply the "Value" matrix to those features in order to prepare them for feed-forward into the next layer.
There are multiple "Attention Heads" per layer (GPT-3 had 96 heads per layer), and each Head performs its own separate QKV operation. After applying those 96 Q+K->V attention operations per layer, the results are merged back into a single matrix so that they can be fed-forward into the next layer.
Or something like that...
I'm still trying to grok it myself, and if anyone here shed more light on the details, I'd be very grateful!
I'm still trying to understand, for example, how many QKV matrices are actually stored in a model with a particular number of parameters. For example, in a GPT-NeoX-20B model (with 20 billion params) how many distinct Q, K, and V matrices are there, and what is their dimensionality?
EDIT:
I just read Imnimo's comment below, and it provides a much better explanation about QKV vectors. I learned a lot!
Its basically almost the same as convolution with image processing. For example, you take the 3 channel rgb value of a single pixel, do some math on it with the values of the surrounding pixels with weights, which gives you some value(s). Depending on the dimensions of everything, you can end up with a smaller dimension output, like a single 3 channel RGB value, or a higher dimension output (i.e for a 5x5 kernel, you can end up with a 9x9 output)
The confusing part that doesn't get mentioned is that the input vectors (Q, K, V) are weighted, i.e they are derived from the input with the standard linear transformation where y = A*x+b, where x is the input word, A is the linear layer matrix, and b is the bias. Those weighs are the things that are learned through the training process.
That was incredible. Thank you! If you made it into an article with images showing the mask/filter analogy, it might be one of the best/most unique explanations I've seen.
Love the ground-up approach beginning with data's shape.
Reminded me of the style of a book on machine learning. If anyone liked this explanation, you may appreciate this book:
If it only generates one word at a time and then repeat the process again, how does it know when to stop?
It feels like this method would create endless ramblings. But we all know you can ask Chatgpt to “summarize in one sentence” and it pulls it off. When speaking yourself you sort of have to think how to finish a sentence before you start it, to explain something cohesively, surely there must be something similar in the AI?
I have a very dumb question, I'll just throw it here: I understand word embeddings and tokenisation- and the value of each; but how can the two work together? Are embeddings calculated for tokens, and in that case, how useful are they, given that each token is just a fragment of a word, often with little or no semantic meaning?
I've heard that nowadays subword/token embeddings are learned during the training phase, and that they are useful for reconstructing the embeddings of words that contain them, and in fact allow the model to handle typos like "aple" (instead of "apple").
The way transformers operate is by transforming the embedding space through each layer. You could say that all the "understanding" is happening in that high dimensional space - that of a single token, but multiplied by the number of tokens. Seeding the embedding space with some learned value for each token is helpful. Think of it as just a vector database: token -> vector.
Decoder-only architectures (such as GPT) mask the token embedding interaction matrix (attention) such each token embedding and all subsequent transformations only have access to preceeding token embeddings (and transforms). This means that on output, only the last transformed token embedding has the full information of the entire context - and only it is capable of making predictions for the next token.
This is done so that during training, you can simultaneously make 1000s (context size) of predictions - every final token embedding transform is predicting the next token. The alternative (Encoder architecture, where there is no masking and the first token can interact with the final token) would result in massively inefficient training for predicting the next token as each full context can only make a single prediction.
First, we take a sequence of words and represent it as a grid of numbers: each column of the grid is a separate word, and each row of the grid is a measurement of some property of that word. Words with similar meanings are likely to have similar numerical values on a row-by-row basis.
(During the training process, we create a dictionary of all possible words, with a column of numbers for each of those words. More on this later!)
This grid is called the "context". Typical systems will have a context that spans several thousand columns and several thousand rows. Right now, context length (column count) is rapidly expanding (1k to 2k to 8k to 32k to 100k+!!) while the dimensionality of each word in the dictionary (row count) is pretty static at around 4k to 8k...
Anyhow, the Transformer architecture takes that grid and passes it through a multi-layer transformation algorithm. The functionality of each layer is identical: receive the grid of numbers as input, then perform a mathematical transformation on the grid of numbers, and pass it along to the next layer.
Most systems these days have around 64 or 96 layers.
After the grid of numbers has passed through all the layers, we can use it to generate a new column of numbers that predicts the properties of some word that would maximize the coherence of the sequence if we add it to the end of the grid. We take that new column of numbers and comb through our dictionary to find the actual word that most-closely matches the properties we're looking for.
That word is the winner! We add it to the sequence as a new column, remove the first-column, and run the whole process again! That's how we generate long text-completions on word at a time :D
So the interesting bits are located within that stack of layers. This is why it's called "deep learning".
The mathematical transformation in each layer is called "self-attention", and it involves a lot of matrix multiplications and dot-product calculations with a learned set of "Query, Key and Value" matrixes.
It can be hard to understand what these layers are doing linguistically, but we can use image-processing and computer-vision as a good metaphor, since images are also grids of numbers, and we've all seen how photo-filters can transform that entire grid in lots of useful ways...
You can think of each layer in the transformer as being like a "mask" or "filter" that selects various interesting features from the grid, and then tweaks the image with respect to those masks and filters.
In image processing, you might apply a color-channel mask (chroma key) to select all the green pixels in the background, so that you can erase the background and replace it with other footage. Or you might apply a "gaussian blur" that mixes each pixel with its nearest neighbors, to create a blurring effect. Or you might do the inverse of a gaussian blur, to create a "sharpening" operation that helps you find edges...
But the basic idea is that you have a library of operations that you can apply to a grid of pixels, in order to transform the image (or part of the image) for a desired effect. And you can stack these transforms to create arbitrarily-complex effects.
The same thing is true in a linguistic transformer, where a text sequence is modeled as a matrix.
The language-model has a library of "Query, Key and Value" matrixes (which were learned during training) that are roughly analogous to the "Masks and Filters" we use on images.
Each layer in the Transformer architecture attempts to identify some features of the incoming linguistic data, an then having identified those features, it can subtract those features from the matrix, so that the next layer sees only the transformation, rather than the original.
We don't know exactly what each of these layers is doing in a linguistic model, but we can imagine it's probably doing things like: performing part-of-speech identification (in this context, is the word "ring" a noun or a verb?), reference resolution (who does the word "he" refer to in this sentence?), etc, etc.
And the "dot-product" calculations in each attention layer are there to make each word "entangled" with its neighbors, so that we can discover all the ways that each word is connected to all the other words in its context.
So... that's how we generate word-predictions (aka "inference") at runtime!
By why does it work?
To understand why it's so effective, you have to understand a bit about the training process.
The flow of data during inference always flows in the same direction. It's called a "feed-forward" network.
But during training, there's another step called "back-propagation".
For each document in our training corpus, we go through all the steps I described above, passing each word into our feed-forward neural network and making word-predictions. We start out with a completely randomized set of QKV matrixes, so the results are often really bad!
During training, when we make a prediction, we KNOW what word is supposed to come next. And we have a numerical representation of each word (4096 numbers in a column!) so we can measure the error between our predictions and the actual next word. Those "error" measurements are also represented as columns of 4096 numbers (because we measure the error in every dimension).
So we take that error vector and pass it backward through the whole system! Each layer needs to take the back-propagated error matrix and perform tiny adjustments to its Query, Key, and Value matrixes. Having compensated for those errors, it reverses its calculations based on the new QKV, and passes the resultant matrix backward to the previous layer. So we make tiny corrections on all 96 layers, and eventually to the word-vectors in the dictionary itself!
Like I said earlier, we don't know exactly what those layers are doing. But we know that they're performing a hierarchical decomposition of concepts.
Hope that helps!