-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[v1.x][Feature] Add flag for disabling oneDNN BRGEMM implementation of FC #20450
Conversation
Hey @bgawrych , Thanks for submitting the PR
CI supported jobs: [edge, sanity, centos-cpu, centos-gpu, windows-gpu, clang, miscellaneous, unix-gpu, unix-cpu, website, windows-cpu] Note: |
@mxnet-bot run ci [unix-cpu] |
Jenkins CI successfully triggered : [unix-cpu] |
@@ -312,7 +312,8 @@ inline static mkldnn::memory::desc GetFCWeightDesc(const NDArray &arr, int dtype | |||
for (size_t i = 0; i < dims.size(); i++) dims[i] = arr.shape()[i]; | |||
auto format = mkldnn::memory::format_tag::any; | |||
// for batch 256 alexnet benchmark test | |||
if (dims.size() == 2) { | |||
const bool brgemm_disabled = dmlc::GetEnv("MXNET_DISABLE_ONEDNN_BRGEMM_FC", true); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will be good to add description to docs/static_site/src/pages/api/faq/env_var.md
Also I am not sure if for 1.x branch the name have to include ONEDNN ?
so maybe MXNET_MKLDNN_DISABLE_BRGEMM_FC
@szha Can you help with merge? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me. Ideally we don't want to introduce more environment variables, but instead make the decision automatically on which implementation to use based on input size.
@mxnet-bot run ci [unix-gpu] |
Unauthorized access detected. |
@mxnet-bot run ci [unix-gpu] |
Jenkins CI successfully triggered : [unix-gpu] |
@@ -312,7 +312,8 @@ inline static mkldnn::memory::desc GetFCWeightDesc(const NDArray &arr, int dtype | |||
for (size_t i = 0; i < dims.size(); i++) dims[i] = arr.shape()[i]; | |||
auto format = mkldnn::memory::format_tag::any; | |||
// for batch 256 alexnet benchmark test | |||
if (dims.size() == 2) { | |||
const bool brgemm_disabled = dmlc::GetEnv("MXNET_MKLDNN_DISABLE_BRGEMM_FC", true); | |||
if (dims.size() == 2 && brgemm_disabled) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please provide more benchmarking data with different m/n/k and formats?
BTW, actually brgemm_disabled
looks misleading to me. According to the code change, i would rather call the flag force_plain_format
or force_ab_format
.
…f FC (apache#20450) * Add flag for disabling oneDNN BRGEMM implementation of FC * Review fixes * Update env_var.md
* [v1.x][Feature] Add flag for disabling oneDNN BRGEMM implementation of FC (#20450) * Add flag for disabling oneDNN BRGEMM implementation of FC * Review fixes * Update env_var.md * [v1.x] Enabling BRGEMM FullyConnected based on shapes (#20533) * Enable brgemm based on input info * fix sanity * Review fixes * Change function name * Fix typo * Align variable assignments * Fix review * use const reference * Update flag name
…f FC (apache#20450) * Add flag for disabling oneDNN BRGEMM implementation of FC * Review fixes * Update env_var.md
Description
In new oneDNN version there are BRGEMM kernels for FullyConnected - it require special memory format of weights.
This PR let user to decide if BRGEMM implementation should be used by flag - it can significantly speedup FC execution for large tensors (got 42% speedup on BERT with 64 batch size ) - feature disabled by default as it's not so efficient on small tensors.
Checklist
Essentials