Skip to content

Commit

Permalink
Added tl.assume clauses for all gemm-kernel strides
Browse files Browse the repository at this point in the history
  • Loading branch information
ravil-mobile committed Oct 29, 2024
1 parent 0bce6bf commit 42bca31
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions python/perf-kernels/tools/tune_gemm/matmul_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,15 @@ def matmul_kernel(a_ptr, b_ptr, c_ptr, bias_ptr, M, N, K, stride_am, stride_ak,
stride_cn, stride_bias, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr, SPLIT_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, BIAS: tl.constexpr,
EVEN_K: tl.constexpr, GRID_MN: tl.constexpr, NUM_XCDS: tl.constexpr):

tl.assume(stride_am > 0)
tl.assume(stride_ak > 0)
tl.assume(stride_bk > 0)
tl.assume(stride_bn > 0)
tl.assume(stride_cm > 0)
tl.assume(stride_cn > 0)
tl.assume(stride_bias > 0)

pid = tl.program_id(axis=0)
pid_z = tl.program_id(1)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
Expand All @@ -33,6 +42,9 @@ def matmul_kernel(a_ptr, b_ptr, c_ptr, bias_ptr, M, N, K, stride_am, stride_ak,
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m

tl.assume(pid_m > 0)
tl.assume(pid_n > 0)

if SPLIT_K == 1:
offs_k = tl.arange(0, BLOCK_SIZE_K)
else:
Expand Down

0 comments on commit 42bca31

Please sign in to comment.