Skip to content

Commit

Permalink
Cache B tiles in L1D for AMX int8 WoQ micro-kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
sanchitintel committed Sep 25, 2024
1 parent c9d12f6 commit 1b3c163
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 7 deletions.
9 changes: 7 additions & 2 deletions torch/_inductor/codegen/cpp_gemm_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,8 +436,10 @@ def get_cache_blocking(register_blocking, thread_blocking):
def get_num_byte(dtype):
return torch.tensor([], dtype=dtype).element_size()

num_byte_A = get_num_byte(self.input_nodes[0].get_dtype())
num_byte_B = get_num_byte(self.input_nodes[1].get_dtype())
dtype_A = self.input_nodes[0].get_dtype()
dtype_B = self.input_nodes[1].get_dtype()
num_byte_A = get_num_byte(dtype_A)
num_byte_B = get_num_byte(dtype_B)

# NOTE [CPP GEMM Cache Blocking Algorithm]
# Our overall strategy is to
Expand All @@ -449,6 +451,9 @@ def get_num_byte(dtype):

# Step 1: Decide Kc assuming B block is L1-reside.
size_cache_B = Kr * Kt_blocks * Nr * num_byte_B
if dtype_A is torch.bfloat16 and dtype_B is torch.int8:
# We will cache dequantized weights (BF16) in L1D
size_cache_B = size_cache_B * 2
Kc_blocks = Kt_blocks
if size_cache_B > L1:
Kc_blocks = math.floor(L1 / (Kr * Nr * num_byte_B))
Expand Down
17 changes: 12 additions & 5 deletions torch/_inductor/codegen/cpp_micro_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,7 @@ class CppMicroGemmAMX(CppMicroGemm):
A + m * lda,
B + n,
C + m * ldc + n,
m,
K,
lda,
ldb,
Expand All @@ -542,6 +543,7 @@ class CppMicroGemmAMX(CppMicroGemm):
A + m_tail * lda,
B + n,
C + m_tail * ldc + n,
m,
K,
lda,
ldb,
Expand All @@ -561,6 +563,7 @@ class CppMicroGemmAMX(CppMicroGemm):
const {{input_t}}* {{restrict_keyword}} A,
const {{input2_t}}* {{restrict_keyword}} B,
{{output_t}}* {{restrict_keyword}} C,
int64_t m,
int64_t K,
int64_t lda,
int64_t ldb,
Expand Down Expand Up @@ -603,7 +606,7 @@ class CppMicroGemmAMX(CppMicroGemm):
{%- if input_dtype == torch.bfloat16 and input2_dtype == torch.int8 %}
// create a buffer for tiles of B.
alignas(64) {{input_t}} bf16_weights_buf[512];
alignas(4096) {{input_t}} bf16_weights_buf[(K / 16) * 512];
int num_b_rows = (last_k_offset > 0) ? 16 : (tail_k_size * sizeof({{input_t}})) / 4;
int b_tile_ptr_stride = ldb * {{vnni_size}};
Expand All @@ -618,12 +621,13 @@ class CppMicroGemmAMX(CppMicroGemm):
}
};
auto load_B_in_buf = [&]({{input2_t}}* B_ptr) {
auto load_B_in_buf = [&]({{input2_t}}* B_ptr, int idx) {
{{input_t}}* base_addr = &bf16_weights_buf[idx];
{{kernel.unroll_pragma(8)}}
for (int i = 0; i < num_b_rows; i++) {
load_B_row(
B_ptr + i * b_tile_ptr_stride,
bf16_weights_buf + i * 32
base_addr + i * 32
);
}
};
Expand All @@ -642,8 +646,11 @@ class CppMicroGemmAMX(CppMicroGemm):
{%- endif %}
{%- if tile_row == 0 %}
{%- if input_dtype == torch.bfloat16 and input2_dtype == torch.int8 %}
load_B_in_buf(const_cast<{{input2_t}}*>(B) + k * ldb + {{tile_col * 16 * vnni_size}});
_tile_loadd({{tile_idx_b}}, bf16_weights_buf, 64);
if C10_UNLIKELY(m == 0) {
load_B_in_buf(const_cast<{{input2_t}}*>(B) + k * ldb + {{tile_col * 16 * vnni_size}}, (k/16 + {{tile_col}}) * 512);
}
// We duplicate (k/16 + {{tile_col}}) * 512 because a variable holding it would have been declared twice
_tile_loadd({{tile_idx_b}}, &bf16_weights_buf[(k/16 + {{tile_col}}) * 512], 64);
{%- else %}
_tile_loadd({{tile_idx_b}}, B + k * ldb + {{tile_col * 16 * vnni_size}}, ldb * {{vnni_size}} * sizeof({{input_t}}));
{%- endif %}
Expand Down

0 comments on commit 1b3c163

Please sign in to comment.