Hacker Newsnew | past | comments | ask | show | jobs | submitlogin

Here is an attempt that sacrifices technical precision for intuition.

We're going to represent words as vectors (a sequence of numbers). We would like it to be the case that the value of the numbers reflects the meaning of the words. Words that mean similar things should be near each other. We also want to represent higher level ideas, ideas that take multiple words to express, in the same way. You can think of all the possible vectors as the entire space of ideas.

To begin with, though, we just have a vector for each word. This is insufficient - does the word "bank" mean the edge of a river or a place to store money? Is it a noun or a verb? In order to figure out the correct vector for a particular instance of this word, we need to take into account its context.

A natural idea might be to look at the words next to it. This works okay, but it's not the best. In the sentence "I needed some money so I got in my car and took a drive down to the bank", the word that really tells me the most about "bank" is "money", even though its far away in the sentence. What I really want is to find informative words based on their meaning.

This is what transformers and attention are for. The process works like this: For each word, I compose a "query" - in hand-wavy terms, this says "I'm looking for any other words out there that are X". X could be "related to money" or "near the end of the sentence" or "are adjectives". Next, for each word I also compute a "key", this is the counterpart of the query, and says "I have Y". For each query, I compare it to all the keys, and find which ones are most similar. This tells me which words (queries) should pay attention to which other words (keys). Finally, for each word I compute a "value". Whereas the "key" was sort of an advertisement saying what sort of information the word has, the "value" is the information itself. Under the hood, the "query", "key" and "value" are all just vectors. A query and a key match if their vectors are similar.

So, as an example, suppose that my sentence is "Steve has a green thumb". We want to understand the meaning of the word "thumb". Perhaps a useful step for understanding any noun would be to look for adjectives that modify it. We compute a "query" that says "I'm looking for words near the end of the sentence that are adjectives". When computing a "key" for the word green, maybe we compute "I'm near the end of the sentence, I'm a color, I'm an adjective or a noun". These match pretty well, so "thumb" attends to "green". We then compute a "value" for "green" that communicates its meaning.

By combining the information we got from the word "green" with the information for the word "thumb", we can have a better understanding of what it means in this particular sentence. If we repeat this process many times, we can build up stronger understanding of the whole sentence. We could also have a special empty word at the end that represents "what might come next?", and use that to generate more text.

But how did we know which queries, keys and values to compute? How did we know how to represent a word's meaning as numbers at all? These seemingly impossible questions are what is being "learned". How exactly that happens would require an equally big explanation of its own.

Keep in mind that this explanation is very fuzzy, and is only intended to convey the loose intuition of what is going on. It leaves out many technical details and even gets some details intentionally wrong to avoid confusion.



Thank you for this explanation. I've found that the QKV concepts are some of the most glossed over parts of attention and I'll be honest, some of the most confusing. Would you mind actually going into more detail on the questions you asked towards the end of your post? I vaguely understand how an embedding might get learned over time, but I don't understand how queries are "constructed" or how these three separate matrices get learned, or what kind of information is being encoded in them when they're learned. I also don't really understand how the value matrix itself is used.

Any further detail, even if it gets into some technical details would be very helpful and appreciated!


Thank you for asking this question and pushing for clarification. I literally have a tab open with a question to GPT asking it to explain to me the positional encoding and QKV concepts of transformers. After going through Karpathy's Zero to Hero and reading/watching a few other tutorials on modern NN architectures I feel I mostly have a grasp on the main topics (e.g. back propagation). But the Key/Query matrices just stick out like a sharp thorn. These are clearly the most important features of the transformer architecture and it is frustrating not to have a intuitive understanding of their function.


Sure, so to see how these things can be learned, we should be a little more precise about how they work.

Each token is a vector, and from that vector we compute three things - a query, a key and a value. Each of these is typically computed by multiplying the token's vector by a matrix (aka a linear projection). It's the values in these matrices that we need to learn.

When performing an attention step, for a given token we compare its "query" to every token's "key" (including it's own key - a token can attend to itself). This gives us a score for how important we think that key is. We normalize those scores to sum to one (typically via a softmax operation). Essentially, we have one "unit" of attention, and we're going to spread it across all the tokens. Some we will pay a lot of attention to, and others very little.

But what does it mean to pay a lot of or a little attention to other tokens? At the end of this whole procedure, we're going to arrive at a new vector that represents our new understanding/meaning for the token we're working on. This vector will be computed as a weighted sum of the values from all the tokens we're attending to. The weights are our attention scores (determined by the query-key similarity scores).

So as a simple example, suppose I have three tokens, A B and C, and let's focus on the attention operation for A. Say A's query vector is [1 2 -1]. A's key vector is [3 -1 0], B's key vector is [3 -1 -1] and C's key vector is [0 1 -3]. This gives us raw attention scores of 1 for A (attending to itself), 4 for B, and 5 for C. Rather than take a messy softmax, let's just normalize these to 0.1, 0.4 and 0.5 for simplicity.

Now that we have our attention weights, we also need to know each token's value. Let's say they are [1 0 1] for A, [-1 2 0] for B, and [1 1 1] for C. So our final output for this attention step will be 0.1 * [1 0 1] + 0.4 * [-1 2 0] + 0.5 * [1 1 1]. This gives us [0.2 1.3 0.6] (assuming I eyeballed the math correctly), this will be our new representation of A for the next step. (in practice there are some additional network layers that do more processing).

Okay, so how can we learn any of the matrices that go from a token vector to a query, a key and a value? The important thing is that all of this is just addition and multiplication - it's all nicely differentiable. And because the attention is "soft" (meaning we always attend at least a little bit to everything, as opposed to "hard" attention where we ignore some items entirely), we can even compute gradients through the attention scores.

Put a simpler way, I can ask "if I had included a bit more of A's value and a bit less of B's value, would my final output have been closer to the target?". To include a bit less of B, I need to make A's query and B's key a little further apart (lower dot product). And to make them a little further apart, I need to adjust the numbers in the matrices that produces them. Similarly, I can ask "if C's value had been a little larger in the first slot and a little smaller in the third, would my final output have been closer to the target?", and adjust the value matrix in the same way. Even if the attention to another token is very low, there's at least a small sliver of contribution from it, so we can still learn that we would've been better off having more (or even less) of it.

Learning the initial embeddings (the vectors that represent each word in the vocabulary, before any processing) is done in the same way - you trace all the way back through the network and ask "if the embedding for the word 'bank' had been a little more like this would my answer have been closer?", and adjust accordingly.

Understanding what exactly the queries keys and values represent is often very difficult - sometimes we can look and see which words attend to which other words and make up a convincing story. "Oh, in this layer, the verb is attending to the corresponding subject" or whatever. But in practice, the real meaning of a particular internal representation is going to be very fuzzy and not have a single clear concept behind it.

There is no explicit guidance to the network like "you should attend to this word because it's relevant to this other word." The only guidance is what the correct final output is (usually for LLMs the training task is to predict the next word, but it could be something else). And then the training algorithm adjusts all the parameters, including the embeddings, the QKV matrices, and all the other weights of the network, in whatever direction would make that correct output more likely.


This was an excellent explanation, thank you for taking the time to write it out!




Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: