@@ -268,7 +268,9 @@ def my_matmul(
268268 # Fix fifo depth for C objfifo to 1 since 1 buffer will be used for accumulation
269269 # and another for transfer to L2
270270 fifo_depth_out = 1
271+ # Set the type for accumulation
271272 C_l1_ty_internal = np .ndarray [(m , n ), np .dtype [dtype_out_internal ]]
273+ # A kernel to convert from the internal f32 accumulation to bf16 for transfer to L2 is needed
272274 convert_copy_kernel = Kernel (
273275 f"convert_copy_f32_to_bf16" ,
274276 f"gemm_{ m } x{ k } x{ n } _archive.a" ,
@@ -287,6 +289,8 @@ def my_matmul(
287289 [A_l1_ty , B_l1_ty , C_l1_ty_internal ],
288290 )
289291 else :
292+ # No need to use separate buffers for accumulation and transfer to L2, so
293+ # we only need the zero and matmul kernels
290294 fifo_depth_out = fifo_depth
291295 zero_kernel = Kernel (
292296 f"zero{ scalar_suffix } _{ dtype_out_str } " ,
@@ -470,6 +474,8 @@ def core_fn(in_a, in_b, out_c, zero, matmul, convert_copy, my_rtp, barrier):
470474 # tb = transfer block; block of transfers before sync call
471475 tb_max_n_rows = 4
472476 tb_n_rows = tb_max_n_rows // 2
477+
478+ # Calculate RTP values for the reduction loop and total C tiles
473479 K_div_k = K // k
474480 n_c_col_tiles_per_core = N // mem_tile_n
475481 n_c_row_tiles_per_core = M // mem_tile_m_C
0 commit comments