Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add bench_gemm to benchmark matmul kernel with Triton and hipblaslt #616

Draft
wants to merge 6 commits into
base: triton-mlir
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 212 additions & 0 deletions scripts/amd/gemm/bench_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
import argparse
import sys
import yaml
import os
import pandas as pd

from utils.utils import *


def parse_args():
parser = argparse.ArgumentParser(
prog="tune a specific gemm size",
allow_abbrev=False,
)

parser.add_argument("-dtype_a",
type=str,
default='fp16',
help="matrix a element data type")
parser.add_argument("-dtype_b",
type=str,
default='fp16',
help="matrix b element data type")
parser.add_argument("-dtype_c",
type=str,
default='fp16',
help="output element data type")
parser.add_argument("--gemm_size_file",
type=str,
default="",
help='yaml file to indicate matrix size')
parser.add_argument("--triton_result",
type=str,
default="",
help='yaml file to load (for benchmarking) or '
'store (for tuning) triton results')
parser.add_argument("--tune_hipblaslt",
action='store_true',
default=False,
help='Run tuning with hipblaslt')
parser.add_argument("--tune_triton",
action='store_true',
default=False,
help='Run tuning with triton')
parser.add_argument("--hipblaslt_result",
type=str,
default="",
help='csv file to load (if not tuning hipblaslt) or '
'store (if tuning hipblaslt) hipblaslt tuning results')

args = parser.parse_args()
return args


def run_hipblaslt_bench(hipblaslt_bench, M, N, K, transA, transB, dtype):
ITER = 10
WARMUP = 100
dtype = 'f16_r' if dtype == "fp16" else 'f8_r'
hipBLASLt_bench_args = f"-f matmul -r {dtype} -m {M} -n {N} -k {K}"
hipBLASLt_bench_args += f" --transA {transA} --transB {transB}"
hipBLASLt_bench_args += f" --compute_type f32_r --algo_method all"
hipBLASLt_bench_args += f" -i {ITER} -j {WARMUP} --print_kernel_info"
SED_WINNER = "sed -n '/Winner:/, $p'"

print(f"Tuning hipblaslt with {hipBLASLt_bench_args}", flush=True)

winner = run_bash_command(
f"HIP_FORCE_DEV_KERNARG=1 {hipblaslt_bench} {hipBLASLt_bench_args} | {SED_WINNER}"
)

for line in winner:
line = line.decode('utf-8')

if "Solution index" in line:
winner_index = int(line.split(':', 1)[1].strip())
if "kernel name" in line:
kernel_name = line.split(':', 1)[1].strip()
if f"{M},{N},{K}" in line:
tflops = int(line.split(',')[-2].strip()) / 1000
us = float(line.split(',')[-1].strip())

return winner_index, kernel_name, tflops, us


def run_triton_tuning(input, output, dtype_a):
print(f"Tuning gemm sizes from {input} with Triton", flush=True)
run_bash_command(
f"./tune_gemm.py --gemm_size_file {input} -dtype_a {dtype_a} -dtype_b {dtype_a} --ngpus 8 --jobs 32 --o {output} --rotating_tensor 512",
False)


def run_triton_bench(input, dtype_a):
if not os.path.exists(input):
print(f"{input} does not exist, please run tuning first")
sys.exit(1)
print(f"Benchmarking gemms from {input} with Triton", flush=True)
triton_output = run_bash_command(
f"./tune_gemm.py --gemm_size_file {input} -dtype_a {dtype_a} -dtype_b {dtype_a} --benchmark --rotating_tensor 512"
)

data = []
for line in triton_output:
line = line.decode('utf-8')

if "Benchmarking" in line or "trans" in line:
continue

items = line.split()
trans = items[0].strip()
M = items[1].strip()
N = items[2].strip()
K = items[3].strip()
tflops = items[4].strip()
us = items[5].strip()

data.append([trans, int(M), int(N), int(K), float(tflops), float(us)])

return pd.DataFrame(data, columns=['trans', 'M', 'N', 'K', 'tflops', 'us'])


