From a4c4fa011e12d77e9b222079095fbab3db4493db Mon Sep 17 00:00:00 2001 From: bgawrych Date: Fri, 16 Jul 2021 23:30:24 +0200 Subject: [PATCH] [v1.x][Feature] Add flag for disabling oneDNN BRGEMM implementation of FC (#20450) * Add flag for disabling oneDNN BRGEMM implementation of FC * Review fixes * Update env_var.md --- docs/static_site/src/pages/api/faq/env_var.md | 4 ++++ src/operator/nn/mkldnn/mkldnn_base-inl.h | 3 ++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/docs/static_site/src/pages/api/faq/env_var.md b/docs/static_site/src/pages/api/faq/env_var.md index 831f7ee3e043..6336b3da955d 100644 --- a/docs/static_site/src/pages/api/faq/env_var.md +++ b/docs/static_site/src/pages/api/faq/env_var.md @@ -298,6 +298,10 @@ If ctypes is used, it must be `mxnet._ctypes.ndarray.NDArrayBase`. - Values: Int ```(default=-1)``` - Flag to set num of elements that MKLDNN cache can hold. Default is -1 which means cache size is unbounded. Should only be set if your model has variable input shapes, as cache size may grow unbounded. The number represents the number of items in the cache and is proportional to the number of layers that use MKLDNN and different input shape. +* MXNET_MKLDNN_DISABLE_BRGEMM_FC + - Values: 0, 1 ```(default=1)``` + - Flag which disables BRGEMM kernels in FullyConnected executed with MKLDNN support - Should only be set to 0 if your model has constant input shapes or FullyConnected is calculated with large tensors. Supported on machines with AVX512-VNNI. + * MXNET_ENFORCE_DETERMINISM - Values: 0(false) or 1(true) ```(default=0)``` - If set to true, MXNet will only use deterministic algorithms in forward and backward computation. diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h index 48c7445ce642..5ac053e0b6a6 100644 --- a/src/operator/nn/mkldnn/mkldnn_base-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h @@ -313,7 +313,8 @@ inline static mkldnn::memory::desc GetFCWeightDesc(const NDArray& arr, int dtype dims[i] = arr.shape()[i]; auto format = mkldnn::memory::format_tag::any; // for batch 256 alexnet benchmark test - if (dims.size() == 2) { + const bool brgemm_disabled = dmlc::GetEnv("MXNET_MKLDNN_DISABLE_BRGEMM_FC", true); + if (dims.size() == 2 && brgemm_disabled) { format = mkldnn::memory::format_tag::ab; }