On DRAM Sharded Matmul in TTNN
I spent some time trying to get the DRAM Sharded matmul kernel to work well for running a simple MLP, and it turned out not as useful as I hoped. In the end the performance was worse when not using LoFi math fidelity, due to it only running the compute on 8 cores.
How to DRAM Sharded matmul
The two shapes I was using as examples were 5120x17408 (Qwen3.6-27B) and 5376x21504 (Gemma4-31B), and their down projection counterparts. Sharding this in DRAM was easy, simply split to shards based on the number of dram channels (for BH p150 this is 8). For 5120x17408 (160x544 tiles), each shard gets 160x68 tiles. For 5376x21504 (168x672 tiles) each shard gets 168x84 tiles.
Sharding the input vector is more interesting, and there’s several factors to consider:
- Even though we are only sharding a single dimension, It must be sharded evenly onto a rectangular core grid. E.g. 8x2, 8x4.
- Presumably this is for NoC broadcasting concerns (noc broadcasts can specify a rectangular grid of cores). This means that weirdly shaped inputs may need to be padded.
- Using more cores means each core stores less data, which is important for not exceeding L1 size on each core.
- The number of tiles on each core should be divisible by a small number (e.g. 2, 3, 5 …).
- This is needed when transferring the data during compute (
in0_block_w). Transferring a single tile at a time negatively impacts performance. In practice I didn’t notice any performance improvements forin0_block_w > 2. - However, if this value is too large, then we run into trouble later attempting to allocate circular buffers, which scale proportionally with
in0_block_w.
- This is needed when transferring the data during compute (
These two considerations mean that we should pick a grid of cores such that the data is spread onto as many cores as we can, but at the same time the result must allow for a in0_block_w of at least 2, but small enough that we don’t blow out caches. The existing code in ttnn does not consider this when determining the L1 grid to use, leading to failures later.
Finally, the program config for the DRAM sharded matmul has 4 parameters:
in0_block_w: Size of each chunk to process… 2 - 5 are probably fine, more might exceed memory, 1 is really slow.per_core_M: M is 1, why is this even here. DRAM Sharded matmul doesn’t even support m > 1 tile.per_core_N: This only applies to placement of the output data. Actual computation is ALWAYS done on one-core-per-memory-channel.fused_activation: optionally apply an activation function, I found attaching the activation function here always makes this slower, likely due to the compute bottleneck.
The most interesting fact from this is that compute is only done on 8 cores. In practice, for my example matrices, using BFP8, this means that unless I’m using LoFi math fidelity, the operation ends up being compute limited, and actually performs slower than DRAM Interleaved matmuls, which are dram limited (see Results below). However, I found LoFi math fidelity to introduce too much precision issues to be usable.
Results
Decode HiFi2
DRAM Interleaved: BFP8 weights, BF16 inputs. HiFi2
| Op | Type | Layout | Device Time | Cores | DRAM | DRAM Util | FLOPS | Fidelity |
|---|---|---|---|---|---|---|---|---|
| MatmulDeviceOperation 32 x 5120 x 17408 | DRAM | INTERLEAVED | 249 µs | 109 | 363.1 GB/s | 70.9% | 22.9 TFLOPS | HiFi2 BF16 x BFP8 => BF16 |
| MatmulDeviceOperation 32 x 5120 x 17408 | DRAM | INTERLEAVED | 248 µs | 109 | 364.5 GB/s | 71.2% | 23 TFLOPS | HiFi2 BF16 x BFP8 => BF16 |
| BinaryNgDeviceOperation | DRAM | INTERLEAVED | 11 µs | 110 | BF16, BF16 => BF16 | |||
| MatmulDeviceOperation 32 x 17408 x 5120 | DRAM | INTERLEAVED | 345 µs | 80 | 262.7 GB/s | 51.3% | 16.5 TFLOPS | HiFi2 BF16 x BFP8 => BF16 |
| 853.67 µs |
DRAM Sharded: BFP8 weights, BF16 inputs, HiFi2
| Op | Type | Layout | Device Time | Cores | DRAM | DRAM Util | FLOPS | Fidelity |
|---|---|---|---|---|---|---|---|---|
| InterleavedToShardedDeviceOperation | DRAM | INTERLEAVED | 1 µs | 32 | BF16 => BF16 | |||
| MatmulDeviceOperation 32 x 5120 x 17408 | L1 | WIDTH_SHARDED | 315 µs | 12 | 282.7 GB/s | 55.2% | 18.1 TFLOPS | HiFi2 BF16 x BFP8 => BF16 |
| MatmulDeviceOperation 32 x 5120 x 17408 | L1 | WIDTH_SHARDED | 316 µs | 12 | 282.3 GB/s | 55.1% | 18.1 TFLOPS | HiFi2 BF16 x BFP8 => BF16 |
| ShardedToInterleavedDeviceOperation | L1 | WIDTH_SHARDED | 2 µs | 32 | BF16 => BF16 | |||
| ShardedToInterleavedDeviceOperation | L1 | WIDTH_SHARDED | 2 µs | 32 | BF16 => BF16 | |||
| BinaryNgDeviceOperation | L1 | INTERLEAVED | 10 µs | 110 | BF16, BF16 => BF16 | |||
| InterleavedToShardedDeviceOperation | L1 | INTERLEAVED | 2 µs | 32 | BF16 => BF16 | |||
| MatmulDeviceOperation 32 x 17408 x 5120 | L1 | WIDTH_SHARDED | 308 µs | 12 | 289.8 GB/s | 56.6% | 18.5 TFLOPS | HiFi2 BF16 x BFP8 => BF16 |
| ShardedToInterleavedDeviceOperation | L1 | WIDTH_SHARDED | 1 µs | 32 | BF16 => BF16 | |||
| 957.56 µs |
Decode LoFi
DRAM Interleaved: BFP8 weights, BF16 inputs. LoFi
| Op | Type | Layout | Device Time | Cores | DRAM | DRAM Util | FLOPS | Fidelity |
|---|---|---|---|---|---|---|---|---|
| MatmulDeviceOperation 32 x 5120 x 17408 | DRAM | INTERLEAVED | 247 µs | 109 | 366.8 GB/s | 71.6% | 23.1 TFLOPS | LoFi BF16 x BFP8 => BF16 |
| MatmulDeviceOperation 32 x 5120 x 17408 | DRAM | INTERLEAVED | 248 µs | 109 | 365.7 GB/s | 71.4% | 23 TFLOPS | LoFi BF16 x BFP8 => BF16 |
| BinaryNgDeviceOperation | DRAM | INTERLEAVED | 11 µs | 110 | BF16, BF16 => BF16 | |||
| MatmulDeviceOperation 32 x 17408 x 5120 | DRAM | INTERLEAVED | 344 µs | 80 | 263.2 GB/s | 51.4% | 16.6 TFLOPS | LoFi BF16 x BFP8 => BF16 |
| 849.84 µs |
DRAM Sharded: BFP8 weights, BF16 inputs, LoFi.
| Op | Type | Layout | Device Time | Cores | DRAM | DRAM Util | FLOPS | Fidelity |
|---|---|---|---|---|---|---|---|---|
| InterleavedToShardedDeviceOperation | DRAM | INTERLEAVED | 1 µs | 32 | BF16 => BF16 | |||
| MatmulDeviceOperation 32 x 5120 x 17408 | L1 | WIDTH_SHARDED | 197 µs | 12 | 452.5 GB/s | 88.4% | 29 TFLOPS | LoFi BF16 x BFP8 => BF16 |
| MatmulDeviceOperation 32 x 5120 x 17408 | L1 | WIDTH_SHARDED | 197 µs | 12 | 452.2 GB/s | 88.3% | 28.9 TFLOPS | LoFi BF16 x BFP8 => BF16 |
| ShardedToInterleavedDeviceOperation | L1 | WIDTH_SHARDED | 2 µs | 32 | BF16 => BF16 | |||
| ShardedToInterleavedDeviceOperation | L1 | WIDTH_SHARDED | 2 µs | 32 | BF16 => BF16 | |||
| BinaryNgDeviceOperation | L1 | INTERLEAVED | 10 µs | 110 | BF16, BF16 => BF16 | |||
| InterleavedToShardedDeviceOperation | L1 | INTERLEAVED | 2 µs | 32 | BF16 => BF16 | |||
| MatmulDeviceOperation 32 x 17408 x 5120 | L1 | WIDTH_SHARDED | 194 µs | 12 | 460.5 GB/s | 89.9% | 29.5 TFLOPS | LoFi BF16 x BFP8 => BF16 |
| ShardedToInterleavedDeviceOperation | L1 | WIDTH_SHARDED | 1 µs | 32 | BF16 => BF16 | |||
| 606.71 µs |