Transformer architectures generally take quadratic time wrt sequence length, not exponential. Architectural innovations like flash attention also mitigate this somewhat.
Backtracking isn't involved, transformers are feedforward.
No, additional context does not cause exponential slowdowns and you absolutely can use FlashAttention tricks during training, I'm doing it right now. Transformers are not RNNs, they are not unrolled across timesteps, the backpropagation path for a 1,000,000 context LLM is not any longer than a 100 context LLM of the same size. The only thing which is larger is the self attention calculation which is quadratic wrt compute and linear wrt memory if you use FlashAttention or similar fused self attention calculations. These calculations can be further parallelized using tricks like ring attention to distribute very large attention calculations over many nodes. This is how google trained their 10M context version of Gemini.
So why are the context windows so "small", then? It would seem that if the cost was not so great, then having a larger context window would give an advantage over the competition.
The cost for both training and inference is vaguely quadratic while, for the vast majority of users, the marginal utility of additional context is sharply diminishing. For 99% of ChatGPT users something like 8192 tokens, or about 20 pages of context would be plenty. Companies have to balance the cost of training and serving models. Google did train an uber long context version of Gemini but since Gemini itself fundamentally was not better than GPT-4 or Claude this didn't really matter much, since so few people actually benefited from such a niche advantage it didn't really shift the playing field in their favor.
Marginal utility only drops because effective context is really bad, i.e. most models still vastly prefer the first things they see and those "needle in a haystack" tests are misleading in that they convince people that LLMs do a good job of handling their whole context when they just don't.
If we have the effective context window equal to the claimed context window, well, I'd start worrying a bit about most of the risks that AI doomers talk about...
There has been a huge increase in context windows recently.
I think the larger problem is "effective context" and training data.
Being technically able to use a large context window doesn't mean a model can actually remember or attend to that larger context well. In my experience, the kinds of synthetic "needle in haystack" tasks that AI companies use to show how large of a context their model can handle don't translate very well to more complicated use cases.
You can create data with large context for training by synthetically adding in random stuff, but there's not a ton of organic training data where something meaningfully depends on something 100,000 tokens back.
Also, even if it's not scaling exponentially, it's still scaling: at what point is RAG going to be more effective than just having a large context?
Great point about the meaningful datasets, this makes perfect sense. Esp. in regards to SFT and RLHF. Although I suppose it would be somewhat easier to do pretraining on really long context (books, I assume?)
Because you have to do inference distributed between multiple nodes at this point. For prefill because prefill is actually quadratic, but also for memory reasons. KV Cache for 405B at 10M context length would take more than 5 terabytes (at bf16). That's 36 H200 just for KV Cache, but you would need roughly 48 GPUs to serve bf16 version of the model. Generation speed at that setup would be roughly 30 tokens per second, 100k tokens per hour, and you can server only a single user because batching doesn't make sense at these kinds of context lengths. If you pay 3 dollars per hour per GPU, it's $1440 per million tokens cost. For fp8 version the numbers are a bit better: you need only 24 GPUs, generation speed stays roughly the same, so it's only 700 dollars per million tokens. There are architectural modifications that will bring that down significantly, but, nonetheless, it's still really really expensive, but also quite hard to get to work.
Another factor in context window is effective recall. If the model can't actually use a fact 1m tokens earlier, accurately and precisely, then there's no benefit and it's harmful to the user experience to allow the use of a poorly functioning feature. Part of what Google have done with Gemini's 1-2m token context window is demonstrate that the model will actually recall and use that data. Disclosure, I do work at Google but not on this, I don't have any inside info on the model.
Memory. I don't know the equation, but its very easy to see when you load a 128k context model at 8K vs 80K. The quant I am running would double VRAM requirements when loading 80K.
> The only thing which is larger is the self attention calculation which is quadratic wrt compute and linear wrt memory if you use FlashAttention or similar fused self attention calculations.
FFWD input is self-attention output. And since the output of self-attention layer is [context, d_model], FFWD layer input will grow as well. Consequently, FFWD layer compute cost will grow as well, no?
The cost of FFWD layer according to my calculations is ~(4+2 * true(w3)) * d_model * dff * n_layers * context_size so the FFWD cost grows linearly wrt the context size.
So, unless I misunderstood the transformer architecture, larger the context the larger the compute of both self-attention and FFWD is?
So you're saying that if I have a sentence of 10 words, and I want the LLM to predict the 11th word, FFWD compute is going to be independent of the context size?
I don't understand how since that very context is what makes the likeliness of output of next prediction worthy, or not?
More specifically, FFWD layer is essentially self attention output [context, d_model] matrix matmul'd with W1, W2 and W3 weights?
I may be missing something, but I thought that each context token would result in an 3 additional parameters per context token for self attention to build its map, since each attention must calculate a value considering all existing context
Transformer architectures generally take quadratic time wrt sequence length, not exponential. Architectural innovations like flash attention also mitigate this somewhat.
Backtracking isn't involved, transformers are feedforward.
Google advertises support for 128k tokens, with 2M-token sequences available to folks who pay the big bucks: https://blog.google/technology/ai/google-gemini-next-generat...