Skip to content

Commit

Permalink
x64: brgemm matmul: fix tile configuration
Browse files Browse the repository at this point in the history
  • Loading branch information
akharito committed Feb 25, 2023
1 parent 05629a5 commit 28ddb5b
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 25 deletions.
58 changes: 35 additions & 23 deletions src/cpu/x64/matmul/brgemm_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,22 @@ using namespace nstl;

using namespace data_type;

namespace {
void maybe_tile_configure(bool is_amx,
const char brg_kernel_palettes[][AMX_PALETTE_SIZE], int brg_ker_idx,
int &prev_ker_idx) {
if (!is_amx) return;
if (brg_ker_idx == prev_ker_idx) return;
// TODO: more accurately estimate the costs of memcmp and tile configuration
if (prev_ker_idx == -1
|| std::memcmp(&brg_kernel_palettes[brg_ker_idx][0],
&brg_kernel_palettes[prev_ker_idx][0], AMX_PALETTE_SIZE)
!= 0)
amx_tile_configure(&brg_kernel_palettes[brg_ker_idx][0]);
prev_ker_idx = brg_ker_idx;
}
} // namespace

template <cpu_isa_t isa>
status_t brgemm_matmul_t<isa>::pd_t::init(engine_t *engine) {
const auto src_dt = src_md_.data_type;
Expand Down Expand Up @@ -119,8 +135,7 @@ status_t brgemm_matmul_t<isa>::pd_t::init(engine_t *engine) {
brgemm_attr_t brgattr;
brgattr.generate_skip_accumulation
= bgmmc_.post_ops_applicable && bgmmc_.nthr_k > 1;
constexpr bool is_amx = one_of(
isa, avx512_core_bf16_amx_int8, avx512_core_bf16_amx_bf16);
const bool is_amx = is_superset(isa, avx512_core_amx);
if (is_amx) {
if (!brgattr.generate_skip_accumulation) {
// TODO: uker doesn't yet support generate_skip_accumulation
Expand Down Expand Up @@ -217,10 +232,9 @@ status_t brgemm_matmul_t<isa>::execute_body(const exec_ctx_t &ctx) const {
balance211((int)bgmmc.K_chunks, brgmm_ctx.get_num_threads_for_k(),
ithr_k, kc_start, kc_end);

if (is_amx) {
const auto base_ker_idx = brgmm_ctx.get_base_brgemm_kernel_idx();
amx_tile_configure(&brg_kernel_palettes_[base_ker_idx][0]);
}
int prev_ker_idx = -1;
maybe_tile_configure(is_amx, brg_kernel_palettes_,
brgmm_ctx.get_base_brgemm_kernel_idx(), prev_ker_idx);

int b {0}, mc {0}, nc {0};
nd_iterator_init(
Expand All @@ -239,8 +253,8 @@ status_t brgemm_matmul_t<isa>::execute_body(const exec_ctx_t &ctx) const {
for (int mb = m_start; mb < m_end; mb++) {
if (use_buffer_a && nb == n_start)
copy_a_chunk_in_buffer(brgmm_ctx, ithr, b, mb, kc);
compute_kernel(
brgmm_ctx, ithr, b, mb, nb, kc, kc == kc_start);
compute_kernel(brgmm_ctx, ithr, b, mb, nb, kc,
kc == kc_start, prev_ker_idx);
}
}
++start;
Expand All @@ -258,12 +272,12 @@ status_t brgemm_matmul_t<isa>::execute_body(const exec_ctx_t &ctx) const {
template <cpu_isa_t isa>
void brgemm_matmul_t<isa>::compute_kernel(
const brg_matmul_exec_ctx_t &brgmm_ctx, int ithr, int b_idx,
int m_blk_idx, int n_blk_idx, int k_chunk_idx, bool do_init) const {
int m_blk_idx, int n_blk_idx, int k_chunk_idx, bool do_init,
int &prev_ker_idx) const {
constexpr bool is_amx
= one_of(isa, avx512_core_bf16_amx_int8, avx512_core_bf16_amx_bf16);
const auto &bgmmc = pd()->get_brgemm_matmul_conf();
const auto addr_batch = brgmm_ctx.get_batch_elem_ptr(ithr);
const int base_brg_ker_idx = brgmm_ctx.get_base_brgemm_kernel_idx();

const auto wsp_tile = brgmm_ctx.get_tile_workspace(ithr);
const int m = m_blk_idx * bgmmc.M_blk;
Expand Down Expand Up @@ -302,10 +316,8 @@ void brgemm_matmul_t<isa>::compute_kernel(
if (gemm_batch > 0 && brg_ker_idx >= 0) {
const auto brg_kernel = brg_kernels_[brg_ker_idx].get();
assert(brg_kernel != nullptr);

const bool is_tile_reconf_required = is_amx && (is_M_tail || is_N_tail);
if (is_tile_reconf_required)
amx_tile_configure(&brg_kernel_palettes_[brg_ker_idx][0]);
maybe_tile_configure(
is_amx, brg_kernel_palettes_, brg_ker_idx, prev_ker_idx);

brgmm_ctx.init_brgemm_batch_elements_values(
ithr, 0, gemm_batch, b_idx, m_blk_idx, k_blk_idx, n_blk_idx);
Expand Down Expand Up @@ -339,9 +351,6 @@ void brgemm_matmul_t<isa>::compute_kernel(
brgemm_kernel_execute(brg_kernel, gemm_batch, addr_batch,
(void *)ptr_C, is_amx ? (void *)wsp_tile : nullptr);
}

if (is_tile_reconf_required)
amx_tile_configure(&brg_kernel_palettes_[base_brg_ker_idx][0]);
}
if (is_K_tail) {
brgmm_ctx.init_brgemm_batch_elements_values(
Expand All @@ -350,11 +359,10 @@ void brgemm_matmul_t<isa>::compute_kernel(
const bool use_init_ker = (do_init && gemm_batch == 0);
const int brg_ker_idx = pd()->get_brg_kernel_idx(
false, use_init_ker, is_M_tail, is_N_tail, true);
maybe_tile_configure(
is_amx, brg_kernel_palettes_, brg_ker_idx, prev_ker_idx);
const auto brg_kernel_k_tail = brg_kernels_[brg_ker_idx].get();
const bool is_tile_reconf_required
= is_amx && bgmmc.K_tail != bgmmc.K_blk;
if (is_tile_reconf_required)
amx_tile_configure(&brg_kernel_palettes_[brg_ker_idx][0]);

if (post_ops_applicable) {
void *scratch = is_amx
? static_cast<void *>(wsp_tile)
Expand Down Expand Up @@ -384,8 +392,6 @@ void brgemm_matmul_t<isa>::compute_kernel(
brgemm_kernel_execute(brg_kernel_k_tail, 1, addr_batch,
(void *)ptr_C, is_amx ? (void *)wsp_tile : nullptr);
}
if (is_tile_reconf_required)
amx_tile_configure(&brg_kernel_palettes_[base_brg_ker_idx][0]);
}
}

Expand All @@ -394,6 +400,8 @@ void brgemm_matmul_t<isa>::maybe_reduce_partial_results_and_apply_postops(
const brg_matmul_exec_ctx_t &brgmm_ctx) const {
if (!brgmm_ctx.parallel_reduction_is_used()) return;

const bool is_amx = is_superset(isa, avx512_core_amx);

const auto &bgmmc = pd()->get_brgemm_matmul_conf();
const int num_threads = brgmm_ctx.get_num_threads_for_parallelization();

Expand All @@ -412,6 +420,8 @@ void brgemm_matmul_t<isa>::maybe_reduce_partial_results_and_apply_postops(
bmn_end);
balance211(bmn_end - bmn_start, nthr_k, ithr_k, start, end);

int prev_ker_idx = -1;

int b {0}, mc {0}, nc {0};

assert(bgmmc.batch == 1);
Expand Down Expand Up @@ -450,6 +460,8 @@ void brgemm_matmul_t<isa>::maybe_reduce_partial_results_and_apply_postops(
= (bgmmc.N - nb * bgmmc.N_blk < bgmmc.N_blk);
const int brg_ker_idx = pd()->get_brg_kernel_idx(
false, false, is_M_tail, is_N_tail, false);
maybe_tile_configure(is_amx, brg_kernel_palettes_,
brg_ker_idx, prev_ker_idx);
const auto brg_kernel = brg_kernels_[brg_ker_idx].get();
const int m = mb * bgmmc.M_blk;
const int n = nb * bgmmc.N_blk;
Expand Down
4 changes: 2 additions & 2 deletions src/cpu/x64/matmul/brgemm_matmul.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ struct brgemm_matmul_t : public primitive_t {
status_t execute_body(const exec_ctx_t &ctx) const;
void compute_kernel(const brg_matmul_exec_ctx_t &brgmm_ctx, int ithr,
int b_idx, int m_blk_idx, int n_blk_idx, int k_blk_idx,
bool do_init) const;
bool do_init, int &prev_ker_idx) const;
void copy_a_chunk_in_buffer(const brg_matmul_exec_ctx_t &brgmm_ctx,
int ithr, int b_idx, int m_blk_idx, int k_blk_idx) const;
void copy_b_chunk_in_buffer(const brg_matmul_exec_ctx_t &brgmm_ctx,
Expand All @@ -116,7 +116,7 @@ struct brgemm_matmul_t : public primitive_t {
char *result_ptr, const char *reduce_ptr, size_t size) const;

std::unique_ptr<brgemm_kernel_t> brg_kernels_[max_num_brg_kernels_matmul];
char brg_kernel_palettes_[max_num_brg_kernels_matmul][64];
alignas(64) char brg_kernel_palettes_[max_num_brg_kernels_matmul][64];
std::unique_ptr<jit_brgemm_matmul_copy_b_t> copy_B_kernel_;
std::unique_ptr<jit_brgemm_matmul_copy_a_t> copy_A_kernel_;
std::unique_ptr<cpu_accumulator_1d_t<data_type::f32>> acc_ker_f32_;
Expand Down

0 comments on commit 28ddb5b

Please sign in to comment.