def main():
args = parse_args()
gemm_size_file = args.gemm_size_file
hipblaslt_csv = args.hipblaslt_result
triton_yaml = args.triton_result

if not gemm_size_file:
print("Need to provide gemm size file: --i filename")
sys.exit(1)

if not triton_yaml:
print(
"Need to provide triton result filename: --triton_result filename.yaml"
)
sys.exit(1)

if not hipblaslt_csv:
print(
"Need to provide hipblaslt result filename: --hipblaslt_result filename.csv"
)
sys.exit(1)

# Get element type
dtype_a = args.dtype_a
dtype_b = args.dtype_b
dtype_c = args.dtype_c

if args.tune_triton:
run_triton_tuning(gemm_size_file, triton_yaml, dtype_a)

df_triton = run_triton_bench(triton_yaml, dtype_a)

if args.tune_hipblaslt:
with open(gemm_size_file) as inFile:
matrix_sizes = yaml.safe_load(inFile)

mnks = []
for item in matrix_sizes:
M = item['M']
N = item['N']
K = item['K']
transA = item['rowMajorA']
transB = item['rowMajorB']
mnks.append((M, N, K, transA, transB))

hipblaslt_ROOT_DIR = os.environ.get('HIPBLASLT_ROOT')
if not hipblaslt_ROOT_DIR:
print("Need to provide hipblaslt root dir: HIPBLASLT_ROOT")
sys.exit(1)
hipblaslt_bench = os.path.join(hipblaslt_ROOT_DIR,
"build/clients/staging",
"hipblaslt-bench")

hipblaslt_data = []

for (M, N, K, transA, transB) in mnks:
if not (transA == 'T' and transB == 'N'):
## It seems hipblaslt does not support TT case?
continue
winner_index, kernel_name, tflops, us = run_hipblaslt_bench(
hipblaslt_bench, M, N, K, transA, transB, dtype_a)
hipblaslt_data.append([
f"{transA}{transB}", M, N, K, tflops, us, winner_index,
kernel_name
])

df_hipblaslt = pd.DataFrame(hipblaslt_data,
columns=[
'trans', 'M', 'N', 'K', 'tflops', 'us',
'winner_idx', 'kernel_name'
])
df_hipblaslt.to_csv(hipblaslt_csv, index=False)
else:
if not os.path.exists(hipblaslt_csv):
print(f"{hipblaslt_csv} does not exist, please run tuning first")
sys.exit(1)
df_hipblaslt = pd.read_csv(hipblaslt_csv)

df_merged = pd.merge(df_triton,
df_hipblaslt,
on=['trans', 'M', 'N', 'K'],
how='left',
suffixes=('_triton', '_hipblaslt'))


print(df_merged[[
'trans', 'M', 'N', 'K', 'tflops_triton', 'tflops_hipblaslt'
]])


if __name__ == '__main__':
sys.exit(main())
15 changes: 15 additions & 0 deletions scripts/amd/gemm/configs/beautiful.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
## TN
- {'M': 4864, 'N': 4096, 'K': 4096, 'rowMajorA': 'T', 'rowMajorB': 'N'}
- {'M': 4864, 'N': 4096, 'K': 4097, 'rowMajorA': 'T', 'rowMajorB': 'N'}
- {'M': 4864, 'N': 4096, 'K': 4098, 'rowMajorA': 'T', 'rowMajorB': 'N'}
- {'M': 4864, 'N': 4096, 'K': 4100, 'rowMajorA': 'T', 'rowMajorB': 'N'}
- {'M': 4864, 'N': 4096, 'K': 4104, 'rowMajorA': 'T', 'rowMajorB': 'N'}
- {'M': 4864, 'N': 4096, 'K': 4112, 'rowMajorA': 'T', 'rowMajorB': 'N'}
- {'M': 4864, 'N': 4096, 'K': 4128, 'rowMajorA': 'T', 'rowMajorB': 'N'}
- {'M': 4864, 'N': 4096, 'K': 4160, 'rowMajorA': 'T', 'rowMajorB': 'N'}
- {'M': 4864, 'N': 4096, 'K': 4224, 'rowMajorA': 'T', 'rowMajorB': 'N'}
- {'M': 4864, 'N': 4096, 'K': 4288, 'rowMajorA': 'T', 'rowMajorB': 'N'}

