Yeah, it always seemed pretty wasteful to me. In every single forward pass the LLM must basically start out from scratch, without all the forward-looking plans it made the previous times, and must figure out what we are doing, where we are in the generation process, as in the movie Memento, waking up after an episode of amnesia, except you're waking up in the middle of typing out a sentence, you can look at the previous typed words, but can't carry your future plans with you ahead to the next word. At the next word, you (your clone) again wakes up and must figure out from scratch what it is that we are supposed to be typing out.
The obvious way to deal with this would be to send forward some of the internal activations as well as the generated words in the autoregressive chain. That would basically turn the thing into a recurrent network though. And those are more difficult to train and have a host of issues. Maybe there will be a better way.
> The obvious way to deal with this would be to send forward some of the internal activations as well as the generated words in the autoregressive chain.
Hi! I lead interpretability research at Anthropic.
That's a great intuition, and in fact the transformer architecture actually does exactly what you suggest! Activations from earlier time steps are sent forward to later time steps via attention. (This is another thing that's lost in the "models just predict the next word" framing.)
This actually has interesting practical implications -- for example, in some sense, it's the deep reason costs can sometimes be reduced via "prompt caching".
I'm more a vision person, and haven't looked a lot into NLP transformers, but is this because the attention is masked to only allow each query to look at keys/values from its own past? So when we are at token #5, then token #3's query cannot attend to token #4's info? And hence the previously computed attention values and activations remain the same and can be cached, because it would anyway be the same in the new forward pass?
If you want to be precise, there are “autoregressive transformers” and “bidirectional transformers”. Bidirectional is a lot more common in vision. In language models, you do see bidirectional models like Bert, but autoregressive is dominant.
The obvious way to deal with this would be to send forward some of the internal activations as well as the generated words in the autoregressive chain. That would basically turn the thing into a recurrent network though. And those are more difficult to train and have a host of issues. Maybe there will be a better way.