Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[v1.x] Enabling BRGEMM FullyConnected based on shapes #20533

Merged
merged 8 commits into from
Sep 2, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/static_site/src/pages/api/faq/env_var.md
Original file line number Diff line number Diff line change
Expand Up @@ -298,9 +298,9 @@ If ctypes is used, it must be `mxnet._ctypes.ndarray.NDArrayBase`.
- Values: Int ```(default=-1)```
- Flag to set num of elements that MKLDNN cache can hold. Default is -1 which means cache size is unbounded. Should only be set if your model has variable input shapes, as cache size may grow unbounded. The number represents the number of items in the cache and is proportional to the number of layers that use MKLDNN and different input shape.

* MXNET_MKLDNN_DISABLE_BRGEMM_FC
- Values: 0, 1 ```(default=1)```
- Flag which disables BRGEMM kernels in FullyConnected executed with MKLDNN support - Should only be set to 0 if your model has constant input shapes or FullyConnected is calculated with large tensors. Supported on machines with AVX512-VNNI.
* MXNET_MKLDNN_FORCE_FC_AB_FORMAT
- Values: 0, 1 ```(default=0)```
- If set to true, FullyConnected will use only AB format for weights, thus MXNet won't use BRGEMM implementation of FC on machines with AVX512-VNNI support.
bgawrych marked this conversation as resolved.
Show resolved Hide resolved

* MXNET_ENFORCE_DETERMINISM
- Values: 0(false) or 1(true) ```(default=0)```
Expand Down
19 changes: 15 additions & 4 deletions src/operator/nn/mkldnn/mkldnn_base-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -305,17 +305,28 @@ inline static mkldnn::memory::desc GetMemDesc(const NDArray& arr, int dtype = -1
return mkldnn::memory::desc{dims, get_mkldnn_type(dtype), mkldnn::memory::format_tag::any};
}

inline static mkldnn::memory::desc GetFCWeightDesc(const NDArray& arr, int dtype = -1) {
inline static bool ChooseBRGEMMImpl(mkldnn::memory::dims weight_dims, size_t batch_size) {
bgawrych marked this conversation as resolved.
Show resolved Hide resolved
// Conditions based on measurment results done on CLX8280
bgawrych marked this conversation as resolved.
Show resolved Hide resolved
// https://github.com/apache/incubator-mxnet/pull/20533
return weight_dims[0] % 64 == 0 && weight_dims[1] % 64 == 0 && weight_dims[0] >= 1024 &&
bgawrych marked this conversation as resolved.
Show resolved Hide resolved
bgawrych marked this conversation as resolved.
Show resolved Hide resolved
weight_dims[1] >= 1024 && batch_size >= 2 << 13;
bgawrych marked this conversation as resolved.
Show resolved Hide resolved
}

inline static mkldnn::memory::desc GetFCWeightDesc(const NDArray& arr,
size_t batch_size,
int dtype = -1) {
bgawrych marked this conversation as resolved.
Show resolved Hide resolved
int ndim = arr.shape().ndim();
mkldnn::memory::dims dims(ndim);
dtype = (dtype == -1) ? arr.dtype() : dtype;
for (size_t i = 0; i < dims.size(); i++)
dims[i] = arr.shape()[i];
auto format = mkldnn::memory::format_tag::any;
// for batch 256 alexnet benchmark test
const bool brgemm_disabled = dmlc::GetEnv("MXNET_MKLDNN_DISABLE_BRGEMM_FC", true);
if (dims.size() == 2 && brgemm_disabled) {
format = mkldnn::memory::format_tag::ab;
const bool force_fc_ab_format = dmlc::GetEnv("MXNET_MKLDNN_FORCE_FC_AB_FORMAT", false);
bgawrych marked this conversation as resolved.
Show resolved Hide resolved
if (dims.size() == 2) {
if (force_fc_ab_format || !ChooseBRGEMMImpl(dims, batch_size)) {
format = mkldnn::memory::format_tag::ab;
}
}

return mkldnn::memory::desc{dims, get_mkldnn_type(dtype), format};
Expand Down
11 changes: 6 additions & 5 deletions src/operator/nn/mkldnn/mkldnn_fully_connected.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,10 @@ mkldnn::inner_product_forward::primitive_desc GetFCFwdImpl(const MKLDNNFCFullPar
const NDArray* bias,
const mkldnn::memory::desc& out_md) {
auto data_md = GetMemDesc(data);
auto weight_md = full_param.mkldnn_param.quantized ? GetFCWeightDesc(weight, mshadow::kInt8)
: GetFCWeightDesc(weight);
auto engine = CpuEngine::Get()->get_engine();
auto weight_md = full_param.mkldnn_param.quantized
? GetFCWeightDesc(weight, data.shape()[0], mshadow::kInt8)
: GetFCWeightDesc(weight, data.shape()[0]);
auto engine = CpuEngine::Get()->get_engine();
bgawrych marked this conversation as resolved.
Show resolved Hide resolved
auto propagation =
is_train ? mkldnn::prop_kind::forward_training : mkldnn::prop_kind::forward_scoring;

Expand Down Expand Up @@ -107,7 +108,7 @@ inline static mkldnn::inner_product_backward_data::primitive_desc GetFCBwdData(
const NDArray& output,
mkldnn::inner_product_forward::primitive_desc fwd_pd) {
auto data_md = GetMemDesc(data);
auto weight_md = GetFCWeightDesc(weight);
auto weight_md = GetFCWeightDesc(weight, data.shape()[0]);
auto out_md = GetMemDesc(output);
auto engine = CpuEngine::Get()->get_engine();
mkldnn::inner_product_backward_data::desc desc(data_md, weight_md, out_md);
Expand All @@ -121,7 +122,7 @@ inline static mkldnn::inner_product_backward_weights::primitive_desc GetFCBwdWei
const NDArray& output,
mkldnn::inner_product_forward::primitive_desc fwd_pd) {
auto data_md = GetMemDesc(data);
auto weight_md = GetFCWeightDesc(weight);
auto weight_md = GetFCWeightDesc(weight, data.shape()[0]);
auto out_md = GetMemDesc(output);
auto engine = CpuEngine::Get()->get_engine();
if (bias) {
Expand Down