Hacker Newsnew | past | comments | ask | show | jobs | submitlogin

This is actually completely unnecessary in the batched inference case.

Here is an oversimplified explanation that gets the gist accross:

The standard architecture for transformer based LLMs is as follows: Token Embedding -> N Layers each consisting of an attention sublayer and an MLP sublayer -> Output Embedding.

Most attention implementations use a simple KV caching strategy. In prefill you first calculate the KV cache entries by performing GEMM against the W_K, W_V, W_Q tensors. In the case of token generation, you only need to calculate against the current token. Next comes the quadratic part of attention. You need to calculate softmax(Q K^T)V. This is two matrix multiplications and has a linear cost with respect to the number of entries in the KV cache for generating the next token, as you need to re-read the entire KV cache plus the new entry. For prefill you are processing n tokens, so the cost is quadratic. The KV cache is unique for every user session. It also grows with the size of the context. This means the KV cache is really expensive memory wise. It consumes both memory capacity and bandwidth and it also doesn't permit batching.

Meanwhile the MLP sublayer is so boring I won't bother going into the details, but the gist is that you have a simple gating network with two feed forward layers that project the token vector into a higher dimension (e.g. more outputs than inputs) known as up gate and then you element-wise multiply these vectors and then feed them into a down gate which reduces it back to the original dimension of the token vector. Since the matrices are always the same, you can process the tokens of multiple users at once.

Now here are the implications of what I wrote above: Prefill is generally compute bound is therefore mostly uninteresting,or rather, interesting for ASIC designers because FLOPS are cheap and SRAM is expensive. Token generation meanwhile is a mix of being memory bandwidth bound and compute bound in the batched case. The MLP layer is trivially parallelized through GEMM based batching. Having lots of SRAM is beneficial for GEMM, but it is not super critical in a double buffered implementation that performs loading and computation simultaneously with the memory bandwidth being chosen so that both finish roughly at the same time.

What SRAM buys you for GEMM is the following: Given two square matrices A, B and their output A*B = C of the same dimension, where A and B are both 1 GiB in size and x MiB of SRAM, you tile the GEMM operation so that each sub-matrix is x/3 MiB in size. Let's say x=120MiB which means 40 MiB per matrix. You will split the matrices A and B into approximately 25 tiles. For every tile in A, you have to load all tiles in B. Meaning (A) 25 + 25*25 (A*B) = 650 load operations of 40 MiB matrices for a total amount of reads of 26000 MiB. If you double the SRAM you now have 13 tiles of size 80 MiB. 13 + 13*13 = 182. 182 * 80 MiB = 14560 MiB. Loosely speaking, doubling SRAM reduces the needed memory bandwidth by half. This is boring old linear scaling, because fewer tiles also means bigger tiles, so the quadratic gain of 4x reduction in loads is outweighed by 2x bigger load operations. Having more SRAM is good though.

Now onto Flash Attention. If I had to dumb down flash attention, it's a very quirky way of arranging two GEMM operations to reduce the amount of memory allocated to the intermediate C matrix of the first Q*K^T multiplication. Otherwise it is the same as two GEMM with smaller tiles. Doubling SRAM halves the necessary memory bandwidth.

Final conclusion: In the batched multi user inference case your goal is to allocate the KV cache to SRAM for attention nodes and achieve as large of a batch size as possible for MLP nodes and use the SRAM to operate on as large tiles as possible. If you achieve both, then the required memory bandwidth scales reciprocal to the amount of SRAM. Storing full tensors in SRAM is not necessary at large batch sizes.

Of course since I only looked at the memory aspects, it shouldn't be left out that you need to evenly match compute and memory resources. Having SRAM on its own doesn't buy you anything really.



In batched inference Cerebras have no advantage but cost more AFAIU.




Consider applying for YC's Winter 2026 batch! Applications are open till Nov 10

Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: