From 32e83eacbc734e6c497d238cb0fb1e03dfae5972 Mon Sep 17 00:00:00 2001 From: Lixun Zhang Date: Tue, 23 Jul 2024 16:05:13 -0500 Subject: [PATCH 1/6] Initial commit of bench_gemm.py --- scripts/amd/gemm/bench_gemm.py | 212 ++++++++++++++++++++++++++++++++ scripts/amd/gemm/tune_gemm.py | 14 +-- scripts/amd/gemm/utils/utils.py | 10 ++ 3 files changed, 223 insertions(+), 13 deletions(-) create mode 100644 scripts/amd/gemm/bench_gemm.py diff --git a/scripts/amd/gemm/bench_gemm.py b/scripts/amd/gemm/bench_gemm.py new file mode 100644 index 000000000000..47dc1291d119 --- /dev/null +++ b/scripts/amd/gemm/bench_gemm.py @@ -0,0 +1,212 @@ +import argparse +import sys +import yaml +import os +import pandas as pd + +from utils.utils import * + + +def parse_args(): + parser = argparse.ArgumentParser( + prog="tune a specific gemm size", + allow_abbrev=False, + ) + + parser.add_argument("-dtype_a", + type=str, + default='fp16', + help="matrix a element data type") + parser.add_argument("-dtype_b", + type=str, + default='fp16', + help="matrix b element data type") + parser.add_argument("-dtype_c", + type=str, + default='fp16', + help="output element data type") + parser.add_argument("--gemm_size_file", + type=str, + default="", + help='yaml file to indicate matrix size') + parser.add_argument("--triton_result", + type=str, + default="", + help='yaml file to load (for benchmarking) or ' + 'store (for tuning) triton results') + parser.add_argument("--tune_hipblaslt", + action='store_true', + default=False, + help='Run tuning with hipblaslt') + parser.add_argument("--tune_triton", + action='store_true', + default=False, + help='Run tuning with triton') + parser.add_argument("--hipblaslt_result", + type=str, + default="", + help='csv file to load (if not tuning hipblaslt) or ' + 'store (if tuning hipblaslt) hipblaslt tuning results') + + args = parser.parse_args() + return args + + +def run_hipblaslt_bench(hipblaslt_bench, M, N, K, transA, transB, dtype): + ITER = 10 + WARMUP = 100 + dtype = 'f16_r' if dtype == "fp16" else 'f8_r' + hipBLASLt_bench_args = f"-f matmul -r {dtype} -m {M} -n {N} -k {K}" + hipBLASLt_bench_args += f" --transA {transA} --transB {transB}" + hipBLASLt_bench_args += f" --compute_type f32_r --algo_method all" + hipBLASLt_bench_args += f" -i {ITER} -j {WARMUP} --print_kernel_info" + SED_WINNER = "sed -n '/Winner:/, $p'" + + print(f"Tuning hipblaslt with {hipBLASLt_bench_args}") + + winner = run_bash_command( + f"HIP_FORCE_DEV_KERNARG=1 {hipblaslt_bench} {hipBLASLt_bench_args} | {SED_WINNER}" + ) + + for line in winner: + line = line.decode('utf-8') + + if "Solution index" in line: + winner_index = int(line.split(':', 1)[1].strip()) + if "kernel name" in line: + kernel_name = line.split(':', 1)[1].strip() + if f"{M},{N},{K}" in line: + tflops = int(line.split(',')[-2].strip()) / 1000 + us = float(line.split(',')[-1].strip()) + + return winner_index, kernel_name, tflops, us + + +def run_triton_tuning(input, output, dtype_a): + print(f"Tuning gemm sizes from {input} with Triton") + run_bash_command( + f"./tune_gemm.py --gemm_size_file {input} -dtype_a {dtype_a} -dtype_b {dtype_a} --ngpus 8 --jobs 32 --o {output}", + False) + + +def run_triton_bench(input, dtype_a): + if not os.path.exists(input): + print(f"{input} does not exist, please run tuning first") + sys.exit(1) + print(f"Benchmarking gemms from {input} with Triton") + triton_output = run_bash_command( + f"./tune_gemm.py --gemm_size_file {input} -dtype_a {dtype_a} -dtype_b {dtype_a} --benchmark" + ) + + data = [] + for line in triton_output: + line = line.decode('utf-8') + + if "Benchmarking" in line or "trans" in line: + continue + + items = line.split() + trans = items[0].strip() + M = items[1].strip() + N = items[2].strip() + K = items[3].strip() + tflops = items[4].strip() + us = items[5].strip() + + data.append([trans, int(M), int(N), int(K), float(tflops), float(us)]) + + return pd.DataFrame(data, columns=['trans', 'M', 'N', 'K', 'tflops', 'us']) + + +def main(): + args = parse_args() + gemm_size_file = args.gemm_size_file + hipblaslt_csv = args.hipblaslt_result + triton_yaml = args.triton_result + + if not gemm_size_file: + print("Need to provide gemm size file: --i filename") + sys.exit(1) + + if not triton_yaml: + print( + "Need to provide triton result filename: --triton_result filename.yaml" + ) + sys.exit(1) + + if not hipblaslt_csv: + print( + "Need to provide hipblaslt result filename: --hipblaslt_result filename.csv" + ) + sys.exit(1) + + # Get element type + dtype_a = args.dtype_a + dtype_b = args.dtype_b + dtype_c = args.dtype_c + + if args.tune_triton: + run_triton_tuning(gemm_size_file, triton_yaml, dtype_a) + + df_triton = run_triton_bench(triton_yaml, dtype_a) + + if args.tune_hipblaslt: + with open(gemm_size_file) as inFile: + matrix_sizes = yaml.safe_load(inFile) + + mnks = [] + for item in matrix_sizes: + M = item['M'] + N = item['N'] + K = item['K'] + transA = item['rowMajorA'] + transB = item['rowMajorB'] + mnks.append((M, N, K, transA, transB)) + + hipblaslt_ROOT_DIR = os.environ.get['HIPBLASLT_ROOT'] + if not hipblaslt_ROOT_DIR: + print("Need to provide hipblaslt root dir: HIPBLASLT_ROOT") + sys.exit(1) + hipblaslt_bench = os.path.join(hipblaslt_ROOT_DIR, + "build/clients/staging", + "hipblaslt-bench") + + hipblaslt_data = [] + + for (M, N, K, transA, transB) in mnks: + if not (transA == 'T' and transB == 'N'): + ## It seems hipblaslt does not support TT case? + continue + winner_index, kernel_name, tflops, us = run_hipblaslt_bench( + hipblaslt_bench, M, N, K, transA, transB, dtype_a) + hipblaslt_data.append([ + f"{transA}{transB}", M, N, K, tflops, us, winner_index, + kernel_name + ]) + + df_hipblaslt = pd.DataFrame(hipblaslt_data, + columns=[ + 'trans', 'M', 'N', 'K', 'tflops', 'us', + 'winner_idx', 'kernel_name' + ]) + df_hipblaslt.to_csv(hipblaslt_csv, index=False) + else: + if not os.path.exists(hipblaslt_csv): + print(f"{hipblaslt_csv} does not exist, please run tuning first") + sys.exit(1) + df_hipblaslt = pd.read_csv(hipblaslt_csv) + + df_merged = pd.merge(df_triton, + df_hipblaslt, + on=['trans', 'M', 'N', 'K'], + how='left', + suffixes=('_triton', '_hipblaslt')) + + + print(df_merged[[ + 'trans', 'M', 'N', 'K', 'tflops_triton', 'tflops_hipblaslt' + ]]) + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/scripts/amd/gemm/tune_gemm.py b/scripts/amd/gemm/tune_gemm.py index 3fdd7da082b5..4760e793e29b 100755 --- a/scripts/amd/gemm/tune_gemm.py +++ b/scripts/amd/gemm/tune_gemm.py @@ -668,16 +668,6 @@ def type_name_to_bytes(ty_name): sys.exit(1) -def format_output(unformatted): - if unformatted < 0.0001: - formatted = "{:.3e}".format(unformatted) - elif unformatted > 1000: - formatted = "{:.1f}".format(unformatted) - else: - formatted = "{:.2f}".format(unformatted) - return formatted - - def get_rocm_version(): torch_hip_version = torch.version.hip vers = torch_hip_version.split('.') @@ -885,9 +875,7 @@ def main(): } sizeDict.update(bestConfig) if not run_bench: - f_results.write("- " + str(sizeDict) + " ") - f_results.write( - f'# TFLOPS: {formatted_tflops} time(us): {minTime}\n') + f_results.write("- " + str(sizeDict) + "\n") # remove generated files if asked to if not keepTmp: diff --git a/scripts/amd/gemm/utils/utils.py b/scripts/amd/gemm/utils/utils.py index 9b6b50ea626b..028714bcad38 100644 --- a/scripts/amd/gemm/utils/utils.py +++ b/scripts/amd/gemm/utils/utils.py @@ -113,3 +113,13 @@ def patch_triton_compiler(): run_bash_command(f"sed -i 's/import torch/return True/g' {hip_driver_filename}") run_bash_command(f"sed -i 's/device = self.get_current_device()/return GPUTarget(\"hip\", \"{target.arch}\", 64)/g' {hip_driver_filename}") run_bash_command(f"sed -i 's/import torch/return False/g' {cuda_driver_filename}") + + +def format_output(unformatted): + if unformatted < 0.0001: + formatted = "{:.3e}".format(unformatted) + elif unformatted > 1000: + formatted = "{:.1f}".format(unformatted) + else: + formatted = "{:.2f}".format(unformatted) + return formatted From af3e467d2d4fae40527db5d0107b684df7c43808 Mon Sep 17 00:00:00 2001 From: Lixun Zhang Date: Wed, 24 Jul 2024 10:42:59 -0500 Subject: [PATCH 2/6] Add beautiful gemm sizes --- scripts/amd/gemm/configs/beautiful.yaml | 9 +++++++++ 1 file changed, 9 insertions(+) create mode 100644 scripts/amd/gemm/configs/beautiful.yaml diff --git a/scripts/amd/gemm/configs/beautiful.yaml b/scripts/amd/gemm/configs/beautiful.yaml new file mode 100644 index 000000000000..3c7a5cfac9ce --- /dev/null +++ b/scripts/amd/gemm/configs/beautiful.yaml @@ -0,0 +1,9 @@ +## TN +# The best gemm size that provides the best perf number +- {'M': 4864, 'N': 4096, 'K': 4224, 'rowMajorA': 'T', 'rowMajorB': 'N'} +# K % 256 == 0, it has some cache conflict issue +- {'M': 4864, 'N': 4096, 'K': 4096, 'rowMajorA': 'T', 'rowMajorB': 'N'} + +## TT replace the same config except TT +- {'M': 4864, 'N': 4096, 'K': 4224, 'rowMajorA': 'T', 'rowMajorB': 'T'} +- {'M': 4864, 'N': 4096, 'K': 4096, 'rowMajorA': 'T', 'rowMajorB': 'T'} From 6a3d43210fa9db2e0dcdae396b824bb56b8b333c Mon Sep 17 00:00:00 2001 From: Lixun Zhang Date: Wed, 24 Jul 2024 20:59:47 -0500 Subject: [PATCH 3/6] [tune_gemm] refactor and release restrictions for BLOCK_K --- scripts/amd/gemm/tune_gemm.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/scripts/amd/gemm/tune_gemm.py b/scripts/amd/gemm/tune_gemm.py index 4760e793e29b..ecd24032564e 100755 --- a/scripts/amd/gemm/tune_gemm.py +++ b/scripts/amd/gemm/tune_gemm.py @@ -86,6 +86,13 @@ def get_default_config(): return full_configs[0] +def infer_k_dim(nonkdim, elemBytes_a): + if nonkdim == 16: + return 32 if elemBytes_a == 1 else 16 + elif nonkdim == 32: + return 16 if elemBytes_a == 1 else 8 + + def prune_configs(M, N, K, configs, elemBytes_a, elemBytes_b): pruned_configs = [] @@ -107,6 +114,10 @@ def prune_configs(M, N, K, configs, elemBytes_a, elemBytes_b): num_stages = config.get("num_stages") matrix_instr_nonkdim = config.get("matrix_instr_nonkdim") kpack = config.get("kpack") + ## Need to skip the config if K dim is not large enough to include + ## mfma_k * kpack + if kpack * infer_k_dim(matrix_instr_nonkdim, elemBytes_a) > K: + continue if matrix_instr_nonkdim > mfma: continue if mfma == 4 and BLOCK_SIZE_K < 64: @@ -151,12 +162,10 @@ def prune_configs(M, N, K, configs, elemBytes_a, elemBytes_b): if large_gemm: if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64: continue - if BLOCK_SIZE_K < 64: - continue if num_warps < 4: continue # check if tiling is integer multiple of GEMM size because we have no boundary check - if M % BLOCK_SIZE_M != 0 or N % BLOCK_SIZE_N != 0 or K % BLOCK_SIZE_K != 0: + if M % BLOCK_SIZE_M != 0 or N % BLOCK_SIZE_N != 0: continue pruned_configs.append(config) @@ -863,7 +872,7 @@ def main(): # write best config to tuning_results.yaml if run_bench: - print(f"{formatted_tflops} {minTime}") + print(f"{formatted_tflops} {minTime} {bestConfig_compact_str}") f_results.write(f"{formatted_tflops},{minTime}\n") sizeDict = { @@ -875,7 +884,9 @@ def main(): } sizeDict.update(bestConfig) if not run_bench: - f_results.write("- " + str(sizeDict) + "\n") + f_results.write("- " + str(sizeDict) + " ") + f_results.write( + f'# {bestConfig_compact_str}\n') # remove generated files if asked to if not keepTmp: From 390e3ea02cd1802f36c7d465c9d64052639b7d2f Mon Sep 17 00:00:00 2001 From: Lixun Zhang Date: Wed, 24 Jul 2024 21:00:04 -0500 Subject: [PATCH 4/6] fix get HIPBLASLT_ROOT --- scripts/amd/gemm/bench_gemm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/amd/gemm/bench_gemm.py b/scripts/amd/gemm/bench_gemm.py index 47dc1291d119..39252680436d 100644 --- a/scripts/amd/gemm/bench_gemm.py +++ b/scripts/amd/gemm/bench_gemm.py @@ -163,7 +163,7 @@ def main(): transB = item['rowMajorB'] mnks.append((M, N, K, transA, transB)) - hipblaslt_ROOT_DIR = os.environ.get['HIPBLASLT_ROOT'] + hipblaslt_ROOT_DIR = os.environ.get('HIPBLASLT_ROOT') if not hipblaslt_ROOT_DIR: print("Need to provide hipblaslt root dir: HIPBLASLT_ROOT") sys.exit(1) From db775fbc0aeccd44e10acf0e5904ba8d35f27495 Mon Sep 17 00:00:00 2001 From: Lixun Zhang Date: Thu, 25 Jul 2024 10:39:28 -0500 Subject: [PATCH 5/6] Add some configs in beautiful --- scripts/amd/gemm/bench_gemm.py | 6 +++--- scripts/amd/gemm/configs/beautiful.yaml | 20 ++++++++++++++------ 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/scripts/amd/gemm/bench_gemm.py b/scripts/amd/gemm/bench_gemm.py index 39252680436d..c3b16cc6e2e9 100644 --- a/scripts/amd/gemm/bench_gemm.py +++ b/scripts/amd/gemm/bench_gemm.py @@ -62,7 +62,7 @@ def run_hipblaslt_bench(hipblaslt_bench, M, N, K, transA, transB, dtype): hipBLASLt_bench_args += f" -i {ITER} -j {WARMUP} --print_kernel_info" SED_WINNER = "sed -n '/Winner:/, $p'" - print(f"Tuning hipblaslt with {hipBLASLt_bench_args}") + print(f"Tuning hipblaslt with {hipBLASLt_bench_args}", flush=True) winner = run_bash_command( f"HIP_FORCE_DEV_KERNARG=1 {hipblaslt_bench} {hipBLASLt_bench_args} | {SED_WINNER}" @@ -83,7 +83,7 @@ def run_hipblaslt_bench(hipblaslt_bench, M, N, K, transA, transB, dtype): def run_triton_tuning(input, output, dtype_a): - print(f"Tuning gemm sizes from {input} with Triton") + print(f"Tuning gemm sizes from {input} with Triton", flush=True) run_bash_command( f"./tune_gemm.py --gemm_size_file {input} -dtype_a {dtype_a} -dtype_b {dtype_a} --ngpus 8 --jobs 32 --o {output}", False) @@ -93,7 +93,7 @@ def run_triton_bench(input, dtype_a): if not os.path.exists(input): print(f"{input} does not exist, please run tuning first") sys.exit(1) - print(f"Benchmarking gemms from {input} with Triton") + print(f"Benchmarking gemms from {input} with Triton", flush=True) triton_output = run_bash_command( f"./tune_gemm.py --gemm_size_file {input} -dtype_a {dtype_a} -dtype_b {dtype_a} --benchmark" ) diff --git a/scripts/amd/gemm/configs/beautiful.yaml b/scripts/amd/gemm/configs/beautiful.yaml index 3c7a5cfac9ce..5d085e028018 100644 --- a/scripts/amd/gemm/configs/beautiful.yaml +++ b/scripts/amd/gemm/configs/beautiful.yaml @@ -1,9 +1,17 @@ ## TN -# The best gemm size that provides the best perf number -- {'M': 4864, 'N': 4096, 'K': 4224, 'rowMajorA': 'T', 'rowMajorB': 'N'} -# K % 256 == 0, it has some cache conflict issue -- {'M': 4864, 'N': 4096, 'K': 4096, 'rowMajorA': 'T', 'rowMajorB': 'N'} +#- {'M': 4864, 'N': 4096, 'K': 4096, 'rowMajorA': 'T', 'rowMajorB': 'N'} +#- {'M': 4864, 'N': 4096, 'K': 4160, 'rowMajorA': 'T', 'rowMajorB': 'N'} +#- {'M': 4864, 'N': 4096, 'K': 4224, 'rowMajorA': 'T', 'rowMajorB': 'N'} +#- {'M': 4864, 'N': 4096, 'K': 4288, 'rowMajorA': 'T', 'rowMajorB': 'N'} + +- {'M': 4864, 'N': 4096, 'K': 4097, 'rowMajorA': 'T', 'rowMajorB': 'N'} +- {'M': 4864, 'N': 4096, 'K': 4098, 'rowMajorA': 'T', 'rowMajorB': 'N'} +- {'M': 4864, 'N': 4096, 'K': 4100, 'rowMajorA': 'T', 'rowMajorB': 'N'} +- {'M': 4864, 'N': 4096, 'K': 4104, 'rowMajorA': 'T', 'rowMajorB': 'N'} +- {'M': 4864, 'N': 4096, 'K': 4112, 'rowMajorA': 'T', 'rowMajorB': 'N'} +- {'M': 4864, 'N': 4096, 'K': 4128, 'rowMajorA': 'T', 'rowMajorB': 'N'} + ## TT replace the same config except TT -- {'M': 4864, 'N': 4096, 'K': 4224, 'rowMajorA': 'T', 'rowMajorB': 'T'} -- {'M': 4864, 'N': 4096, 'K': 4096, 'rowMajorA': 'T', 'rowMajorB': 'T'} +#- {'M': 4864, 'N': 4096, 'K': 4224, 'rowMajorA': 'T', 'rowMajorB': 'T'} +#- {'M': 4864, 'N': 4096, 'K': 4096, 'rowMajorA': 'T', 'rowMajorB': 'T'} From 2e19802e805caa53536caab38f375db957436bb0 Mon Sep 17 00:00:00 2001 From: Lixun Zhang Date: Thu, 25 Jul 2024 14:08:42 -0500 Subject: [PATCH 6/6] Add rotating_tensor 512 when tuning and benchmarking triton --- scripts/amd/gemm/bench_gemm.py | 4 ++-- scripts/amd/gemm/configs/beautiful.yaml | 10 ++++------ 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/scripts/amd/gemm/bench_gemm.py b/scripts/amd/gemm/bench_gemm.py index c3b16cc6e2e9..d65655bd0542 100644 --- a/scripts/amd/gemm/bench_gemm.py +++ b/scripts/amd/gemm/bench_gemm.py @@ -85,7 +85,7 @@ def run_hipblaslt_bench(hipblaslt_bench, M, N, K, transA, transB, dtype): def run_triton_tuning(input, output, dtype_a): print(f"Tuning gemm sizes from {input} with Triton", flush=True) run_bash_command( - f"./tune_gemm.py --gemm_size_file {input} -dtype_a {dtype_a} -dtype_b {dtype_a} --ngpus 8 --jobs 32 --o {output}", + f"./tune_gemm.py --gemm_size_file {input} -dtype_a {dtype_a} -dtype_b {dtype_a} --ngpus 8 --jobs 32 --o {output} --rotating_tensor 512", False) @@ -95,7 +95,7 @@ def run_triton_bench(input, dtype_a): sys.exit(1) print(f"Benchmarking gemms from {input} with Triton", flush=True) triton_output = run_bash_command( - f"./tune_gemm.py --gemm_size_file {input} -dtype_a {dtype_a} -dtype_b {dtype_a} --benchmark" + f"./tune_gemm.py --gemm_size_file {input} -dtype_a {dtype_a} -dtype_b {dtype_a} --benchmark --rotating_tensor 512" ) data = [] diff --git a/scripts/amd/gemm/configs/beautiful.yaml b/scripts/amd/gemm/configs/beautiful.yaml index 5d085e028018..c020d51576df 100644 --- a/scripts/amd/gemm/configs/beautiful.yaml +++ b/scripts/amd/gemm/configs/beautiful.yaml @@ -1,16 +1,14 @@ ## TN -#- {'M': 4864, 'N': 4096, 'K': 4096, 'rowMajorA': 'T', 'rowMajorB': 'N'} -#- {'M': 4864, 'N': 4096, 'K': 4160, 'rowMajorA': 'T', 'rowMajorB': 'N'} -#- {'M': 4864, 'N': 4096, 'K': 4224, 'rowMajorA': 'T', 'rowMajorB': 'N'} -#- {'M': 4864, 'N': 4096, 'K': 4288, 'rowMajorA': 'T', 'rowMajorB': 'N'} - +- {'M': 4864, 'N': 4096, 'K': 4096, 'rowMajorA': 'T', 'rowMajorB': 'N'} - {'M': 4864, 'N': 4096, 'K': 4097, 'rowMajorA': 'T', 'rowMajorB': 'N'} - {'M': 4864, 'N': 4096, 'K': 4098, 'rowMajorA': 'T', 'rowMajorB': 'N'} - {'M': 4864, 'N': 4096, 'K': 4100, 'rowMajorA': 'T', 'rowMajorB': 'N'} - {'M': 4864, 'N': 4096, 'K': 4104, 'rowMajorA': 'T', 'rowMajorB': 'N'} - {'M': 4864, 'N': 4096, 'K': 4112, 'rowMajorA': 'T', 'rowMajorB': 'N'} - {'M': 4864, 'N': 4096, 'K': 4128, 'rowMajorA': 'T', 'rowMajorB': 'N'} - +- {'M': 4864, 'N': 4096, 'K': 4160, 'rowMajorA': 'T', 'rowMajorB': 'N'} +- {'M': 4864, 'N': 4096, 'K': 4224, 'rowMajorA': 'T', 'rowMajorB': 'N'} +- {'M': 4864, 'N': 4096, 'K': 4288, 'rowMajorA': 'T', 'rowMajorB': 'N'} ## TT replace the same config except TT #- {'M': 4864, 'N': 4096, 'K': 4224, 'rowMajorA': 'T', 'rowMajorB': 'T'}