Skip to content

Commit

Permalink
x64: brgemm bwd_w conv: use ih_block instead of ih for tr_src scratchpad
Browse files Browse the repository at this point in the history
  • Loading branch information
ankalinin committed Apr 17, 2023
1 parent f7acf98 commit 8da1083
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 14 deletions.
8 changes: 4 additions & 4 deletions src/cpu/x64/jit_brgemm_conv_bwd_w.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ status_t brgemm_convolution_bwd_weights_t::pd_t::init(engine_t *engine) {
brgattr.max_top_vpad = 0;
brgattr.max_bottom_vpad = 0;

brgattr.LDA2 = jcp_.tr_iw * jcp_.ih * jcp_.id;
brgattr.LDA2 = jcp_.tr_iw * jcp_.ih_block * jcp_.id;
brgattr.LDB2 = jcp_.tr_ow * jcp_.oc_block * jcp_.oh * jcp_.od;
brgattr.LDC2_M = jcp_.oc_block * jcp_.kd * jcp_.kh * jcp_.kw;
brgattr.LDC2_N = jcp_.nb_ic * jcp_.ic_block * jcp_.oc_block
Expand Down Expand Up @@ -464,7 +464,7 @@ struct brgemm_convolution_bwd_weights_t::thread_info_t {

size_t tr_src_off(int g, int icb, int id, int ih) const {
const size_t tr_row_size = jcp.tr_iw * jcp.ic_block;
const size_t tr_3d_size = tr_row_size * jcp.ih;
const size_t tr_3d_size = tr_row_size * jcp.ih_block;
int adj = (jcp.global_transpose) ? 1 : jcp.nb_ic_blocking;
// Aligned to buffer end to use guard elements
return tr_src_buf_number(g, icb) * adj * jcp.tr_src_buf_size
Expand Down Expand Up @@ -1024,7 +1024,7 @@ void brgemm_convolution_bwd_weights_t::compute_diff_weights_3d(
+ _pd->filter_w_to_src(kw) / jcp.stride_w
+ (kw % jcp.stride_w) * src_stride_w_shift
+ (bs_ih_s - ih_s) * jcp.tr_iw * jcp.ic_block
+ (bs_id_s - id_s) * jcp.ih * jcp.tr_iw * jcp.ic_block;
+ (bs_id_s - id_s) * jcp.ih_block * jcp.tr_iw * jcp.ic_block;
const void *ptr_B = ((diff_dst_data_t *)p_dst)
+ (bs_oh_s - oh_s) * jcp.tr_ow * jcp.oc_block
+ (bs_od_s - od_s) * jcp.oh * jcp.tr_ow * jcp.oc_block;
Expand All @@ -1045,7 +1045,7 @@ void brgemm_convolution_bwd_weights_t::compute_diff_weights_3d(
ti->brg_batch[odb * bs_h + ohb].ptr.A = (char *)ptr_A
+ ohb * jcp.typesize_in * jcp.tr_iw * jcp.ic_block
* jcp.stride_h
+ odb * jcp.typesize_in * jcp.ih * jcp.tr_iw
+ odb * jcp.typesize_in * jcp.ih_block * jcp.tr_iw
* jcp.ic_block * jcp.stride_d;
ti->brg_batch[odb * bs_h + ohb].ptr.B = (char *)ptr_B
+ ohb * jcp.typesize_in * jcp.tr_ow * jcp.oc_block
Expand Down
23 changes: 13 additions & 10 deletions src/cpu/x64/jit_brgemm_conv_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2634,16 +2634,6 @@ void balance_bwd_w(jit_brgemm_conv_conf_t &jcp) {
jcp.nthr_g = nthr_g;
jcp.nthr_oc_b = nthr_oc_b;
jcp.nthr_ic_b = nthr_ic_b;

// TODO: Optimize memory allocation when threaded on height and depth
jcp.tr_src_buf_size = jcp.tr_iw * jcp.ic_block * jcp.ih * jcp.id;
jcp.tr_diff_dst_buf_size = jcp.tr_ow * jcp.oc_block * jcp.oh * jcp.od;
jcp.tr_src_buf_count = jcp.global_transpose
? jcp.nthr_mb * jcp.nb_ic * jcp.ngroups
: jcp.nthr;
jcp.tr_diff_dst_buf_count = jcp.global_transpose
? jcp.nthr_mb * jcp.nb_oc * jcp.ngroups
: jcp.nthr;
}

status_t init_conf_bwd_w(jit_brgemm_conv_conf_t &jcp,
Expand Down Expand Up @@ -2886,6 +2876,19 @@ status_t init_conf_bwd_w(jit_brgemm_conv_conf_t &jcp,
// try to split oh by equal oh blocks
oh_block_limit = div_up(jcp.oh, div_up(jcp.oh, oh_block_limit));
jcp.oh_block = utils::saturate(1, jcp.oh, oh_block_limit);
jcp.ih_block = nstl::min(jcp.ih,
jcp.stride_h
* brg_blocking_t::get_inp_size(jcp.ih, jcp.oh_block, jcp.kh,
jcp.stride_h, jcp.dilate_h));
// TODO: Optimize memory allocation when threaded on height and depth
jcp.tr_src_buf_count = jcp.global_transpose
? jcp.nthr_mb * jcp.nb_ic * jcp.ngroups
: jcp.nthr;
jcp.tr_diff_dst_buf_count = jcp.global_transpose
? jcp.nthr_mb * jcp.nb_oc * jcp.ngroups
: jcp.nthr;
jcp.tr_src_buf_size = jcp.tr_iw * jcp.ic_block * jcp.ih_block * jcp.id;
jcp.tr_diff_dst_buf_size = jcp.tr_ow * jcp.oc_block * jcp.oh * jcp.od;

const int iframe_size = irow_size * jcp.id;
const int oframe_size = orow_size * jcp.od;
Expand Down

0 comments on commit 8da1083

Please sign in to comment.