Notes: CS25 Introduction to Transformer with Andrej

HK
31 min readOct 19, 2023

--

Video: CS25 | Stanford Seminar — Transformers United 2023: Introduction to Transformers w/ Andrej Karpathy

Transformers: A Brief Introduction

Let’s start with with the attention timeline. Attention all started with the paper Attention is all You Need by Vaswani et al in 2017. That was the beginning of Transformers. Before that we had the pre-historic models such as RNN, LSTM and some simple attention mechanisms.

Starting 2017, we saw an explosion of transformer in NLP where people started to use it for everything.

After 2018 to 2020, we see this explosion of transformer into other fields, like computer vision (ViT) and biology (AlphaFold).

In 2021 comes the generative era where we have a lot of generative modeling such as Codex, GPT, DALL-E, stable diffusions. A lot of things happened in generative modeling.

Then, at the end of 2022, we started to scale up and have models like ChatGPT, Whisper, etc. And that’s great.

Before 2017 — The pre-Transformer Era

Once, there were RNNs such as Seq2Seq, LSTMs and GRUs. They can encode history but they are bad at encoding (1) long sentences and (2) context.

Consider this example where we try to predict the last word in the text. Here, you need to understand the context to predict “French”, and the attention mechanism is very good at that whereas if you are just using LSTM, it does not work that well.

Another thing that transformers are good at is context prediction like finding the attention map. If I have a word like “it”, what words do it correlate with? It can output the probability attention as one of its activations. And this works better than existing mechanisms.

2021 — Taking off and the beginning of generative models

Where were we in 2021? We were on the verge of take-off. We were starting to realize the potential of Transformers in different fields. We solved a lot of long sequence problems such as protein folding (AlphaFold), offline Reinforcement Learning (RL). We started to see few-shot, zero-shot generalization. We also saw multimodal tasks and applications like generating images from languages. So that was DALL-E.

2022 — The emergence of LLM

This is where we are going from 2021 to 2022. We had taken off and see applications in audio generation, art, music, story-telling. We started to see new capabilities such as commonsense reasoning, logical reasoning, mathematical reasoning. We are also able to see more human alignment and interaction. The models are able to use reinforcement learning and human feedback. That’s how Chat GPT was trained to perform really good. We have a lot of mechanism for controlling toxicity bias and ethics now. There are a lot of other developments in other areas such as diffusion models.

Beyond 2022 — The future

There are a lot of more applications that we can enable.

  • One big example is video understanding and generation. That is something that everyone is interested in. We also hope that there will be a lot of models in other area this year, including finance and business.
  • It is exciting to GPT author a novel but we need to solve very long sequence modeling, and transformer models are still limited in context length (2K for GPT-3, 4K for GPT-3.5 and more than 32K for GPT-4). We need to make them generalize better on long sequences.
  • We also want to have generalized agents that can do multi-task and multi-input predictions. One recent development is in the field of Vision-Language Models.
  • Finally, we also want domain specific foundation models. So, you might want a GPT model that is good at health, like a DoctorGPT model. You might want a LawyerGPT that is trained only on law data. Now, we have GPT models that are trained on everything. But we might start to see more niche model that are good at one task. We could have a mixture of experts, just like how you would normally consult different experts on different topics. You can go to different AI model for different needs.

There are still a lot of missing ingredient to make this all successful.

  • The first of all is external memory. Current LLM like ChatGPT has no long term memory. It does not have the ability to store conversations for long term. This is something that we want to fix.
  • Second is to reduce computational complexity. Attention mechanism is quadratic over the sequence length, which is slow. We want to reduce it and make it faster.
  • Another thing is to enhance the model controllability. LLM models are stochastic. For example, you get different output each time you refreshes ChatGPT. We want to have a mechanism to control what sorts of outputs we get from them.
  • We also want to align our state-of-the-art model with how the brain works. We are seeing the surge but still need more research on seeing how to make them more aligned.

Transformer — A historical context

Traditional machine learning

