Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MKLDNN] mkldnn RNN operator enhancement #17075

Merged
merged 3 commits into from
Dec 16, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions docs/static_site/src/pages/api/faq/env_var.md
Original file line number Diff line number Diff line change
Expand Up @@ -289,11 +289,11 @@ If ctypes is used, it must be `mxnet._ctypes.ndarray.NDArrayBase`.
If no such algorithm exists given other constraints, MXNet will error out. This variable affects the choice
of CUDNN convolution algorithms. Please see [CUDNN developer guide](https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html) for more details.

* MXNET_CPU_PARALLEL_COPY_SIZE
* MXNET_CPU_PARALLEL_SIZE
- Values: Int ```(default=200000)```
- The minimum size to call parallel copy by OpenMP in CPU2CPU mode.
- When the array size is bigger than or equal to this threshold, NDArray::Copy(from, to) is implemented by OpenMP with the Recommended OMP Thread Count.
- When the array size is less than this threshold, NDArray::Copy(from , to)) is implemented by memcpy in single thread.
- The minimum size to call parallel operations by OpenMP for CPU context.
- When the array size is bigger than or equal to this threshold, the operation implemented by OpenMP is executed with the Recommended OMP Thread Count.
- When the array size is less than this threshold, the operation is implemented naively in single thread.

* MXNET_OPTIMIZER_AGGREGATION_SIZE
- Values: Int ```(default=4)```
Expand Down Expand Up @@ -349,6 +349,10 @@ If ctypes is used, it must be `mxnet._ctypes.ndarray.NDArrayBase`.
- Values: 0(false) or 1(true) ```(default=1)```
- If this variable is set, MXNet will simplify the computation graph, eliminating duplicated operations on the same inputs.

