diff --git a/scripts/amd/gemm/tune_gemm.py b/scripts/amd/gemm/tune_gemm.py index 3b712ee5bdb0..2f03871b37f4 100644 --- a/scripts/amd/gemm/tune_gemm.py +++ b/scripts/amd/gemm/tune_gemm.py @@ -1,3 +1,4 @@ +# fp8 import argparse import sys import yaml @@ -221,7 +222,7 @@ def generated_kernel_name(M, N, K, gpu_id): # 4. test_gemm to invoke # 4.1 run try_config in parallel # 4.2 matmul in a loop of 10 iterations -def generate_kernel(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, configs, jobs, run_bench): +def generate_kernel(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, configs, jobs, run_bench): filenames = [] for i in range(jobs): filenames.append(generated_kernel_name(M, N, K, i)) @@ -259,8 +260,8 @@ def generate_kernel(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, configs, j # pre string test_gemm_pre_str = f"""def test_gemm(M, N, K, num_threads): thread_pool = multiprocessing.Pool(processes=num_threads) - a, a_fp16 = gen_input(M, K, '{dtype_a}', {col_a}, 1, device='cuda') - b, b_fp16 = gen_input(K, N, '{dtype_b}', {col_b}, 2, device='cuda') + a, a_fp16 = gen_input(M, K, '{dtype_a}', {col_a}, 1, '{init_type}', device='cuda') + b, b_fp16 = gen_input(K, N, '{dtype_b}', {col_b}, 2, '{init_type}', device='cuda') c = torch.zeros((M, N), device=a.device, dtype={tl_to_torch_types[name_to_tl_types[dtype_c]]}) task_args = (M, N, K, a.stride(0), a.stride(1), @@ -359,9 +360,9 @@ def profile_batch_kernels(M, N, K, gpuid, gpus, jobs, verbose): jobId += ngpus -def tune_gemm_config(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, configs, run_bench, jobs, verbose=0, num_threads=16, gpus=[0]): +def tune_gemm_config(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, configs, run_bench, jobs, verbose=0, num_threads=16, gpus=[0]): # Generate kernel out of all configs - generate_kernel(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, configs, jobs, run_bench) + generate_kernel(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, configs, jobs, run_bench) # remove any compiled kernel in the cache run_bash_command("rm -rf ~/.triton/cache") @@ -418,7 +419,7 @@ def tune_gemm_config(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, configs, print(f"post procesing time: {post_time}", flush=True) return minTime, bestConfig, compile_time, profile_time, post_time -def gen_input(M, N, ty_name, needTrans, seed, device='cuda'): +def gen_input(M, N, ty_name, needTrans, seed, init_type, device='cuda'): d_type = name_to_tl_types[ty_name] torch.manual_seed(seed) torch.cuda.manual_seed(seed) @@ -431,10 +432,24 @@ def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): output = input tl.store(output_ptr + offsets, output, mask=mask) + def init_by_size_and_type(size, dtype, init_type): + if init_type == 'hpl': + return torch.empty(size, device='cuda', dtype=dtype).uniform_(-0.5, 0.5) + # This init type has element[i] in row[j] equal to sin(i+j*N) + elif init_type == 'trig_float': + M, N = size + return torch.reshape(torch.arange(0, M*N), (M, N)).sin().to(dtype=dtype, device='cuda') + elif init_type == 'zeros': + return torch.zeros(size, dtype=dtype, device='cuda') + elif init_type == "randn": + temp = torch.randn(size, dtype=dtype, device='cuda') + return temp + else: + raise ValueError("Bad matrix initialization type.") + + raw_data = init_by_size_and_type((N,M) if needTrans else (M,N), torch.float32, init_type) if needTrans: - raw_data = torch.randn((N, M), dtype=torch.float32, device='cuda').T - else: - raw_data = torch.randn((M, N), dtype=torch.float32, device='cuda') + raw_data = raw_data.T if (d_type == tl.float8e4b8 and TORCH_HAS_FP8E4B8) or \ (d_type == tl.float8e5b16 and TORCH_HAS_FP8E5B16) or not d_type.is_fp8(): input = raw_data.to(tl_to_torch_types[d_type]) @@ -481,14 +496,14 @@ def matmul(a, b, c, block_m, block_n, block_k, group_m, split_k, num_warps, num_ return c -def test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, config, verbose): +def test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, config, verbose): block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfmaInstrSize = read_config(config) torch.manual_seed(0) #a = torch.randn((M, K), device='cuda', dtype=datatype) #b = torch.randn((K, N), device='cuda', dtype=datatype) - a, a_fp16 = gen_input(M, K, dtype_a, col_a, 1, device='cuda') - b, b_fp16 = gen_input(K, N, dtype_b, col_b, 2, device='cuda') + a, a_fp16 = gen_input(M, K, dtype_a, col_a, 1, init_type, device='cuda') + b, b_fp16 = gen_input(K, N, dtype_b, col_b, 2, init_type, device='cuda') # Allocates output. c = torch.zeros((M, N), device=a.device, dtype=tl_to_torch_types[name_to_tl_types[dtype_c]]) triton_output = matmul(a, b, c, block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfmaInstrSize) @@ -544,6 +559,7 @@ def parse_args(): parser.add_argument("--verbose", action='store_true', default=False, help="enables time_breakdown and additional logging messages") parser.add_argument("--num_threads", type=int, default=16, help="number of threads to use for kernel compilation and post processing") parser.add_argument("--jobs", type=int, default=1, help="number of generated files") + parser.add_argument("--init_type", type=str, default='randn', help="Initialization type for input matrices (default uniform rand [0, 1.0)])") args = parser.parse_args() return args @@ -643,6 +659,7 @@ def main(): mnks = [] # TODO: make it more robust to get user input + init_type = args.init_type if matrix_size_file == "" or not os.path.isfile(matrix_size_file): M = args.m N = args.n @@ -660,7 +677,7 @@ def main(): # Check correctness from given configs if args.compare_wo_tuning: for (M, N, K, col_a, col_b, myConfig) in mnks: - test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, item, True) + test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, item, True) return configs_full = get_full_tuning_space() @@ -670,7 +687,7 @@ def main(): print(f"Benchmarking gemm with {dtype_a} inputs") print("trans M N K TFLOPS") else: - print(f"Tuning starts at: {start_time}", flush=True) + print(f"Tuning {len(mnks)} gemm sizes starts at: {start_time}", flush=True) f_results = open(tuning_output_file, 'w') for (M, N, K, col_a, col_b, myConfig) in mnks: @@ -684,6 +701,8 @@ def main(): size_str = f'SIZE: {M} {N} {K} {row_a_str}{row_b_str}' if not run_bench: print(f"{size_str} nConfigs: {len(pruned_configs)}", end=" ", flush=True) + else: + print(f"{row_a_str}{row_b_str} {M:5d} {N:5d} {K:5d} ", end="") # The main tuning funtion for one gemm size verbose_level = 0 @@ -691,7 +710,11 @@ def main(): verbose_level = 1 if args.verbose: verbose_level = 2 - minTime, bestConfig, compile_time, profile_time, post_time = tune_gemm_config(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, pruned_configs, run_bench, jobs, num_threads=args.num_threads, gpus=gpus, verbose=verbose_level) + minTime, bestConfig, compile_time, profile_time, post_time = tune_gemm_config( + M, N, K, col_a, col_b, dtype_a, + dtype_b, dtype_c, init_type, pruned_configs, + run_bench, jobs, num_threads=args.num_threads, gpus=gpus, + verbose=verbose_level) # post processing the numbers perf_tflops = lambda us: 2 * M * N * K * 1e-12 / (us * 1e-6) @@ -707,7 +730,7 @@ def main(): # write best config to tuning_results.yaml if run_bench: - print(f"{row_a_str}{row_b_str} {M:5d} {N:5d} {K:5d} {formatted_tflops}") + print(f"{formatted_tflops}") sizeDict = {'M': M, 'N': N, 'K': K, 'rowMajorA': row_a_str, 'rowMajorB': row_b_str} sizeDict.update(bestConfig) @@ -727,7 +750,7 @@ def main(): # Check correctness if asked to if args.compare: print("correctness: ", end=" ", flush=True) - test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, bestConfig, False) + test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, bestConfig, False) elif not run_bench: print("", flush=True)