## TT replace the same config except TT
#- {'M': 4864, 'N': 4096, 'K': 4224, 'rowMajorA': 'T', 'rowMajorB': 'T'}
#- {'M': 4864, 'N': 4096, 'K': 4096, 'rowMajorA': 'T', 'rowMajorB': 'T'}
29 changes: 14 additions & 15 deletions scripts/amd/gemm/tune_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,13 @@ def get_default_config():
return full_configs[0]


def infer_k_dim(nonkdim, elemBytes_a):
if nonkdim == 16:
return 32 if elemBytes_a == 1 else 16
elif nonkdim == 32:
return 16 if elemBytes_a == 1 else 8


def prune_configs(M, N, K, configs, elemBytes_a, elemBytes_b):
pruned_configs = []

Expand All @@ -107,6 +114,10 @@ def prune_configs(M, N, K, configs, elemBytes_a, elemBytes_b):
num_stages = config.get("num_stages")
matrix_instr_nonkdim = config.get("matrix_instr_nonkdim")
kpack = config.get("kpack")
## Need to skip the config if K dim is not large enough to include
## mfma_k * kpack
if kpack * infer_k_dim(matrix_instr_nonkdim, elemBytes_a) > K:
continue
if matrix_instr_nonkdim > mfma:
continue
if mfma == 4 and BLOCK_SIZE_K < 64:
Expand Down Expand Up @@ -151,12 +162,10 @@ def prune_configs(M, N, K, configs, elemBytes_a, elemBytes_b):
if large_gemm:
if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64:
continue
if BLOCK_SIZE_K < 64:
continue
if num_warps < 4:
continue
# check if tiling is integer multiple of GEMM size because we have no boundary check
if M % BLOCK_SIZE_M != 0 or N % BLOCK_SIZE_N != 0 or K % BLOCK_SIZE_K != 0:
if M % BLOCK_SIZE_M != 0 or N % BLOCK_SIZE_N != 0:
continue

pruned_configs.append(config)
Expand Down Expand Up @@ -668,16 +677,6 @@ def type_name_to_bytes(ty_name):
sys.exit(1)


def format_output(unformatted):
if unformatted < 0.0001:
formatted = "{:.3e}".format(unformatted)
elif unformatted > 1000:
formatted = "{:.1f}".format(unformatted)
else:
formatted = "{:.2f}".format(unformatted)
return formatted


def get_rocm_version():
torch_hip_version = torch.version.hip
vers = torch_hip_version.split('.')
Expand Down Expand Up @@ -873,7 +872,7 @@ def main():

# write best config to tuning_results.yaml
if run_bench:
print(f"{formatted_tflops} {minTime}")
print(f"{formatted_tflops} {minTime} {bestConfig_compact_str}")
f_results.write(f"{formatted_tflops},{minTime}\n")

sizeDict = {
Expand All @@ -887,7 +886,7 @@ def main():
if not run_bench:
f_results.write("- " + str(sizeDict) + " ")
f_results.write(
f'# TFLOPS: {formatted_tflops} time(us): {minTime}\n')
f'# {bestConfig_compact_str}\n')

# remove generated files if asked to
if not keepTmp:
Expand Down
10 changes: 10 additions & 0 deletions scripts/amd/gemm/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,13 @@ def patch_triton_compiler():
run_bash_command(f"sed -i 's/import torch/return True/g' {hip_driver_filename}")
run_bash_command(f"sed -i 's/device = self.get_current_device()/return GPUTarget(\"hip\", \"{target.arch}\", 64)/g' {hip_driver_filename}")
run_bash_command(f"sed -i 's/import torch/return False/g' {cuda_driver_filename}")


def format_output(unformatted):
if unformatted < 0.0001:
formatted = "{:.3e}".format(unformatted)
elif unformatted > 1000:
formatted = "{:.1f}".format(unformatted)
else:
formatted = "{:.2f}".format(unformatted)
return formatted
Loading