Skip to content

Commit

Permalink
x64: brgemm matmul: enable blocked B for 3D problems
Browse files Browse the repository at this point in the history
  • Loading branch information
akharito committed Jul 7, 2023
1 parent acb8e12 commit 8c20f62
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 23 deletions.
23 changes: 15 additions & 8 deletions src/cpu/x64/matmul/brgemm_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,8 @@ void brgemm_matmul_t<isa>::compute_kernel(
? brgmm_ctx.get_buf_C_ptr(ithr, m_blk_idx, n_blk_idx)
: ptr_D;

const auto zp_comp_a = brgmm_ctx.get_zp_a_compensation_ptr(ithr, n_blk_idx);
const auto zp_comp_a
= brgmm_ctx.get_zp_a_compensation_ptr(ithr, b_idx, n_blk_idx);
const auto zp_comp_b
= brgmm_ctx.get_zp_b_compensation_result_ptr(ithr, m_blk_idx);
const auto zp_c_val_ptr = brgmm_ctx.get_zp_c_val_ptr();
Expand Down Expand Up @@ -475,7 +476,8 @@ void brgemm_matmul_t<isa>::maybe_reduce_partial_results_and_apply_postops(
// TODO: support reduction for zp/s8s8 compensations
// computed in copy routines
const auto zp_comp_a
= brgmm_ctx.get_zp_a_compensation_ptr(ithr, nb);
= brgmm_ctx.get_zp_a_compensation_ptr(
ithr, b, nb);
const auto zp_comp_b
= brgmm_ctx.get_zp_b_compensation_result_ptr(
ithr, mb);
Expand Down Expand Up @@ -579,8 +581,8 @@ void brgemm_matmul_t<isa>::copy_b_chunk_in_buffer(
const int n = n_blk_idx * bgmmc.N_blk;
const bool is_N_tail = (bgmmc.N - n < bgmmc.N_blk);
ctx.current_N_blk = is_N_tail ? bgmmc.N_tail : bgmmc.N_blk;
ctx.zp_a_compensation_ptr
= (void *)brgmm_ctx.get_zp_a_compensation_ptr(ithr, n_blk_idx);
ctx.zp_a_compensation_ptr = (void *)brgmm_ctx.get_zp_a_compensation_ptr(
ithr, b_idx, n_blk_idx);
ctx.zp_a_neg_value_ptr = (void *)brgmm_ctx.get_zp_a_neg_val_ptr();

int gb = 0;
Expand Down Expand Up @@ -709,8 +711,10 @@ struct brgemm_matmul_t<isa>::brg_matmul_exec_ctx_t {
// multitreaded execution mode
const size_t reorder_zp_a_comp_offset
= weights_d.size() - weights_d.additional_buffer_size();
const size_t b_batch
= get_bb_idx(bgmmc.batch - 1, bgmmc_.bcast_B_desc) + 1;
const size_t s8s8_buffer_sz = bgmmc.s8s8_compensation_required
? bgmmc.s8s8_comp_b_str * sizeof(int32_t)
? sizeof(int32_t) * b_batch * bgmmc.s8s8_comp_b_str
: 0;
reorder_zp_a_comp_ptr_
= const_cast<int32_t *>(reinterpret_cast<const int32_t *>(
Expand Down Expand Up @@ -965,7 +969,7 @@ struct brgemm_matmul_t<isa>::brg_matmul_exec_ctx_t {
? n_blk_idx % bgmmc_.N_chunk_size
: n_blk_idx;
return s8s8_compensation_ptr_ + ithr * bgmmc_.s8s8_comp_ithr_str
+ b * bgmmc_.s8s8_comp_b_str
+ get_bb_idx(b, bgmmc_.bcast_B_desc) * bgmmc_.s8s8_comp_b_str
+ n_blk_local * bgmmc_.s8s8_comp_n_str;
}

Expand All @@ -987,7 +991,8 @@ struct brgemm_matmul_t<isa>::brg_matmul_exec_ctx_t {

const int32_t *get_zp_c_val_ptr() const { return &zero_point_c_val_; }

int32_t *get_zp_a_compensation_ptr(int ithr, int n_blk_idx) const {
int32_t *get_zp_a_compensation_ptr(
int ithr, int b_idx, int n_blk_idx) const {
if (!bgmmc_.has_zero_point_a) return nullptr;

const int n_blk_local = n_blk_idx % bgmmc_.N_chunk_size;
Expand All @@ -1000,7 +1005,9 @@ struct brgemm_matmul_t<isa>::brg_matmul_exec_ctx_t {
// locally just before usage. Using the single global scaling before
// parallel section might produce significant overhead for small
// problems running in multitreaded execution mode
const int base_offset = n_blk_idx * bgmmc_.wei_n_blk;
const int base_offset = get_bb_idx(b_idx, bgmmc_.bcast_B_desc)
* rnd_up(bgmmc_.N, bgmmc_.wei_n_blk)
+ n_blk_idx * bgmmc_.wei_n_blk;
PRAGMA_OMP_SIMD()
for (int b = 0; b < bgmmc_.wei_n_blk; b++)
zp_comp[b] = -zero_point_a_negative_val_
Expand Down
47 changes: 32 additions & 15 deletions src/cpu/x64/matmul/brgemm_matmul_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,27 @@ int get_default_n_block(format_tag_t matrix_b_tag) {
// Note: consider using weights mem_descriptor 'inner_blks' to
// return B's inner block for non-default cases.
switch (matrix_b_tag) {
case aCB16b64c:
case aCB16b64c2b:
case aCB16b64c4b:
case BA16a64b4a:
case BA16a64b2a:
case BA16a64b: return 64;
case aCB16b48c:
case aCB16b48c2b:
case aCB16b48c4b:
case BA16a48b:
case BA16a48b2a:
case BA16a48b4a: return 48;
case aCB16b32c:
case aCB16b32c2b:
case aCB16b32c4b:
case BA16a32b:
case BA16a32b2a:
case BA16a32b4a: return 32;
case aCB16b16c:
case aCB16b16c2b:
case aCB16b16c4b:
case BA16a16b:
case BA16a16b2a:
case BA16a16b4a: return 16;
Expand Down Expand Up @@ -242,14 +254,17 @@ status_t brgemm_matmul_conf_utils_t::set_or_check_tags(memory_desc_t &A_md,
status_t brgemm_matmul_conf_utils_t::set_B_flags(memory_desc_t &B_md) const {

memory_desc_t want_B_md = B_md;
// Set bits for all dimensions except k dimension
const int compensation_mask
= ((1 << bgmmc.ndims) - 1 - (1 << (bgmmc.ndims - 2)));
if (bgmmc.s8s8_compensation_required && bgmmc.blocked_B) {
want_B_md.extra.flags |= memory_extra_flags::compensation_conv_s8s8;
want_B_md.extra.compensation_mask = (1 << 1);
want_B_md.extra.compensation_mask = compensation_mask;
}
if (bgmmc.src_zp_type != brgemm_broadcast_t::none && bgmmc.blocked_B) {
want_B_md.extra.flags
|= memory_extra_flags::compensation_conv_asymmetric_src;
want_B_md.extra.asymm_compensation_mask = (1 << 1);
want_B_md.extra.asymm_compensation_mask = compensation_mask;
}

if (B_any_layout) {
Expand All @@ -262,27 +277,29 @@ status_t brgemm_matmul_conf_utils_t::set_B_flags(memory_desc_t &B_md) const {

format_tag_t brgemm_matmul_conf_utils_t::pick_blocked_B_layout(
int n_blk) const {
if (bgmmc.ndims > 2) return format_tag::undef;

if (bgmmc.ndims > 3) return format_tag::undef;
if (this->is_int8()) switch (n_blk) {
case 64: return BA16a64b4a;
case 48: return BA16a48b4a;
case 32: return BA16a32b4a;
case 16: return BA16a16b4a;
case 64: return bgmmc.ndims == 3 ? aCB16b64c4b : BA16a64b4a;
case 48: return bgmmc.ndims == 3 ? aCB16b48c4b : BA16a48b4a;
case 32: return bgmmc.ndims == 3 ? aCB16b32c4b : BA16a32b4a;
case 16: return bgmmc.ndims == 3 ? aCB16b16c4b : BA16a16b4a;
default: return format_tag::undef;
}

if (this->is_bf16()) switch (n_blk) {
case 64: return BA16a64b2a;
case 48: return BA16a48b2a;
case 32: return BA16a32b2a;
case 16: return BA16a16b2a;
case 64: return bgmmc.ndims == 3 ? aCB16b64c2b : BA16a64b2a;
case 48: return bgmmc.ndims == 3 ? aCB16b48c2b : BA16a48b2a;
case 32: return bgmmc.ndims == 3 ? aCB16b32c2b : BA16a32b2a;
case 16: return bgmmc.ndims == 3 ? aCB16b16c2b : BA16a16b2a;
default: return format_tag::undef;
}
// Note: bf32 assumes f32 blocking
if (this->is_f32() || this->is_bf32()) switch (n_blk) {
case 64: return BA16a64b;
case 48: return BA16a48b;
case 32: return BA16a32b;
case 16: return BA16a16b;
case 64: return bgmmc.ndims == 3 ? aCB16b64c : BA16a64b;
case 48: return bgmmc.ndims == 3 ? aCB16b48c : BA16a48b;
case 32: return bgmmc.ndims == 3 ? aCB16b32c : BA16a32b;
case 16: return bgmmc.ndims == 3 ? aCB16b16c : BA16a16b;
default: return format_tag::undef;
}
return format_tag::undef;
Expand Down
16 changes: 16 additions & 0 deletions tests/benchdnn/inputs/matmul/harness_matmul_data_tags
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
--attr-fpmath=,bf16
--wtag=BA16a64b,BA16a48b,BA16a32b,BA16a16b
--batch=shapes_2d
--attr-fpmath=

--cfg=bf16bf16bf16
--wtag=BA16a64b2a,BA16a48b2a,BA16a32b2a,BA16a16b2a
Expand All @@ -21,3 +22,18 @@
--cfg=u8s8f32
--wtag=BA16a64b4a,BA16a48b4a,BA16a32b4a,BA16a16b4a
--batch=shapes_2d

--stag=abc --dtag=abc
--cfg=f32
--attr-fpmath=,bf16
--wtag=aCB16b16c,aCB16b32c,aCB16b48c,aCB16b64c
--batch=shapes_3d
--attr-fpmath=

--cfg=bf16bf16bf16
--wtag=aCB16b16c2b,aCB16b32c2b,aCB16b48c2b,aCB16b64c2b
--batch=shapes_3d

--cfg=u8s8f32
--wtag=aCB16b16c4b,aCB16b32c4b,aCB16b48c4b,aCB16b64c4b
--batch=shapes_3d

0 comments on commit 8c20f62

Please sign in to comment.