Why does the transformer even exist? If you worked on AI back in 2012, you would not say that you worked on AI or deep learning. It was just machine learning. Now, it is fine to say you do so. Do you realize how lucky that you are entering this area in 2023?

Back then in 2011, when we worked on computer vision, the pipeline looked like the paper above. There was a zoo of different kinds of feature descriptors. When you went to a computer vision conference, every paper had their favorite feature descriptors that they were proposing. You would take notes on which one that you wanted to incorporate into your pipeline. Then, you would extract all of them and put an SVM on top. All these features had complicated code by themselves. You needed to collect these codes and run it. It was a total nightmare. On top of that, it did not work (very well).

So, the above diagram represents the prediction from that time. You would get predictions like this once in a while. When that happened, you just shrugged your shoulders and accepted that it would happen once in a while. Today, you will be looking for a bug.

Worse than that, every single field using AI had their own completely separate vocabulary that you work with. If you went to NLP papers, they would be completely different. So, you read NLP papers, and you would wonder what is this part-of-speech tagging, morphological analysis, syntactic parsing, co-reference resolution, etc. You would be completely confused. The vocabulary and everything were completely different. So, you could not read papers across different areas.

AlexNet

That changed in 2012 when Alex Krizhevsky and his colleagues demonstrated that if you scale a large neural network on large dataset, you can get very strong performance. So, up till then, there was a lot of focus on algorithms. But AlexNet shows that actually neural networks scale very well. So, you need to worry more on compute and data. When you scale them up, the model works very well. And that recipe actually did copy-&-paste very well across many areas of AI. So, we started to see neural networks pop up everywhere since 2012. We saw them in computer vision, NLP, speech, translation and RL. Everyone started to use the same kind of modeling toolkit, modeling framework.

And now, when you go to NLP and you start reading papers on machine translation and you would see this sequence-to-sequence network like the figure above. You realize that you can recognize them — this is a neural network, these are the network parameters, there is an optimizer, etc. You read things that you know of. This decreases the barrier to entry across different areas tremendously.

The beginning of Transformers

The big deal is when the transformers was proposed in 2017. This is an unassuming machine translation paper. But since then, not only the toolkit and the neural network are similar, we converged to one architecture that you copy-&-paste to everything.

We can just copy and paste this architecture and use it everywhere. What changes is the details of the data and the chunking of the data and how you feed it in. Now, papers are even more similar looking because everyone is using transformer. This convergence was remarkable to watch and has unfolded over the last decade.

The interesting thing about transformer is that it sort of hints that we may be converging to something that may be the brain is doing. The brain is very homogeneous and uniform across the entire sheet of your cortex. Some of the details may change but these changes feel more like the hyperparameters of a transformer. But, your auditory cortex, your visual cortex and everything else look very similar. So, maybe we are converging to some kind of a uniform, powerful learning algorithm here.

Where do Transformer come from?

Neural Probabilistic Language Model (2003)

This 2003 paper is the first popular application of neural network to the problem of language modeling, predicting the next word in the sequence. This allows you to build generative models over text. In this case, they were using a very simple multi layer perceptron. They took 3 words and predicted the probability distribution for the 4th word in the sequence.

Sequence-to-Sequence Model (2014)

That brings us to the Sequence-to-Sequence paper in 2014 that was pretty influential. The big problem here is that we don’t just want to take 3 words and predict the 4th. We want to predict how to go from an English to a French sentence. The key problem was we can have arbitrary number of words in English, and an arbitrary number words in French.

So, how can do you get an architecture that can process this variably sized input. Here, they used an LSTM. You have an encoder block that that consumes one word at a time, and builds up a context of what it has just read. Then, it acts as a conditioning vector to the decoder LSTM or RNN that basically simply generates the next word in each timestep, translating the English to French.

The big problem with this, that people realized fairly quickly, is that there is this encoder bottleneck. This entire English sentence that we are trying to condition on is packed into a single vector called the context vector that goes from the encoder to the decoder.

