diff --git a/docs/static_site/src/pages/api/faq/env_var.md b/docs/static_site/src/pages/api/faq/env_var.md index d5234d49a82c..a4b4915c84c6 100644 --- a/docs/static_site/src/pages/api/faq/env_var.md +++ b/docs/static_site/src/pages/api/faq/env_var.md @@ -333,6 +333,10 @@ If ctypes is used, it must be `mxnet._ctypes.ndarray.NDArrayBase`. - Values: Int ```(default=-1)``` - Flag to set num of elements that ONEDNN cache can hold. Default is -1 which means cache size is unbounded. Should only be set if your model has variable input shapes, as cache size may grow unbounded. The number represents the number of items in the cache and is proportional to the number of layers that use ONEDNN and different input shape. +* MXNET_ONEDNN_FORCE_FC_AB_FORMAT + - Values: 0, 1 ```(default=0)``` + - If set to true, FullyConnected will use only AB format for weights, thus MXNet won't use BRGEMM implementation of FC on machines with AVX512-VNNI support which requires special weights format. + * MXNET_ENFORCE_DETERMINISM - Values: 0(false) or 1(true) ```(default=0)``` - If set to true, MXNet will only use deterministic algorithms in forward and backward computation. diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h index 2ee0793d3db2..2af0b5f59212 100644 --- a/src/operator/nn/mkldnn/mkldnn_base-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h @@ -306,7 +306,16 @@ inline static mkldnn::memory::desc GetMemDesc(const NDArray& arr, int dtype = -1 return mkldnn::memory::desc{dims, get_mkldnn_type(dtype), mkldnn::memory::format_tag::any}; } -inline static mkldnn::memory::desc GetFCWeightDesc(const NDArray& arr, int dtype = -1) { +inline static bool ChooseBRGEMMImpl(const mkldnn::memory::dims& weight_dims, size_t batch_size) { + // Conditions based on measurement results done on CLX8280 + // https://github.com/apache/incubator-mxnet/pull/20533 + return weight_dims[0] >= 1024 && weight_dims[1] >= 1024 && batch_size >= 16384 && + weight_dims[0] % 64 == 0 && weight_dims[1] % 64 == 0; +} + +inline static mkldnn::memory::desc GetFCWeightDesc(const NDArray& arr, + size_t batch_size, + int dtype = -1) { int ndim = arr.shape().ndim(); mkldnn::memory::dims dims(ndim); dtype = (dtype == -1) ? arr.dtype() : dtype; @@ -314,8 +323,11 @@ inline static mkldnn::memory::desc GetFCWeightDesc(const NDArray& arr, int dtype dims[i] = arr.shape()[i]; auto format = mkldnn::memory::format_tag::any; // for batch 256 alexnet benchmark test + const bool force_fc_ab_format = dmlc::GetEnv("MXNET_ONEDNN_FORCE_FC_AB_FORMAT", false); if (dims.size() == 2) { - format = mkldnn::memory::format_tag::ab; + if (force_fc_ab_format || !ChooseBRGEMMImpl(dims, batch_size)) { + format = mkldnn::memory::format_tag::ab; + } } return mkldnn::memory::desc{dims, get_mkldnn_type(dtype), format}; diff --git a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc index a215d280cf17..4bd0b94e5b26 100644 --- a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc +++ b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc @@ -38,10 +38,11 @@ mkldnn::inner_product_forward::primitive_desc GetFCFwdImpl(const MKLDNNFCFullPar const NDArray& weight, const NDArray* bias, const mkldnn::memory::desc& out_md) { - auto data_md = GetMemDesc(data); - auto weight_md = full_param.mkldnn_param.quantized ? GetFCWeightDesc(weight, mshadow::kInt8) - : GetFCWeightDesc(weight); auto engine = CpuEngine::Get()->get_engine(); + auto data_md = GetMemDesc(data); + auto weight_md = full_param.mkldnn_param.quantized + ? GetFCWeightDesc(weight, data.shape()[0], mshadow::kInt8) + : GetFCWeightDesc(weight, data.shape()[0]); auto propagation = is_train ? mkldnn::prop_kind::forward_training : mkldnn::prop_kind::forward_scoring; @@ -92,7 +93,7 @@ inline static mkldnn::inner_product_backward_data::primitive_desc GetFCBwdData( const NDArray& output, mkldnn::inner_product_forward::primitive_desc fwd_pd) { auto data_md = GetMemDesc(data); - auto weight_md = GetFCWeightDesc(weight); + auto weight_md = GetFCWeightDesc(weight, data.shape()[0]); auto out_md = GetMemDesc(output); auto engine = CpuEngine::Get()->get_engine(); mkldnn::inner_product_backward_data::desc desc(data_md, weight_md, out_md); @@ -106,7 +107,7 @@ inline static mkldnn::inner_product_backward_weights::primitive_desc GetFCBwdWei const NDArray& output, mkldnn::inner_product_forward::primitive_desc fwd_pd) { auto data_md = GetMemDesc(data); - auto weight_md = GetFCWeightDesc(weight); + auto weight_md = GetFCWeightDesc(weight, data.shape()[0]); auto out_md = GetMemDesc(output); auto engine = CpuEngine::Get()->get_engine(); if (bias) {