* MXNET_USE_MKLDNN_RNN
- Values: 0(false) or 1(true) ```(default=1)```
- This variable controls whether to use the MKL-DNN backend in fused RNN operator for CPU context. There are two fusion implementations of RNN operator in MXNet. The MKL-DNN implementation has a better performance than the naive one, but the latter is more stable in the backward operation currently.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean the MKL-DNN fused kernel is not stable in backward pass? Or MKL-DNN version is not flexible as naive one due to some implementation limitation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is not stable in the backward pass. I have trained the bucketing model (https://github.com/apache/incubator-mxnet/tree/master/example/rnn/bucketing) with the backend of MKL-DNN RNN Backward. It resulted in a convergent optimizing curve. But it has not been tested in other applications for training a model. So I provided an env variable for users to switch to the naive implementation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, got it. I thought the results are not stable previously :) The similar description will be only verified with a limited but not broader test cases.


Settings for Minimum Memory Usage
---------------------------------
- Make sure ```min(MXNET_EXEC_NUM_TEMP, MXNET_GPU_WORKER_NTHREADS) = 1```
Expand Down
20 changes: 19 additions & 1 deletion src/common/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -769,7 +769,7 @@ inline void EmplaceBackZeros(const NDArrayStorageType stype, const mxnet::TShape
*/
template<typename DType>
inline void ParallelCopy(DType* dst, const DType* src, index_t size) {
static index_t copy_block_size = dmlc::GetEnv("MXNET_CPU_PARALLEL_COPY_SIZE", 200000);
static index_t copy_block_size = dmlc::GetEnv("MXNET_CPU_PARALLEL_SIZE", 200000);
if (size >= copy_block_size) {
#pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount())
for (index_t i = 0; i < size; ++i) {
Expand All @@ -780,6 +780,24 @@ inline void ParallelCopy(DType* dst, const DType* src, index_t size) {
}
}

/*!
* \breif parallelize add by OpenMP
*/
template<typename DType>
inline void ParallelAdd(DType* dst, const DType* src, index_t size) {
static index_t add_block_size = dmlc::GetEnv("MXNET_CPU_PARALLEL_SIZE", 200000);
if (size >= add_block_size) {
#pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount())
for (index_t i = 0; i < size; ++i) {
dst[i] += src[i];
}
} else {
for (index_t i = 0; i < size; ++i) {
dst[i] += src[i];
}
}
}

/*!
* \brief If numpy compatibility is turned off (default), the shapes passed in
* by users follow the legacy shape definition:
Expand Down
35 changes: 18 additions & 17 deletions src/operator/nn/mkldnn/mkldnn_rnn-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,33 +120,32 @@ class RnnPrimitive {
template<typename rnn_fwd, typename... Args>
static RnnPrimitive Create(Args&&... args) {
RnnPrimitive rnn_fwd_prim;
rnn_fwd_prim.pd_.reset(
new typename rnn_fwd::desc(std::forward<Args>(args)...),
[](typename rnn_fwd::desc* pd) {
delete reinterpret_cast<typename rnn_fwd::desc*>(pd);
auto fwd_desc = typename rnn_fwd::desc(std::forward<Args>(args)...);
rnn_fwd_prim.fwd_pd_.reset(
new typename rnn_fwd::primitive_desc(fwd_desc, CpuEngine::Get()->get_engine()),
[](typename rnn_fwd::primitive_desc* pd) {
delete reinterpret_cast<typename rnn_fwd::primitive_desc*>(pd);
});
const typename rnn_fwd::desc& fwd_desc =
*(reinterpret_cast<typename rnn_fwd::desc*>(rnn_fwd_prim.pd_.get()));
typename rnn_fwd::primitive_desc fwd_pd(fwd_desc, CpuEngine::Get()->get_engine());
rnn_fwd_prim.weights_layer_desc_ = fwd_pd.weights_layer_desc();
rnn_fwd_prim.weights_iter_desc_ = fwd_pd.weights_iter_desc();
rnn_fwd_prim.workspace_desc_ = fwd_pd.workspace_desc();
auto fwd_pd = reinterpret_cast<typename rnn_fwd::primitive_desc*>(rnn_fwd_prim.fwd_pd_.get());
rnn_fwd_prim.weights_layer_desc_ = fwd_pd->weights_layer_desc();
rnn_fwd_prim.weights_iter_desc_ = fwd_pd->weights_iter_desc();
rnn_fwd_prim.workspace_desc_ = fwd_pd->workspace_desc();

rnn_fwd_prim.primitive_ = std::shared_ptr<mkldnn::primitive>(new rnn_fwd(fwd_pd));
rnn_fwd_prim.primitive_ = std::shared_ptr<mkldnn::primitive>(new rnn_fwd(*fwd_pd));

return rnn_fwd_prim;
}

RnnPrimitive() {
this->pd_ = nullptr;
this->fwd_pd_ = nullptr;
this->primitive_ = nullptr;
this->weights_layer_desc_ = mkldnn::memory::desc();
this->weights_iter_desc_ = mkldnn::memory::desc();
this->workspace_desc_ = mkldnn::memory::desc();
}

RnnPrimitive(const RnnPrimitive& rnn_fwd_prim) {
this->pd_ = rnn_fwd_prim.pd_;
this->fwd_pd_ = rnn_fwd_prim.fwd_pd_;
this->primitive_ = rnn_fwd_prim.primitive_;
this->weights_layer_desc_ = rnn_fwd_prim.weights_layer_desc_;
this->weights_iter_desc_ = rnn_fwd_prim.weights_iter_desc_;
Expand All @@ -155,7 +154,7 @@ class RnnPrimitive {

RnnPrimitive& operator=(const RnnPrimitive& rnn_fwd_prim) {
if (this != &rnn_fwd_prim) {
this->pd_ = rnn_fwd_prim.pd_;
this->fwd_pd_ = rnn_fwd_prim.fwd_pd_;
this->primitive_ = rnn_fwd_prim.primitive_;
this->weights_layer_desc_ = rnn_fwd_prim.weights_layer_desc_;
this->weights_iter_desc_ = rnn_fwd_prim.weights_iter_desc_;
Expand All @@ -165,7 +164,7 @@ class RnnPrimitive {
return *this;
}

const void* GetPrimDesc() const { return pd_.get(); }
const void* GetPrimDesc() const { return fwd_pd_.get(); }
const mkldnn::primitive& GetPrim() const { return *primitive_; }

const mkldnn::memory::desc& GetLayerDesc() const {
Expand All @@ -181,7 +180,7 @@ class RnnPrimitive {
}

private:
std::shared_ptr<void> pd_;
std::shared_ptr<void> fwd_pd_;
std::shared_ptr<mkldnn::primitive> primitive_;
mkldnn::memory::desc weights_layer_desc_;
mkldnn::memory::desc weights_iter_desc_;
Expand Down Expand Up @@ -370,7 +369,9 @@ class MKLDNNRnnBackward {
void SetDataGradsMem(void* diff_src, void* diff_state, void* diff_statecell,
void* diff_out, void* diff_state_out, void* diff_statecell_out,
const int dtype = mshadow::kFloat32);
void CommitWeightsDiff(void* diff_weights, void* diff_bias, const int dtype = mshadow::kFloat32);
void CommitWeightsDiff(void* diff_weights, void* diff_bias,
const OpReqType req,
const int dtype = mshadow::kFloat32);

const mkldnn::primitive& GetBwd() const { return *bwd_.primitive_; }
const mkldnn_args_map_t& GetArgsMap() const { return net_args_; }
Expand Down
Loading