That is too much information to potentially maintain in a single vector. This is not right. So, people tried to find ways to alleviate the encoder bottleneck. This brings us to the Dzmitry’s 2014 paper on Neural Machine Translation.

Neural Machine Translation (2014)

Dzmitry proposed a way to look back to the words that are coming from the encoder. It was achieved by using this soft search. So as the decoder is decoding the words, it is allowed to look back words at the encoder via this soft mechanism attention proposed in the paper. So, this is the first time that we see attention mechanism.

The context vector that comes from the encoder is a weighted sum of the hidden states of the words in the encoding. The weights of the sum comes from a softmax that is based on the compatibility between the current state that you are decoding and the hidden states generated by your encoder.

The current modern equation of the attention really first appeared in this paper.

Transformer — “Attention is All You Need” (2017)

That brings us to this 2017 paper “Attention Is All You Need” by Vaswani. The attention component is just a small segment in Dzmitry’s paper and there was other components such as the bi-directional RNN encoder and decoder, etc. The attention paper is basically saying that what really work is just the attention by itself. So, we can remove everything else and just keep the attention.

What’s remarkable about this paper is that usually you see papers that are incremental and improve upon previous versions. You add one thing and show that it is better. But this paper is like a mix of multiple things at the same time.

  • Because attention operates over sets, you need to positionally encode your inputs because attention does not have the notion of space by itself.
  • They adopted the residual network structure from ResNet.
  • They interspersed attention with multi-layer perceptrons.
  • They used layer norm, which came from a different paper.
  • They introduced the concept of multiple heads of attention that were applied in parallel.
  • They gave us a fairly good set of hyperparameters that are still used to these day. The expansion factor in the multi-layer perception grows up by 4x, and this setting has stuck around.

All these components are combined in a very unique way and achieve a very good local minimum in an architecture space. So, this is really a landmark paper that is quite remarkable.

There are number of papers that try to play around with all kinds of details to improve the transformer since then. But most of them do not stick because the original version is already working very well. The only things that stick are:

  • The reshuffling of the Layer Norm to go into the pre-norm Layer Norm where the layer norms are after the multi-head self attention. They just put it before instead.
  • There has been innovation that has been adopted in positional encoding. it’s more common to use different rotary and relative positional encoding and so on.

Besides reshuffling of layer norms, the GPT and LLMs that you see today are basically similar to the 2017 architecture proposed 5 years ago. Although everybody try to improve it, the architecture remains pretty resilient. There has been changes but for the most part, it remains pretty much unchanged.

Self-Attention = Message Passing on Directed Graph

Now, let’s look at the self attention mechanism. It interleaves two phases where attention occurs in the first phase, the communication phase of the transformer.

  1. Communication phases
    The communication phase is multi-headed self-attention which performs data dependent message passing on directed graphs. At each node of the directed graph, you are storing a vector which communicates with each other during this phase.
  2. Computation phase
    The compute phase is just a multi-perception which then basically acts on each node individually.

But how do the nodes talk to each other in this directed graph?

Message Passing on Directed Graph

The Python code above shows one round of communication using attention as the message passing scheme. Here, we have a graph that is made up of nodes (Graph.nodes, line 5) and some random edges (Graph.edges, line 7–8).

Each node (Node) has a private data vector (Node.data). You can think of it as a private information of this node. It can emit a key (Node.key, lines 12–14), a query (Node.query, lines 16–18) and a value (Node.value, lines 20–22) by performing linear transformation on Node.data.

  • The query are the things that I am looking for.
  • The key are the things that I have.
  • The value are the things that I communicate.

When you perform message passing (Graph.run), the following happens:

  1. Get the query vector. When the nodes communicate, we loop over all the nodes individually in some random order (line 13). Imagine that you are the node itself. You get the query vector q via a linear transformation (line 16) — this is what I am looking for.
  2. Get related nodes and their keys. Then, we look at all inputs that point to you, the query node (line 19). The inputs broadcast what are the things that they have, i.e., their keys (line 23). So, I have the query, while they pass me their keys.
  3. Compute the affinity matrix. The query and the keys of the inputs interact by dot product to get scores (line 25). By doing dot product, you get some kind of unnormalized weightage of the interestingness of all the information in the nodes that point to me and to the things that I am looking for. Then, you normalize scores with softmax so that they sum to one (lines 27–28) and becomes a probability distribution.
  4. Compute the weighted sum of the values. Then, you get the values of all inputs (line 30) and do a weighted sum of the values to get your update (line 31).

To summarize, I have a query and they have the keys. Then, we perform dot products to get the interestingness or the affinity matrix, and apply softmax to normalize it. The weighted sum of the values of the connected nodes flow to me and update me. This is happening for each node individually and we update it in the end. This kind of message passing scheme is at the heart of a transformer. Note that this happens in a more vectorized, batched way. It is also interspersed with LayerNorm to make the training better.

Parallel processing in Transformer

In the communication phase of the transformer, message passing happens in parallel at every head and in series at every layer with different weights each time. That’s as far as the multi-headed attention goes.

If you look at the encoder-decoder models, you can think of them in terms of the connectivity of the nodes in a graph.

  • Encoder: All these tokens are in the encoder that we want to condition on, they are fully connected to each other. When they communicate, they communicate fully when you calculate the features.
  • Decocer: In the decoder, because we are trying to have a language model, we need to mask future tokens. We do not want to have communication for future tokens because they give away the answers at this step. Hence, the tokens in the decoder are fully connected to all encoder states, but only fully connected to the past and current decoder tokens. So, you end up with this triangular structure in the data graph. But that’s message passing scheme that this basically implements.

The cross attention in the decoder consumes the features only from the top of the encoder. At the encoder, all the nodes are looking at each other many many times (over the encoder layers(. They really figure out what’s in there. Then, the decoder is only looking at the top nodes in the encoder network. That’s roughly the message passing scheme.

nanoGPT

NanoGPT is a complete implementation of a transformer which is minimal. Here, it is reproducing GPT2 on open web text. It’s a pretty serious implementation reproducing GPT-2 using one node of 8 A100 40GB GPUs in 38 hours. It’s very readable- the implementation is about 300 lines.

Implementing your own GPT

Shakespeare toy dataset

Let’s try to have a decoder only transformer. That means it’s a language model. It tries to model the next word in the sequence or the next character in the sequence. The data that we train on is always some kind of text. Here is the tiny Shakespeare dataset. You take all of Shakespeare, concatenate it, you get a 1MB file. Then, you can train language models on it and get infinite Shakespeare.

Tokenization

The first thing to do is to convert it to a sequence of integers because transformers natively process numerical input. The sway encoding is done is to convert, in the simplest case, every character into integer. So, you convert “hii there” into [46, 47, 47, 1, 58, 46, 56, 43].

Batching the data

Then you can encode every single character as an integer and get a massive sequence of integers. You just concatenate them into one large, long 1-D sequence. And then you can train on it.

Here, we only have a single document. In some cases, if you have multiple independent document, what people like to do is to create special tokens and they intersperse the documents with the special end of text tokens that they splice in between to create boundaries. But these boundaries acutally don’t have any modeling impact. It’s just that the transformer is supposed to learn via backpropagation that the end of document sequence means that you should wipe the memory.

Then we produce batches. These batches of data just mean that we go back to the one-dimensional sequence, and take out chunks of this sequence.

If the block size is 8, that means that we are going to have up to 8 characters of context to predict the 9th character. Block size means the maximum length of context that your process will process.

The batch size indicates how many samples to process in parallel. We want this to be as large as possible to to fully take advantage of the GPU. In the code above, we are using a batch size of 4 to 8. Every row here is an independent example, which is a small chunk of the sequence that we are going to train on. We have both the inputs and the targets at every single point here.

To fully spell out what is contain by this 4x8 batch to the transformer, we have unpacked it on the top-left of the figure.

  • When the input is 47, the target is 58.
  • When the input is [47, 58], the target is 1.
  • When the input is [57, 58, 1], the target is 51

So, a single batch of examples with a shape of 4x8 actually have tonnes of individual examples that we are expecting a transformer to learn on in parallel. So, you will see that the batches are learnt on completely independently but the time dimension here along horizontally is also trained on in parallel.

So your real batch size is more like B x T. It’s just that the context grows linearly for the predictions that you make along the T direction in the model. So, this is all the examples that the model will learn from in this single batch.

The GPT class

Now, this is the GPT class. Because this is a decoder-only model, we are not going to have an encoder because we are not trying to condition it on some other external information, e.g., the input sentence in the original language translation task. We are just trying to produce a sequence of words that follow each other.

In the forward pass:

  1. Line 20: Generates the word embedding. We take these indices (idx) and encode the identity of the indices just via an embedding lookup table. Every single integer has an index into a lookup table of vectors in wte (line 8) and then pull out the word vector for that token.
  2. Line 21: Include positional encoding. Then, because transformer by itself processes set natively, we also need to positionally encode these vectors. We have both information of the token identity and its place in the sequence from 1 to block size. Now, the information of what and where is combined additively. The token embedding and the positional embeddings are just added.
  3. Line 22: Optional dropout. The input x contains a set of words and their positions and that feeds into the blocks of the transformer.
  4. Lines 23–24 Feedforward through the series of decoder blocks : This is just a series of decoder blocks in a transformer.
  5. Lines 25–26: Generate the logits. The output of the decoder goes through a layer norm. then lm_head generates the logits for the next word or integer in the sequence using a linear projection of the output of the blocks. lm_head is short for language model head. It’s just a linear function.
  6. Lines 27–29: Compute the loss. If we have the targets (which we produce in the data loader. The targets are just the inputs offset by one in time). The targets are then feed to the cross entropy loss. This is just the negative log likelihood, typical classification loss.

The Decoder Block

The code on the right of the figure above shows the implementation of a decoder block. As mentioned, there is a communicate phase and compute phase.

  • Communicate phase. Since GPT is a because this is a decoder-only model, the decoder block does not have cross attention. It only has the masked self-attention.
    In the self attention module, all the nodes get to talk to each other. If our block size is 8, there will be 8 nodes in this graph.
    - The 1st node points only to itself.
    - The 2nd node points to by the 1st node and itself.
    - The 3rd node points to by the 1st, 2nd nodes and itself.
    - So on and so forth…
    Given the input x, you apply a layer norm and then the self-attention where these 8 nodes communicate. But you have to keep in mind that the batch is 4, each of them communicating individually in one of these 8 nodes. There is no cross-communication across the batch samples. That is the residual pathway. The residual output is then added with x to form the output of the self attention.
  • Compute phase.
    Then the signals go through a layer norm and then processed using the multi-layer perceptron in the compute phase.

Feed forward (MLP)

The MLP here is pretty straightforward. It is just individual processing on each node. It transform the feature representation on that node. So, applying a two-layer neural net with a GELU nonlinearity, which is just like a ReLU. It’s just a nonlinearity.

Multi-head Self-Attention

This is the causal self-attention part, the communication phase. This is like the meat of things, the most complicated part. It’s complicated because of the batching — the implementation detail of how you mask the connectivity in the graph so that you cannot obtain any information from the future when you predict your token. Otherwise, it gives away information. So, if I’m the 5th token, I’m getting the 4th token coming into the input, and I am attending to the 3rd, 2nd and 1st. I am trying to figure what’s the next token. That is the next element over in the time dimension. The answer is at the input. So, I can’t get any information there. So, that’s why it is trickly.

  • (line 13) The input to the decoder is x which has the shape (B, T, n_embd) where n_embd is the embedding dimension.
  • (Line 14) In the forward pass, we are calculating the queries q, keys k, and values v given the input x. The shapes of q, k and v are (B, T, n_embd)
  • (Line 15–16) Split the features into nh (number of heads) different heads. The feature dimension for each head is ns = n_embd / nh. Hence the shapes of k, q and v are temporarily (B, T, nh, hs) where hs is the dimension of the head feature. Then swap the time and head dimension such that the final shapes of k, q and v are (B, nh, T, hs).
  • (Line 18) Then, we compute the attention att. The query matrix is multiplied with the keys (q @ k.transpose(-2, -1)). This is the dot product in parallel for the corresponding time dimensions (T), heads (nh) for all the queries and all the keys in the batch (B).
    - The shape of q is (B, nh, T, hs)
    - The shape of k.transpose(-2, -1) is (B, nh, hs, T)
    - Matrix multiplication @ is done between the last two dimensions, i.e.,
    (T×hs) @ (hs × T) → (T×T))
    - The remaining first three dimensions are broadcast. Therefore, the output shape is (B, nh, T, T)
  • (line 19) Then we do a mask fill operation. It is basically clamping the attention between the nodes that are not supposed to communicate to negative infinity. In softmax, negative infinity will make the attention on those elements zero. Here, we are going to embed with the weights, the affinities between nodes.
  • (line 20) Perform softmax to convert the attention vector att into a probability distribution.
  • (line 21–22) Then add some dropout. y = att @ v is basically the gathering of the information according to the affinities that we just calculated. This is just a weighted sum of the values at all those nodes.
    - att has the shape (B, nh, T, T)
    - v has the shape (B, nh, T, hs)
    - y has the shape (B, nh, T, hs)
  • (line 23) Then combine the output from all head.
    - Transpose to swap nh and T dimension so that the head dimension (nh) and the feature dimension (hs) are at the last two dimensions. The shape changes from (B, nh, T, hs) → (B, T, nh, hs).
    - Combine the last two dimensions (B, T, nh, hs) → (B, T, C)
  • (line 25) Perform linear projection back to the residual pathway.

Training

Then you can train this transformer. And then you can generate infinite Shakespeare.

Result

Because our block size is 8, you start with a <START> token. Then, you communicate only to yourself, because it is a single node. Then, you get the probability distribution of the first word in the sequence. Then, you decode it for the first character in the sequence. Then, you bring back the character and re-encode it as an integer. Now, you have the second token. We are at first position, add the positional encoding, goes into the sequence. And again this token now communicates with the first token and it’s identify. So, you just keep plugging it back. Once you run out of block size which is 8, you start to crop because you can never have block size or context more than 8 tokens in the time dimension following the way you train the transformer. So, all of these transformers have a finite block size or context length. In typical models, this will be 1024 tokens or 2048 tokens. But these tokens are usually like BPE tokens, SentencePiece tokens or WorkPiece tokens. There are many different types of encoding.

Some issues when increasing the context length:

  • The costs may increase quadratically when we increase the context length, e.g., increasing the context length from 2K to 4K with GPT-3 to GPT-3.5 increases the training cost by 4x rather than 2x.
  • Liu et al shows that adding relevant context at the beginning or the end of a prompt improves the performance of LLMs, compared to at the middle of the prompt. It is unclear how this observation will affect larger context windows.

Encoder-Only Transformer

If you want to implement a encoder-only transformer, all you need to do is to disable masking by deleting the highlighted line. This allows all the node to communicate with each other and information flows between all the node.

Cross-Attention Module

If you want to implement cross-attention so that you have a full encoder-decoder transformer, then you need to modify the Block module to support cross attention

  • Add a new input, i.e., the features from the top of the encoder.
  • Add a cross attention layer (self.cross_attn) in the middle between self.attn and self.mlp. The query comes from x but the key and value should come from the top of the encoder.
  • There would be information flowing from the encoder to all the nodes inside x.

Different types of Transformer

  • You can have a decoder-only model like GPT, encoder-only model like BERT, or encoder-decoder model like T5 that do machine translation.
  • In BERT, you cannot train it using language modeling setup that’s auto-regressive where you are just predicting the next word in the squence. You are training it to do slightly different objectives. You are putting in the full sentence and the full sentence is allowed to communicate fully. Then, you are trying to classify sentiment or something like that. You are not trying to model the next token in the sequence. So, the training is done via masking and other denoising techniques.

Computer Vision: Vision Transformer (ViT)

Transformer has been applied to all other fields. ViT take an image and chop it into a small squares which are then fed to the transformer. The patches are positionally encoded. The patches are communicate with each other throughout the entire transformer.

Speech Recognition: Conformer

In speech recognition, you can take your mel spectogram and chop it up into slices and you feed them into a transformer. So, there is a paper like this such as Conformer. You also have Whisper from OpenAI. You just cop up mel spectogram and feed it into a transformer, and then pretend you are dealing with text. And it works very well.

Reinforcement learning : Decision Transformer

Decision Transformer in RL. You take your states, actions and reward that you experience in an environment, and you just pretend it’s a language. You start to model the sequence of that, and then you can use for planning later. That works really well.

Biology: AlphaFold

At the heart of AlphaFold computation is also transformer.

Flexibility of Transformer

Transformers are very flexible. For example, in Tesla, you have a ConvNet that takes an image and makes prediction about it. Then the question is how to feed in extra information? And it’s not always trivial. For example, I have additional information that the output to be informed by. Maybe I have other sensors like radar, may be I have some mapping information or a vehicle type, or some audio. how do I feed such information to ConvNet. E.g., where do you feed it in? Do you concatenate it? Do you add it? At what stage? The transformer is much easier, because you just need take whatever you want, you just slice it up into pieces, and feed it with the set, and let the self-attention to figure out how everything should communicate. And that actually works. It frees neural nets from the burden of Euclidean spaace. Previously, you have to arrange your computation to conform to the Euclidean space of 3D dimensions of how you are laying out the compute. The compute actually happens in 3D space. But in attention, everything is just sets. So, it’s a very flexible framework — you can just throw in stuff into your conditioning set and everything just self-attend over itself.

Why makes the Transformer so effective?

What transformer so effective? A good example of this is the GPT-3 paper, Language Models are Few-Shot Learners. A better title for it may be like Learning models are capable of In-Context Learning or Meta-Learning. That’s really what makes them special. The setting they are working with are like this:

Suppose that you have a passage and I am asking questions about it. Then, as part of the context in the prompt, I am giving the questions and the answers. So, I am giving several examples of question-answer.

With more examples given in the context, the accuracy improves. This shows that transformer is able to somehow able to learn in the activations without doing any gradient descent in a typical fine-tuning fashion.

So, if you fine-tune, you have to give an example and the answer, And you fine-tune it using gradient descent.

But it looks like the transformer internally in its weights is doing something like a gradient descent, some kind of meta-learning in the weights of the transformer as it is reading the prompt. So, in this paper, they distinguish between the outer loop with SGD and this inner loop with in-context learning. So, the inner loop is like the transformer reading the sequence. The outer loop is like the training by gradient descent. So there are some training happening in the activation of the transformer as it is consuming a sequence that may be very much looks like gradient descent.

There are some recent paper that kind of hint at this and study it. In this paper, they propose something called the raw operator. They argue that the raw operator is implemented by transformer and they show that you can implement things like ridge regression on top of the raw operator.

There are papers hinting that maybe there is something like gradient-based learning inside the activations of the transformer. I think this is not impossible to think through. What is gradient-based learning? Forward pass, backward pass and then update. That looks like a Resnet. Because you are adding to the weights. So, the start of initial random set of weights, forward pass, backward pass and update weights. And you repeat this process.

The above are some general high-level presumption. Why this architecture is so interesting and why potentially became so popular? I think it optimizes 3 properties that are very desirable.

  1. The transformer is very expressive in the forward pass. It can implement very interesting functions, potentially function that can even do metal learning
  2. It is very optimizable owing to things like residual connections, layer norms and so on.
  3. It is extremely efficient. This is not always appreciated. But if you look at the computational graph, it is a shallow, wide network, which is perfect to take advantage of the parallelization power of GPU. The transformer is designed very deliberately to run efficiently on GPUs. There is previous work such as NeuralGPU, which also design neurl nets that are efficient on GPUs. and thinking backwards from the constraints of the hardware.

In hindsight, the paper “Attention is All You Need” is better named as “Transformer: A general-purpose, efficient, optimizable computer” .

If you can scale up training set and use a powerful enough neurl net (Transformer), the network becomes a general-purpose computer over time. If previous neural nets are special-purpose computers designed for specific task, GPT is a general-purpose computer, reconfigurable at runtime to run natural language programs. So, the programs are given in as prompts, and then GPT runs the program by completing the document. We may compare LLMs to computers, and it’s optimizable by gradient descent.

Question & Answering

  • Why Transformer is better than the Transformer?
    Although RNN can implement arbitrary programs, they are not optimizable. They are also not efficient because they are serial computing devices. As a computational graph, RNNs are very long, thin graph. If you stretch out the neurons, and take all the individual neurons interconnectivity and try to visualize them, RNN would be like a very long graph and that’s bad. It’s bad for optimizability, may be because when you backproprogating, you do not want to make too many steps.
    On the other hand, transformers are a shallow wide graph. From supervision to inputs is a very small number of hops. And it’s a long residual pathways, which make gradients flow very easily. There is all these layer norms to control the scales of all those activations. So, there is not too many hops, and you go through from supervision to input very quickly and just flows through the graph. So, it can all be done in parallel. So, for the encoder and decoder RNNs, you have to go the first word, then second word, then third word. But in transformer, all words are processed in parallel. Efficiency is important because in deep learning, scale matters. So, the size of the network that you can train is important. If it’s efficient on current hardware, you can make it bigger.
  • How do you handle different modality in neural network?
    You take your image and chop them up into patches. So that’s the first thousand tokens or whatever. For radar signal, you can just chop it and enter it. And then you have to encode it somehow. The transformer needs to know that they are coming from radar. So, you create a special token so that these radar tokens are slightly different in representation and they are learnable by gradient descent. So, different information can come in with a special embedding token that can be learnt.
  • Will the positional encoder encode some inductive bias into the model? Can we encode this information structurally?
    The positional encoder has very little inductive bias. If you have enough data, trying to mess with it is a bad thing, like trying to enter knowledge when you have enough knowledge in the data set itself is not usually productive. It really depends on what scales that you want. If you have infinity data, then you actually want to encode less and less, that tends to work better. If you have very little data, then actually you do want to encode some biases. Maybe if you have a much smaller dataset, then using convolutions is better because you actually have this bias coming from your filters.
    The transformer is pretty general, but there are ways to mess with the encodings to put in more structure. For example, you can go to the attention mechanism and say, if my image is chopped up into patches, each patch can only communicate to its neighborhood. So, you just do it in your attention matrix and mask out whatever that you do not want to communicate. So, people really play with this because the full attention is inefficient. So, they will intersperse layers that only communicate in little patches and then layers that communicate globally. And they will do all kinds of tricks like that. So you can slowly bring in more inductive bias, but the inductive bias are like factored out from the core transformer in the connectivity of the node as well as the position encoding.
  • How to improve in-context learning? Is it useful to increase the context length?
    Rather than increasing the context length, it may be better to keep it fixed, but allow the network to somehow use a scratch pad. Teach the transformer to somehow via examples in a prompt that you have a scratch pad, basically you cannot remember too much, your context line is finite, but you can use a scratch pad. And you do that by emitting a start scratch pad, then write whatever you want to remember, and then end scratch pad. And then you continue with whatever you want. Later, when it is decoding, you actually have special objects that when ou detect start scratch pad, you will like save whatever it puts in there like external thing and allow it to attend over it. Basically you can teach the transformer just dynamically because it is so meta-learned. You can teach it to use other gizmos and gadgets and allow it to expand its memory that way. It’s just like human learning to use a notepad. You don’t have to keep it in your brain. So, keeping things in your brain is like the context line for the transformer. But may be we can just give it a notebook. Then, it can query the notebook, read from it and write to it.

--

--

HK
HK

No responses yet