From 49f6c3aa2b325ee4c770aa4d4cc5a7882333a2fa Mon Sep 17 00:00:00 2001
From: Vinayak Gokhale <vinayak.gokhale@amd.com>
Date: Fri, 26 Jan 2024 21:05:58 -0600
Subject: [PATCH] [GEMM] [Tuning] Option to try different initialization
 strategies (#486)

* Add a few different GEMM init strategies

* Minor fixes

* Fix order when transposed

* Fix transpose

* Update tune_gemm.py

Fix blank lines

* remove init_type from mnks

* Fix trig_float

---------

Co-authored-by: Lixun Zhang <lixun.zhang@amd.com>
---
 scripts/amd/gemm/tune_gemm.py | 57 ++++++++++++++++++++++++-----------
 1 file changed, 40 insertions(+), 17 deletions(-)

diff --git a/scripts/amd/gemm/tune_gemm.py b/scripts/amd/gemm/tune_gemm.py
index 3b712ee5bdb0..2f03871b37f4 100644
--- a/scripts/amd/gemm/tune_gemm.py
+++ b/scripts/amd/gemm/tune_gemm.py
@@ -1,3 +1,4 @@
+# fp8
 import argparse
 import sys
 import yaml
@@ -221,7 +222,7 @@ def generated_kernel_name(M, N, K, gpu_id):
 # 4. test_gemm to invoke
 # 4.1 run try_config in parallel
 # 4.2 matmul in a loop of 10 iterations
-def generate_kernel(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, configs, jobs, run_bench):
+def generate_kernel(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, configs, jobs, run_bench):
     filenames = []
     for i in range(jobs):
         filenames.append(generated_kernel_name(M, N, K, i))
@@ -259,8 +260,8 @@ def generate_kernel(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, configs, j
     # pre string
     test_gemm_pre_str = f"""def test_gemm(M, N, K, num_threads):
     thread_pool = multiprocessing.Pool(processes=num_threads)
-    a, a_fp16 = gen_input(M, K, '{dtype_a}', {col_a}, 1, device='cuda')
-    b, b_fp16 = gen_input(K, N, '{dtype_b}', {col_b}, 2, device='cuda')
+    a, a_fp16 = gen_input(M, K, '{dtype_a}', {col_a}, 1, '{init_type}', device='cuda')
+    b, b_fp16 = gen_input(K, N, '{dtype_b}', {col_b}, 2, '{init_type}', device='cuda')
     c = torch.zeros((M, N), device=a.device, dtype={tl_to_torch_types[name_to_tl_types[dtype_c]]})
     task_args = (M, N, K,
                  a.stride(0), a.stride(1),
@@ -359,9 +360,9 @@ def profile_batch_kernels(M, N, K, gpuid, gpus, jobs, verbose):
         jobId += ngpus
 
 
-def tune_gemm_config(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, configs, run_bench, jobs, verbose=0, num_threads=16, gpus=[0]):
+def tune_gemm_config(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, configs, run_bench, jobs, verbose=0, num_threads=16, gpus=[0]):
     # Generate kernel out of all configs
-    generate_kernel(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, configs, jobs, run_bench)
+    generate_kernel(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, configs, jobs, run_bench)
 
     # remove any compiled kernel in the cache
     run_bash_command("rm -rf ~/.triton/cache")
@@ -418,7 +419,7 @@ def tune_gemm_config(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, configs,
         print(f"post procesing time: {post_time}", flush=True)
     return minTime, bestConfig, compile_time, profile_time, post_time
 
-def gen_input(M, N, ty_name, needTrans, seed, device='cuda'):
+def gen_input(M, N, ty_name, needTrans, seed, init_type, device='cuda'):
     d_type = name_to_tl_types[ty_name]
     torch.manual_seed(seed)
     torch.cuda.manual_seed(seed)
@@ -431,10 +432,24 @@ def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
         output = input
         tl.store(output_ptr + offsets, output, mask=mask)
 
+    def init_by_size_and_type(size, dtype, init_type):
+        if init_type == 'hpl':
+            return torch.empty(size, device='cuda', dtype=dtype).uniform_(-0.5, 0.5)
+        # This init type has element[i] in row[j] equal to sin(i+j*N)
+        elif init_type == 'trig_float':
+            M, N = size
+            return torch.reshape(torch.arange(0, M*N), (M, N)).sin().to(dtype=dtype, device='cuda')
+        elif init_type == 'zeros':
+            return torch.zeros(size, dtype=dtype, device='cuda')
+        elif init_type == "randn":
+            temp = torch.randn(size, dtype=dtype, device='cuda')
+            return temp
+        else:
+            raise ValueError("Bad matrix initialization type.")
+
+    raw_data = init_by_size_and_type((N,M) if needTrans else (M,N), torch.float32, init_type)
     if needTrans:
-        raw_data = torch.randn((N, M), dtype=torch.float32, device='cuda').T
-    else:
-        raw_data = torch.randn((M, N), dtype=torch.float32, device='cuda')
+        raw_data = raw_data.T
     if (d_type == tl.float8e4b8 and TORCH_HAS_FP8E4B8) or \
         (d_type == tl.float8e5b16 and TORCH_HAS_FP8E5B16) or not d_type.is_fp8():
         input = raw_data.to(tl_to_torch_types[d_type])
@@ -481,14 +496,14 @@ def matmul(a, b, c, block_m, block_n, block_k, group_m, split_k, num_warps, num_
     return c
 
 
-def test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, config, verbose):
+def test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, config, verbose):
     block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfmaInstrSize = read_config(config)
 
     torch.manual_seed(0)
     #a = torch.randn((M, K), device='cuda', dtype=datatype)
     #b = torch.randn((K, N), device='cuda', dtype=datatype)
-    a, a_fp16 = gen_input(M, K, dtype_a, col_a, 1, device='cuda')
-    b, b_fp16 = gen_input(K, N, dtype_b, col_b, 2, device='cuda')
+    a, a_fp16 = gen_input(M, K, dtype_a, col_a, 1, init_type, device='cuda')
+    b, b_fp16 = gen_input(K, N, dtype_b, col_b, 2, init_type, device='cuda')
     # Allocates output.
     c = torch.zeros((M, N), device=a.device, dtype=tl_to_torch_types[name_to_tl_types[dtype_c]])
     triton_output = matmul(a, b, c, block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfmaInstrSize)
@@ -544,6 +559,7 @@ def parse_args():
     parser.add_argument("--verbose", action='store_true', default=False, help="enables time_breakdown and additional logging messages")
     parser.add_argument("--num_threads", type=int, default=16, help="number of threads to use for kernel compilation and post processing")
     parser.add_argument("--jobs", type=int, default=1, help="number of generated files")
+    parser.add_argument("--init_type", type=str, default='randn', help="Initialization type for input matrices (default uniform rand [0, 1.0)])")
     args = parser.parse_args()
 
     return args
@@ -643,6 +659,7 @@ def main():
 
     mnks = []
     # TODO: make it more robust to get user input
+    init_type = args.init_type
     if matrix_size_file == "" or not os.path.isfile(matrix_size_file):
         M = args.m
         N = args.n
@@ -660,7 +677,7 @@ def main():
     # Check correctness from given configs
     if args.compare_wo_tuning:
         for (M, N, K, col_a, col_b, myConfig) in mnks:
-            test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, item, True)
+            test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, item, True)
         return
 
     configs_full = get_full_tuning_space()
@@ -670,7 +687,7 @@ def main():
         print(f"Benchmarking gemm with {dtype_a} inputs")
         print("trans     M      N      K    TFLOPS")
     else:
-        print(f"Tuning starts at: {start_time}", flush=True)
+        print(f"Tuning {len(mnks)} gemm sizes starts at: {start_time}", flush=True)
         f_results = open(tuning_output_file, 'w')
 
     for (M, N, K, col_a, col_b, myConfig) in mnks:
@@ -684,6 +701,8 @@ def main():
         size_str = f'SIZE: {M} {N} {K} {row_a_str}{row_b_str}'
         if not run_bench:
             print(f"{size_str} nConfigs: {len(pruned_configs)}", end=" ", flush=True)
+        else:
+            print(f"{row_a_str}{row_b_str}    {M:5d}  {N:5d}  {K:5d}    ", end="")
 
         # The main tuning funtion for one gemm size
         verbose_level = 0
@@ -691,7 +710,11 @@ def main():
             verbose_level = 1
         if args.verbose:
             verbose_level = 2
-        minTime, bestConfig, compile_time, profile_time, post_time = tune_gemm_config(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, pruned_configs, run_bench, jobs, num_threads=args.num_threads, gpus=gpus, verbose=verbose_level)
+        minTime, bestConfig, compile_time, profile_time, post_time = tune_gemm_config(
+                M, N, K, col_a, col_b, dtype_a, 
+                dtype_b, dtype_c, init_type, pruned_configs, 
+                run_bench, jobs, num_threads=args.num_threads, gpus=gpus, 
+                verbose=verbose_level)
 
         # post processing the numbers
         perf_tflops = lambda us: 2 * M * N * K * 1e-12 / (us * 1e-6)
@@ -707,7 +730,7 @@ def main():
 
         # write best config to tuning_results.yaml
         if run_bench:
-            print(f"{row_a_str}{row_b_str}    {M:5d}  {N:5d}  {K:5d}    {formatted_tflops}")
+            print(f"{formatted_tflops}")
 
         sizeDict = {'M': M, 'N': N, 'K': K, 'rowMajorA': row_a_str, 'rowMajorB': row_b_str}
         sizeDict.update(bestConfig)
@@ -727,7 +750,7 @@ def main():
         # Check correctness if asked to
         if args.compare:
             print("correctness: ", end=" ", flush=True)
-            test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, bestConfig, False)
+            test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, bestConfig, False)
         elif not run_bench:
             print("", flush=True)