Skip to content

Commit

Permalink
Replaced nested for-loops in tune_gemm.py with itertools.product
Browse files Browse the repository at this point in the history
  • Loading branch information
ravil-mobile committed Nov 26, 2024
1 parent 88d06ec commit 71f945b
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 41 deletions.
38 changes: 20 additions & 18 deletions python/perf-kernels/streamk/tune_streamk.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from datetime import datetime
import multiprocessing
import pandas as pd
import itertools

from utils.file_generator import (
gen_configStr,
Expand Down Expand Up @@ -63,22 +64,17 @@ def get_full_tuning_space():
kpack_range = [1, 2]
num_sms_range = [304]

for block_m in block_mn_range:
for block_n in block_mn_range:
for block_k in block_k_range:
for num_warps in num_warps_range:
for group_m in group_m_range:
for num_sms in num_sms_range:
for num_stages in num_stage_range:
for waves_per_eu in waves_per_eu_range:
for matrix_instr_nonkdim in matrix_instr_nonkdim_range:
for kpack in kpack_range:
configs.append({
'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K':
block_k, 'GROUP_SIZE_M': group_m, 'NUM_SMS': num_sms, 'num_warps':
num_warps, 'num_stages': num_stages, 'waves_per_eu': waves_per_eu,
'matrix_instr_nonkdim': matrix_instr_nonkdim, 'kpack': kpack
})
space = itertools.product(block_mn_range, block_mn_range, block_k_range, num_warps_range, group_m_range,
num_sms_range, num_stage_range, waves_per_eu_range, matrix_instr_nonkdim_range,
kpack_range)

for instance in space:
block_m, block_n, block_k, num_warps, group_m, num_sms, num_stages, waves_per_eu, matrix_instr_nonkdim, kpack = instance
configs.append({
'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': block_k, 'GROUP_SIZE_M': group_m,
'NUM_SMS': num_sms, 'num_warps': num_warps, 'num_stages': num_stages, 'waves_per_eu': waves_per_eu,
'matrix_instr_nonkdim': matrix_instr_nonkdim, 'kpack': kpack
})

return configs

Expand Down Expand Up @@ -139,8 +135,14 @@ def prune_configs(M, N, K, configs, elemBytes_a, elemBytes_b):
continue
# out of shared memory resource
# TODO (zhanglx): This does not consider the LDS usage in the epilogue
LDS = BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a + BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b
LDS = LDS if not num_stages else LDS * (num_stages - 1)
LDSA = BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a
LDSB = BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b
if num_stages <= 1:
# No pipeline, buffer A and buffer B can re-use each other
LDS = max(LDSA, LDSB)
else:
# Pipeline, we need (num_stages - 1) buffers for both A and B at the same time
LDS = (LDSA + LDSB) * (num_stages - 1)
if LDS > 65536:
continue
# Skip small block sizes and num_warps for large gemm
Expand Down
34 changes: 14 additions & 20 deletions python/perf-kernels/tools/tune_gemm/tune_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from datetime import datetime
import multiprocessing
import pandas as pd
import itertools

from utils.file_generator import (
gen_configStr,
Expand Down Expand Up @@ -64,26 +65,19 @@ def get_full_tuning_space():
waves_per_eu_range = [0]
matrix_instr_nonkdim_range = [16, 32]
kpack_range = [1, 2]
sched_variants = ["\"default\""]

for block_m in block_mn_range:
for block_n in block_mn_range:
for block_k in block_k_range:
for num_warps in num_warps_range:
for group_m in group_m_range:
for split_k in split_k_range:
for num_stages in num_stage_range:
for waves_per_eu in waves_per_eu_range:
for matrix_instr_nonkdim in matrix_instr_nonkdim_range:
for sched_variant in sched_variants:
for kpack in kpack_range:
configs.append({
'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K':
block_k, 'GROUP_SIZE_M': group_m, 'SPLIT_K': split_k, 'num_warps':
num_warps, 'num_stages': num_stages, 'waves_per_eu': waves_per_eu,
'matrix_instr_nonkdim': matrix_instr_nonkdim, 'kpack': kpack,
'instruction_sched_variant': sched_variant
})
sched_variants = ["none"]

space = itertools.product(block_mn_range, block_mn_range, block_k_range, num_warps_range, group_m_range,
split_k_range, num_stage_range, waves_per_eu_range, matrix_instr_nonkdim_range,
sched_variants, kpack_range)

for instance in space:
block_m, block_n, block_k, num_warps, group_m, split_k, num_stages, waves_per_eu, matrix_instr_nonkdim, sched_variant, kpack = instance
configs.append({
'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': block_k, 'GROUP_SIZE_M': group_m,
'SPLIT_K': split_k, 'num_warps': num_warps, 'num_stages': num_stages, 'waves_per_eu': waves_per_eu,
'matrix_instr_nonkdim': matrix_instr_nonkdim, 'kpack': kpack, 'instruction_sched_variant': sched_variant
})

return configs

Expand Down
7 changes: 4 additions & 3 deletions python/perf-kernels/tools/tune_gemm/utils/file_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ def gen_configStr(config):
config)

## {M}_{N}_{K} is removed since the same kernel can be used for differen gemm sizes
configStr = f"BM{block_m}_BN{block_n}_BK{block_k}_GM{group_m}_SK{split_k}_nW{num_warps}_nS{num_stages}_EU{waves_per_eu}_kP{kpack}_mfma{mfmaInstrSize}_sched{sched_variant[1:-1].upper()}"
sched_variant = sched_variant.upper().replace('-', '_')
configStr = f"BM{block_m}_BN{block_n}_BK{block_k}_GM{group_m}_SK{split_k}_nW{num_warps}_nS{num_stages}_EU{waves_per_eu}_kP{kpack}_mfma{mfmaInstrSize}_sched{sched_variant}"

return configStr

Expand Down Expand Up @@ -113,7 +114,7 @@ def matmul_{configStr}(M, N, K, am, ak, bk, bn, cm, cn, biasn):
EVEN_K = {EVEN_K},
GRID_MN = grid_mn,
NUM_XCDS = {num_xcds},
instruction_sched_variant = {sched_variant},
instruction_sched_variant = \"{sched_variant}\",
grid=(1,),
)
return None
Expand Down Expand Up @@ -148,7 +149,7 @@ def matmul_{configStr}(a, b, c, bias, M, N, K, am, ak, bk, bn, cm, cn, biasn):
EVEN_K = {EVEN_K},
GRID_MN = grid[0],
NUM_XCDS = {num_xcds},
instruction_sched_variant = {sched_variant},
instruction_sched_variant = \"{sched_variant}\",
)
return c
"""
Expand Down

0 comments on commit 71f945b

Please sign in to comment.