From 1b3c16352d0097cd227a4b92d749cf0f68fee842 Mon Sep 17 00:00:00 2001 From: sanchitintel Date: Wed, 25 Sep 2024 15:37:45 -0700 Subject: [PATCH] Cache B tiles in L1D for AMX int8 WoQ micro-kernel --- torch/_inductor/codegen/cpp_gemm_template.py | 9 +++++++-- torch/_inductor/codegen/cpp_micro_gemm.py | 17 ++++++++++++----- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/torch/_inductor/codegen/cpp_gemm_template.py b/torch/_inductor/codegen/cpp_gemm_template.py index e88a634721e4e..3f481a98895b8 100644 --- a/torch/_inductor/codegen/cpp_gemm_template.py +++ b/torch/_inductor/codegen/cpp_gemm_template.py @@ -436,8 +436,10 @@ def get_cache_blocking(register_blocking, thread_blocking): def get_num_byte(dtype): return torch.tensor([], dtype=dtype).element_size() - num_byte_A = get_num_byte(self.input_nodes[0].get_dtype()) - num_byte_B = get_num_byte(self.input_nodes[1].get_dtype()) + dtype_A = self.input_nodes[0].get_dtype() + dtype_B = self.input_nodes[1].get_dtype() + num_byte_A = get_num_byte(dtype_A) + num_byte_B = get_num_byte(dtype_B) # NOTE [CPP GEMM Cache Blocking Algorithm] # Our overall strategy is to @@ -449,6 +451,9 @@ def get_num_byte(dtype): # Step 1: Decide Kc assuming B block is L1-reside. size_cache_B = Kr * Kt_blocks * Nr * num_byte_B + if dtype_A is torch.bfloat16 and dtype_B is torch.int8: + # We will cache dequantized weights (BF16) in L1D + size_cache_B = size_cache_B * 2 Kc_blocks = Kt_blocks if size_cache_B > L1: Kc_blocks = math.floor(L1 / (Kr * Nr * num_byte_B)) diff --git a/torch/_inductor/codegen/cpp_micro_gemm.py b/torch/_inductor/codegen/cpp_micro_gemm.py index 2f0c3e78f5679..478742843dab8 100644 --- a/torch/_inductor/codegen/cpp_micro_gemm.py +++ b/torch/_inductor/codegen/cpp_micro_gemm.py @@ -526,6 +526,7 @@ class CppMicroGemmAMX(CppMicroGemm): A + m * lda, B + n, C + m * ldc + n, + m, K, lda, ldb, @@ -542,6 +543,7 @@ class CppMicroGemmAMX(CppMicroGemm): A + m_tail * lda, B + n, C + m_tail * ldc + n, + m, K, lda, ldb, @@ -561,6 +563,7 @@ class CppMicroGemmAMX(CppMicroGemm): const {{input_t}}* {{restrict_keyword}} A, const {{input2_t}}* {{restrict_keyword}} B, {{output_t}}* {{restrict_keyword}} C, + int64_t m, int64_t K, int64_t lda, int64_t ldb, @@ -603,7 +606,7 @@ class CppMicroGemmAMX(CppMicroGemm): {%- if input_dtype == torch.bfloat16 and input2_dtype == torch.int8 %} // create a buffer for tiles of B. - alignas(64) {{input_t}} bf16_weights_buf[512]; + alignas(4096) {{input_t}} bf16_weights_buf[(K / 16) * 512]; int num_b_rows = (last_k_offset > 0) ? 16 : (tail_k_size * sizeof({{input_t}})) / 4; int b_tile_ptr_stride = ldb * {{vnni_size}}; @@ -618,12 +621,13 @@ class CppMicroGemmAMX(CppMicroGemm): } }; - auto load_B_in_buf = [&]({{input2_t}}* B_ptr) { + auto load_B_in_buf = [&]({{input2_t}}* B_ptr, int idx) { + {{input_t}}* base_addr = &bf16_weights_buf[idx]; {{kernel.unroll_pragma(8)}} for (int i = 0; i < num_b_rows; i++) { load_B_row( B_ptr + i * b_tile_ptr_stride, - bf16_weights_buf + i * 32 + base_addr + i * 32 ); } }; @@ -642,8 +646,11 @@ class CppMicroGemmAMX(CppMicroGemm): {%- endif %} {%- if tile_row == 0 %} {%- if input_dtype == torch.bfloat16 and input2_dtype == torch.int8 %} - load_B_in_buf(const_cast<{{input2_t}}*>(B) + k * ldb + {{tile_col * 16 * vnni_size}}); - _tile_loadd({{tile_idx_b}}, bf16_weights_buf, 64); + if C10_UNLIKELY(m == 0) { + load_B_in_buf(const_cast<{{input2_t}}*>(B) + k * ldb + {{tile_col * 16 * vnni_size}}, (k/16 + {{tile_col}}) * 512); + } + // We duplicate (k/16 + {{tile_col}}) * 512 because a variable holding it would have been declared twice + _tile_loadd({{tile_idx_b}}, &bf16_weights_buf[(k/16 + {{tile_col}}) * 512], 64); {%- else %} _tile_loadd({{tile_idx_b}}, B + k * ldb + {{tile_col * 16 * vnni_size}}, ldb * {{vnni_size}} * sizeof({{input_t}})); {%- endif %}