On DRAM Sharded Matmul in TTNN


I spent some time trying to get the DRAM Sharded matmul kernel to work well for running a simple MLP, and it turned out not as useful as I hoped. In the end the performance was worse when not using LoFi math fidelity, due to it only running the compute on 8 cores.

How to DRAM Sharded matmul

The two shapes I was using as examples were 5120x17408 (Qwen3.6-27B) and 5376x21504 (Gemma4-31B), and their down projection counterparts. Sharding this in DRAM was easy, simply split to shards based on the number of dram channels (for BH p150 this is 8). For 5120x17408 (160x544 tiles), each shard gets 160x68 tiles. For 5376x21504 (168x672 tiles) each shard gets 168x84 tiles.

Sharding the input vector is more interesting, and there’s several factors to consider:

  • Even though we are only sharding a single dimension, It must be sharded evenly onto a rectangular core grid. E.g. 8x2, 8x4.
    • Presumably this is for NoC broadcasting concerns (noc broadcasts can specify a rectangular grid of cores). This means that weirdly shaped inputs may need to be padded.
    • Using more cores means each core stores less data, which is important for not exceeding L1 size on each core.
  • The number of tiles on each core should be divisible by a small number (e.g. 2, 3, 5 …).
    • This is needed when transferring the data during compute (in0_block_w). Transferring a single tile at a time negatively impacts performance. In practice I didn’t notice any performance improvements for in0_block_w > 2.
    • However, if this value is too large, then we run into trouble later attempting to allocate circular buffers, which scale proportionally with in0_block_w.

These two considerations mean that we should pick a grid of cores such that the data is spread onto as many cores as we can, but at the same time the result must allow for a in0_block_w of at least 2, but small enough that we don’t blow out caches. The existing code in ttnn does not consider this when determining the L1 grid to use, leading to failures later.

Finally, the program config for the DRAM sharded matmul has 4 parameters:

  • in0_block_w: Size of each chunk to process… 2 - 5 are probably fine, more might exceed memory, 1 is really slow.
  • per_core_M: M is 1, why is this even here. DRAM Sharded matmul doesn’t even support m > 1 tile.
  • per_core_N: This only applies to placement of the output data. Actual computation is ALWAYS done on one-core-per-memory-channel.
  • fused_activation: optionally apply an activation function, I found attaching the activation function here always makes this slower, likely due to the compute bottleneck.

The most interesting fact from this is that compute is only done on 8 cores. In practice, for my example matrices, using BFP8, this means that unless I’m using LoFi math fidelity, the operation ends up being compute limited, and actually performs slower than DRAM Interleaved matmuls, which are dram limited (see Results below). However, I found LoFi math fidelity to introduce too much precision issues to be usable.

Results

Decode HiFi2

DRAM Interleaved: BFP8 weights, BF16 inputs. HiFi2

OpTypeLayoutDevice TimeCoresDRAMDRAM UtilFLOPSFidelity
MatmulDeviceOperation 32 x 5120 x 17408DRAMINTERLEAVED249 µs109363.1 GB/s70.9%22.9 TFLOPSHiFi2 BF16 x BFP8 => BF16
MatmulDeviceOperation 32 x 5120 x 17408DRAMINTERLEAVED248 µs109364.5 GB/s71.2%23 TFLOPSHiFi2 BF16 x BFP8 => BF16
BinaryNgDeviceOperationDRAMINTERLEAVED11 µs110BF16, BF16 => BF16
MatmulDeviceOperation 32 x 17408 x 5120DRAMINTERLEAVED345 µs80262.7 GB/s51.3%16.5 TFLOPSHiFi2 BF16 x BFP8 => BF16
853.67 µs

DRAM Sharded: BFP8 weights, BF16 inputs, HiFi2

OpTypeLayoutDevice TimeCoresDRAMDRAM UtilFLOPSFidelity
InterleavedToShardedDeviceOperationDRAMINTERLEAVED1 µs32BF16 => BF16
MatmulDeviceOperation 32 x 5120 x 17408L1WIDTH_SHARDED315 µs12282.7 GB/s55.2%18.1 TFLOPSHiFi2 BF16 x BFP8 => BF16
MatmulDeviceOperation 32 x 5120 x 17408L1WIDTH_SHARDED316 µs12282.3 GB/s55.1%18.1 TFLOPSHiFi2 BF16 x BFP8 => BF16
ShardedToInterleavedDeviceOperationL1WIDTH_SHARDED2 µs32BF16 => BF16
ShardedToInterleavedDeviceOperationL1WIDTH_SHARDED2 µs32BF16 => BF16
BinaryNgDeviceOperationL1INTERLEAVED10 µs110BF16, BF16 => BF16
InterleavedToShardedDeviceOperationL1INTERLEAVED2 µs32BF16 => BF16
MatmulDeviceOperation 32 x 17408 x 5120L1WIDTH_SHARDED308 µs12289.8 GB/s56.6%18.5 TFLOPSHiFi2 BF16 x BFP8 => BF16
ShardedToInterleavedDeviceOperationL1WIDTH_SHARDED1 µs32BF16 => BF16
957.56 µs

Decode LoFi

DRAM Interleaved: BFP8 weights, BF16 inputs. LoFi

OpTypeLayoutDevice TimeCoresDRAMDRAM UtilFLOPSFidelity
MatmulDeviceOperation 32 x 5120 x 17408DRAMINTERLEAVED247 µs109366.8 GB/s71.6%23.1 TFLOPSLoFi BF16 x BFP8 => BF16
MatmulDeviceOperation 32 x 5120 x 17408DRAMINTERLEAVED248 µs109365.7 GB/s71.4%23 TFLOPSLoFi BF16 x BFP8 => BF16
BinaryNgDeviceOperationDRAMINTERLEAVED11 µs110BF16, BF16 => BF16
MatmulDeviceOperation 32 x 17408 x 5120DRAMINTERLEAVED344 µs80263.2 GB/s51.4%16.6 TFLOPSLoFi BF16 x BFP8 => BF16
849.84 µs

DRAM Sharded: BFP8 weights, BF16 inputs, LoFi.

OpTypeLayoutDevice TimeCoresDRAMDRAM UtilFLOPSFidelity
InterleavedToShardedDeviceOperationDRAMINTERLEAVED1 µs32BF16 => BF16
MatmulDeviceOperation 32 x 5120 x 17408L1WIDTH_SHARDED197 µs12452.5 GB/s88.4%29 TFLOPSLoFi BF16 x BFP8 => BF16
MatmulDeviceOperation 32 x 5120 x 17408L1WIDTH_SHARDED197 µs12452.2 GB/s88.3%28.9 TFLOPSLoFi BF16 x BFP8 => BF16
ShardedToInterleavedDeviceOperationL1WIDTH_SHARDED2 µs32BF16 => BF16
ShardedToInterleavedDeviceOperationL1WIDTH_SHARDED2 µs32BF16 => BF16
BinaryNgDeviceOperationL1INTERLEAVED10 µs110BF16, BF16 => BF16
InterleavedToShardedDeviceOperationL1INTERLEAVED2 µs32BF16 => BF16
MatmulDeviceOperation 32 x 17408 x 5120L1WIDTH_SHARDED194 µs12460.5 GB/s89.9%29.5 TFLOPSLoFi BF16 x BFP8 => BF16
ShardedToInterleavedDeviceOperationL1WIDTH_SHARDED1 µs32BF16 => BF16
606.71 µs