Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[Backport] Enabling BRGEMM FullyConnected based on shapes (#20568)
Browse files Browse the repository at this point in the history
* [v1.x][Feature] Add flag for disabling oneDNN BRGEMM implementation of FC (#20450)

* Add flag for disabling oneDNN BRGEMM implementation of FC

* Review fixes

* Update env_var.md

* [v1.x] Enabling BRGEMM FullyConnected based on shapes (#20533)

* Enable brgemm based on input info

* fix sanity

* Review fixes

* Change function name

* Fix typo

* Align variable assignments

* Fix review

* use const reference

* Update flag name
  • Loading branch information
bgawrych authored Sep 6, 2021
1 parent 8c7d5c6 commit 1b98299
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 7 deletions.
4 changes: 4 additions & 0 deletions docs/static_site/src/pages/api/faq/env_var.md
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,10 @@ If ctypes is used, it must be `mxnet._ctypes.ndarray.NDArrayBase`.
- Values: Int ```(default=-1)```
- Flag to set num of elements that ONEDNN cache can hold. Default is -1 which means cache size is unbounded. Should only be set if your model has variable input shapes, as cache size may grow unbounded. The number represents the number of items in the cache and is proportional to the number of layers that use ONEDNN and different input shape.

* MXNET_ONEDNN_FORCE_FC_AB_FORMAT
- Values: 0, 1 ```(default=0)```
- If set to true, FullyConnected will use only AB format for weights, thus MXNet won't use BRGEMM implementation of FC on machines with AVX512-VNNI support which requires special weights format.

* MXNET_ENFORCE_DETERMINISM
- Values: 0(false) or 1(true) ```(default=0)```
- If set to true, MXNet will only use deterministic algorithms in forward and backward computation.
Expand Down
16 changes: 14 additions & 2 deletions src/operator/nn/mkldnn/mkldnn_base-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -306,16 +306,28 @@ inline static mkldnn::memory::desc GetMemDesc(const NDArray& arr, int dtype = -1
return mkldnn::memory::desc{dims, get_mkldnn_type(dtype), mkldnn::memory::format_tag::any};
}

inline static mkldnn::memory::desc GetFCWeightDesc(const NDArray& arr, int dtype = -1) {
inline static bool ChooseBRGEMMImpl(const mkldnn::memory::dims& weight_dims, size_t batch_size) {
// Conditions based on measurement results done on CLX8280
// https://github.com/apache/incubator-mxnet/pull/20533
return weight_dims[0] >= 1024 && weight_dims[1] >= 1024 && batch_size >= 16384 &&
weight_dims[0] % 64 == 0 && weight_dims[1] % 64 == 0;
}

inline static mkldnn::memory::desc GetFCWeightDesc(const NDArray& arr,
size_t batch_size,
int dtype = -1) {
int ndim = arr.shape().ndim();
mkldnn::memory::dims dims(ndim);
dtype = (dtype == -1) ? arr.dtype() : dtype;
for (size_t i = 0; i < dims.size(); i++)
dims[i] = arr.shape()[i];
auto format = mkldnn::memory::format_tag::any;
// for batch 256 alexnet benchmark test
const bool force_fc_ab_format = dmlc::GetEnv("MXNET_ONEDNN_FORCE_FC_AB_FORMAT", false);
if (dims.size() == 2) {
format = mkldnn::memory::format_tag::ab;
if (force_fc_ab_format || !ChooseBRGEMMImpl(dims, batch_size)) {
format = mkldnn::memory::format_tag::ab;
}
}

return mkldnn::memory::desc{dims, get_mkldnn_type(dtype), format};
Expand Down
11 changes: 6 additions & 5 deletions src/operator/nn/mkldnn/mkldnn_fully_connected.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,11 @@ mkldnn::inner_product_forward::primitive_desc GetFCFwdImpl(const MKLDNNFCFullPar
const NDArray& weight,
const NDArray* bias,
const mkldnn::memory::desc& out_md) {
auto data_md = GetMemDesc(data);
auto weight_md = full_param.mkldnn_param.quantized ? GetFCWeightDesc(weight, mshadow::kInt8)
: GetFCWeightDesc(weight);
auto engine = CpuEngine::Get()->get_engine();
auto data_md = GetMemDesc(data);
auto weight_md = full_param.mkldnn_param.quantized
? GetFCWeightDesc(weight, data.shape()[0], mshadow::kInt8)
: GetFCWeightDesc(weight, data.shape()[0]);
auto propagation =
is_train ? mkldnn::prop_kind::forward_training : mkldnn::prop_kind::forward_scoring;

Expand Down Expand Up @@ -92,7 +93,7 @@ inline static mkldnn::inner_product_backward_data::primitive_desc GetFCBwdData(
const NDArray& output,
mkldnn::inner_product_forward::primitive_desc fwd_pd) {
auto data_md = GetMemDesc(data);
auto weight_md = GetFCWeightDesc(weight);
auto weight_md = GetFCWeightDesc(weight, data.shape()[0]);
auto out_md = GetMemDesc(output);
auto engine = CpuEngine::Get()->get_engine();
mkldnn::inner_product_backward_data::desc desc(data_md, weight_md, out_md);
Expand All @@ -106,7 +107,7 @@ inline static mkldnn::inner_product_backward_weights::primitive_desc GetFCBwdWei
const NDArray& output,
mkldnn::inner_product_forward::primitive_desc fwd_pd) {
auto data_md = GetMemDesc(data);
auto weight_md = GetFCWeightDesc(weight);
auto weight_md = GetFCWeightDesc(weight, data.shape()[0]);
auto out_md = GetMemDesc(output);
auto engine = CpuEngine::Get()->get_engine();
if (bias) {
Expand Down

0 comments on commit 1b98299

Please sign in to comment.