At first, there is 16 fetches per row x column, 1024 in total. Then, it is observed that an input row needs to be fetched only once per output row, reducing the amount to 8 fetches per row, plus 8 per row x column, 8 * 8 + 8 * 64 = 576 in total. This requires the same amount of 16 numbers to be kept in registers.
But then it is claimed that by doing one quadrant at a time, all that is needed is 64 fetches per quadrant or 256 fetches in total. But that assumes we can keep 4 rows and 4 columns, 8 numbers per row or column = 64 numbers in registers! If we can only keep 16 numbers like above, each row of the quadrant is going to take 40 fetches, and we get 160 fetches per quadrant or 640 fetches in total, a pessimization from 576 fetches!
That’s a valid point - I’m assuming infinite register capacity at that point in the post.
The next section discusses what you’re talking about eg, how to deal with finite register/shared capacity by splitting the k dimension. I’ll mention the shared/register memory limitation sooner to clarify confusion.
The overall problem with your blog post is that it is beating around the bush rather than getting to the point. Overall, it feels like the blog post is explaining tiling in reverse order of what is needed to understand it.
"How effective is tiling?" and "Why tiling tiling is so fast" should be at the end, while the key section "Why there's a limit to tiling" which should be front and center is in the middle, followed by a subversion of the entire concept in "How to sidestep tiling limits"
It's also incredibly jarring to read this:
"Wondering how we were able to reduce memory usage "for free"? Indeed, the reduction wasn't free. In fact, we paid for this reduction a different way — by incurring more writes."
This is again, completely backwards. Let's assume you don't have a cache at all, you'll have to write out everything to DRAM every single time. The opposite is also true. Imagine you had an infinite number of registers. Every addition operation will accumulate into a register, which is a write operation. Hence, the number of write operations doesn't change.
Really the main points should be in this order: 1. matrix multiplication works best with square or almost square matrices. 2. registers and SRAM (including caches) is limited, forcing you to process matrices of finite size (aka tiles) 3. memory hierarchy means that the biggest matrix you can store at a given hierarchy gets bigger. 4. you can split matrix multiplication using inner and outer products 5. outer products take few inputs and have many outputs/accumulators, inner products take many inputs and have few outputs/accumulators. 6. You want to calculate the biggest outer product you can get away with, since this significantly reduces the memory needed to store inputs and maximizes number of cycles doing calculations, once you hit the limit, you want to reuse the accumulator, so you calculate inner products of outer products.
I see, thanks for the feedback - the current blog post’s flow certainly isn’t optimal. I’ll try reordering to eliminate jarring bits and see how it flows.
At first, there is 16 fetches per row x column, 1024 in total. Then, it is observed that an input row needs to be fetched only once per output row, reducing the amount to 8 fetches per row, plus 8 per row x column, 8 * 8 + 8 * 64 = 576 in total. This requires the same amount of 16 numbers to be kept in registers.
But then it is claimed that by doing one quadrant at a time, all that is needed is 64 fetches per quadrant or 256 fetches in total. But that assumes we can keep 4 rows and 4 columns, 8 numbers per row or column = 64 numbers in registers! If we can only keep 16 numbers like above, each row of the quadrant is going to take 40 fetches, and we get 160 fetches per quadrant or 640 fetches in total, a pessimization from 576 fetches!