Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[v1.x] Backport #17702 and #17872 to v1.x branch (#18038)
Browse files Browse the repository at this point in the history
* Support projection feature for LSTM on CPU (Only Inference) (#17702)

* Support projection feature for LSTM on CPU

* test solution for -Werror=maybe-uninitialized

* Check device type when create state

* Document the projection feature of LSTM for RNN operator

* Minor fix

* Re-run CI

* Fix issue of zeros gradients w.r.t. RNN bias when num_layers > 1 (#17872)

* Fix issue of zeros gradients w.r.t. RNN bias when num_layers > 1

* Use nd.copy() to initialize parameters of new operator

* Add check for output states

* Initialize i2h/h2h_weights with zeros for rnn_relu/tanh, and reduce size

* Split fused rnn layer test into tests of individual mode

* Skip lstm and gru tests on CPU context without DNNL
  • Loading branch information
zixuanweeei authored Apr 15, 2020
1 parent 0d3aa67 commit 6fa374b
Show file tree
Hide file tree
Showing 8 changed files with 311 additions and 104 deletions.
1 change: 1 addition & 0 deletions python/mxnet/gluon/rnn/rnn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(self, hidden_size, num_layers, layout,
self._lstm_state_clip_nan = lstm_state_clip_nan
self._dtype = dtype
self._use_sequence_length = use_sequence_length
self.skip_states = None

self._gates = {'rnn_relu': 1, 'rnn_tanh': 1, 'lstm': 4, 'gru': 3}[mode]

Expand Down
8 changes: 0 additions & 8 deletions src/operator/nn/mkldnn/mkldnn_base-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,14 +158,6 @@ static inline bool SupportMKLDNN(int dtype, const mxnet::TShape &shape) {
(ndim == 1 || ndim == 2 || ndim == 4);
}

static inline bool SupportMKLDNNRnn(const NDArray &input) {
if (input.dtype() == mshadow::kFloat32 && input.shape().ndim() == 3
&& dmlc::GetEnv("MXNET_USE_MKLDNN_RNN", 1)) {
return true;
}
return false;
}

static inline bool SupportMKLDNNQuantize(int dtype) {
return dtype == mshadow::kFloat32 || dtype == mshadow::kInt8 ||
dtype == mshadow::kUint8 || dtype == mshadow::kBfloat16;
Expand Down
16 changes: 14 additions & 2 deletions src/operator/nn/mkldnn/mkldnn_rnn-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ struct MKLDNNRnnLayerParam {
size_t reserve_size; // used for the reserved cached memory in Backward
size_t single_w_size; // weights size of a single cell
size_t single_b_size; // bias size of a single cell from mkl-dnn
size_t naive_single_b_size; // bias size of a single cell from framework
size_t single_state_size; // state size of a single cell, hy, cy
size_t native_single_b_size; // bias size of a single cell from framework
size_t single_state_size; // state size of a single cell, hy, cy

MKLDNNRnnLayerParam(int num_layer, int batch_size, int seq_len,
int input_size, int state_size,
Expand Down Expand Up @@ -441,6 +441,18 @@ class MKLDNNRnnOp {
const std::vector<NDArray> &outputs);
};

inline bool SupportMKLDNNRnn(const int input_dtype) {
if (input_dtype == mshadow::kFloat32 && dmlc::GetEnv("MXNET_USE_MKLDNN_RNN", 1)) {
return true;
}
return false;
}

inline bool SupportMKLDNNRnn(const RNNParam &param, const int input_dtype) {
if (param.projection_size.has_value()) return false;
return SupportMKLDNNRnn(input_dtype);
}

} // namespace op
} // namespace mxnet

Expand Down
43 changes: 21 additions & 22 deletions src/operator/nn/mkldnn/mkldnn_rnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ void MKLDNNRnnLayerParam::SetDims() {
// unidirectional size of a single cell
single_w_size = (input_size + state_size) * ngates * state_size;
single_b_size = nbias * state_size;
naive_single_b_size = ngates * state_size * 2; // naive RNN variants have double bias
native_single_b_size = ngates * state_size * 2; // native RNN variants have double bias
single_state_size = batch_size * state_size;

// Get workspace size for cached weights memory
Expand Down Expand Up @@ -265,7 +265,7 @@ RnnBwdPrimitive GetRnnBwdPrim(const MKLDNNRnnForwardTraining &fwd,
}

/*
* Naive weights layout is:
* Native weights layout is:
* | l0_l2r_wx | l0_l2r_wh | l0_r2l_wx | l0_r2l_wh |
* | l1_l2r_wx | l1_l2r_wh | l1_r2l_wx | l1_r2l_wh |
* ...
Expand Down Expand Up @@ -339,7 +339,6 @@ FUNC(MKLDNN_ARG_DIFF_##NAME, ARGS.at(MKLDNN_ARG_##NAME).get_desc(), HANDLE)
void MKLDNNRnnForward::SetNewDataMem(void* x, void* hx, void* cx,
void* y, void* hy, void* cy,
const int dtype) {
using dims = mkldnn::memory::dims;
using desc = mkldnn::memory::desc;
using format_tag = mkldnn::memory::format_tag;
auto& cpu_engine = CpuEngine::Get()->get_engine();
Expand Down Expand Up @@ -462,12 +461,12 @@ inline void EmplaceNetArgs(mkldnn_args_map_t* net_args, const int arg_name,
}

/*
* Copy naive memory to mkldnn-format memory. It will initialize the memory
* when first invoked. Then, the naive weight_layer and weight_iter are
* Copy native memory to mkldnn-format memory. It will initialize the memory
* when first invoked. Then, the native weight_layer and weight_iter are
* concatenated to xxx_xx_r memory. Per the different gates order of GRU,
* it will swap the memory blocks of gates among concatenated memory
* inplace. From then on, the xxx_xx_r memory is reordered to target
* memory with preferred format_tag. Finally, naive bias is fused to MKLDNN
* memory with preferred format_tag. Finally, native bias is fused to MKLDNN
* bias memory.
*/
void MKLDNNRnnForward::SetWeightsMem(MKLDNNRnnMemMgr* mgr, void *w_ptr, void *b_ptr,
Expand Down Expand Up @@ -551,13 +550,13 @@ void MKLDNNRnnForward::SetWeightsMem(MKLDNNRnnMemMgr* mgr, void *w_ptr, void *b_

// Process bias
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
DType* naive_b_ptr = static_cast<DType*>(b_ptr);
DType* native_b_ptr = static_cast<DType*>(b_ptr);
DType* fused_bias = static_cast<DType*>(bias_->get_data_handle());
for (int lyr = 0; lyr < param_.num_layer; ++lyr) {
for (int d = 0; d < param_.bidirectional + 1; ++d) {
FuseBias<DType>(fused_bias, naive_b_ptr, param_.mode, param_.state_size);
FuseBias<DType>(fused_bias, native_b_ptr, param_.mode, param_.state_size);
fused_bias += param_.single_b_size;
naive_b_ptr += param_.naive_single_b_size;
native_b_ptr += param_.native_single_b_size;
}
}
});
Expand Down Expand Up @@ -632,7 +631,6 @@ void MKLDNNRnnOp::Init(const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
using memory = mkldnn::memory;
using format_tag = mkldnn::memory::format_tag;

// In the `autograd.record()` context, RNNOp is required to run into
Expand Down Expand Up @@ -674,10 +672,10 @@ void MKLDNNRnnOp::Init(const OpContext &ctx,
default_param.bidirectional + 1, default_param.mode)) * dtype_bytes;
for (auto& fwd_layer : fwd_inf_vec_) {
size_t single_w_bytes = fwd_layer.GetParam().single_w_size * dtype_bytes;
size_t single_b_bytes = fwd_layer.GetParam().naive_single_b_size * dtype_bytes;
size_t single_b_bytes = fwd_layer.GetParam().native_single_b_size * dtype_bytes;
size_t directions = fwd_layer.GetParam().bidirectional ? 2 : 1;
size_t layer_weights_bytes = single_w_bytes * directions;
size_t layer_bias_bytes = single_b_bytes * directions; // Naive MXNet has double bias
size_t layer_bias_bytes = single_b_bytes * directions; // Native MXNet has double bias

if (!fwd_layer.IsInitialized() || is_training)
fwd_layer.SetWeightsMem(&(this->mgr_), weights_ptr, bias_ptr, is_training, dtype);
Expand Down Expand Up @@ -857,7 +855,7 @@ void MKLDNNRnnBackward::CommitWeightsGrads(void* diff_weights, void* diff_bias,
const size_t wx_size = param.input_size * param.state_size * ngates;
const size_t wh_size = param.state_size * param.state_size * ngates;

/* naive weights layout is:
/* native weights layout is:
1st-layer: | wx_lr | wh_lr | wx_rl | wh_rl |
2st-layer: | wx_lr | wh_lr | wx_rl | wh_rl |
size: | wxh_bytes |
Expand Down Expand Up @@ -903,33 +901,33 @@ void MKLDNNRnnBackward::CommitWeightsGrads(void* diff_weights, void* diff_bias,
});

const size_t bias_size = param.single_b_size;
const size_t naive_bias_size = param.naive_single_b_size;
const size_t native_bias_size = param.native_single_b_size;
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
DType* native_bias = static_cast<DType *>(diff_bias);
DType* diff_bias_ptr = static_cast<DType *>(this->diff_bias_->get_data_handle());
OPREQTYPE_SWITCH(req, DType, FAccGrad, {
if (param.mode != rnn_enum::kGru) {
for (int shift = 0; shift < num_layer * direction; ++shift) {
FAccGrad(native_bias + shift * naive_bias_size,
FAccGrad(native_bias + shift * native_bias_size,
diff_bias_ptr + shift * bias_size, bias_size);
FAccGrad(native_bias + shift * naive_bias_size + bias_size,
FAccGrad(native_bias + shift * native_bias_size + bias_size,
diff_bias_ptr + shift * bias_size, bias_size);
}
} else {
const size_t bias_size_per_gate = param.state_size;
for (int shift = 0; shift < num_layer * direction; ++shift) {
DType* native_reset = native_bias + shift * naive_bias_size;
DType* native_reset = native_bias + shift * native_bias_size;
DType* native_update = native_reset + bias_size_per_gate;
DType* update = diff_bias_ptr + shift * bias_size;
DType* reset = update + bias_size_per_gate;

FAccGrad(native_update, update, bias_size_per_gate);
FAccGrad(native_reset, reset, bias_size_per_gate);
FAccGrad(native_update + naive_bias_size / 2, update, bias_size_per_gate);
FAccGrad(native_reset + naive_bias_size / 2, reset, bias_size_per_gate);
FAccGrad(native_update + native_bias_size / 2, update, bias_size_per_gate);
FAccGrad(native_reset + native_bias_size / 2, reset, bias_size_per_gate);

DType* native_new_bx = native_update + bias_size_per_gate;
DType* native_new_bh = native_new_bx + naive_bias_size / 2;
DType* native_new_bh = native_new_bx + native_bias_size / 2;
DType* new_bx = reset + bias_size_per_gate;
DType* new_bh = new_bx + bias_size_per_gate;
FAccGrad(native_new_bx, new_bx, bias_size_per_gate);
Expand Down Expand Up @@ -1186,10 +1184,11 @@ void MKLDNNRnnOp::Backward(const OpContext& ctx,

// Commit weights diff
if (req[rnn_enum::kParams] != kNullOp) {
const int directions = default_param.bidirectional ? 2 : 1;
for (size_t lyr = 0; lyr < bwd_vec_.size(); ++lyr) {
bwd_vec_.at(lyr).CommitWeightsGrads(dw, db, req[rnn_enum::kParams], w_dtype);
dw += full_param_.layer_params.at(lyr).single_w_size * w_bytes;
db += full_param_.layer_params.at(lyr).single_b_size * w_bytes;
dw += full_param_.layer_params.at(lyr).single_w_size * directions * w_bytes;
db += full_param_.layer_params.at(lyr).native_single_b_size * directions * w_bytes;
}
}
}
Expand Down
33 changes: 24 additions & 9 deletions src/operator/rnn-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ inline int GetRnnBiasSize(int num_layer,
inline size_t GetRNNWorkspaceSize(int seq_length,
int batch_size,
int hidden_size,
int projection_size,
int direction,
int mode) {
size_t size = 0;
Expand Down Expand Up @@ -324,6 +325,7 @@ void RNNForwardInference(DType* ws,
const int batch_size,
const int input_size,
const int state_size,
const int projection_size,
DType* x_ptr,
DType* hx_ptr,
DType* cx_ptr,
Expand All @@ -336,8 +338,8 @@ void RNNForwardInference(DType* ws,
switch (mode) {
case rnn_enum::kLstm:
LstmForwardInference<DType>(ws, state_outputs, num_layers, direction, seq_length,
batch_size, input_size, state_size, x_ptr, hx_ptr, cx_ptr,
w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr);
batch_size, input_size, state_size, projection_size,
x_ptr, hx_ptr, cx_ptr, w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr);
break;
case rnn_enum::kGru:
GruForwardInference<DType>(ws, state_outputs, num_layers, direction, seq_length,
Expand Down Expand Up @@ -511,10 +513,7 @@ class RNNOp {
this->temp_init_space_ = false;
this->reserve_cpu_space_size_ = 0;
this->temp_cpu_space_size_ = 0;
if (param_.projection_size.has_value()) {
LOG(FATAL) <<
"hidden layer projection is only supported for GPU with CuDNN later than 7.1.1";
}

if (param_.lstm_state_clip_min.has_value()
|| param_.lstm_state_clip_max.has_value()) {
LOG(FATAL) << "LSTM state clipping is only supported for GPU with CuDNN later than 7.2.1";
Expand Down Expand Up @@ -843,9 +842,14 @@ class RNNOp {
#endif // MXNET_USE_CUDNN == 1 && defined(__CUDACC__)

if (ctx_.dev_type == kCPU) {
int projection_size = 0;
if (param_.projection_size.has_value()) {
projection_size = param_.projection_size.value();
}

// allocate temp space
const size_t work_cpu_space_size = GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_,
param_.state_size, direction, param_.mode);
param_.state_size, projection_size, direction, param_.mode);
if (!temp_init_space_ || temp_cpu_space_size_ < work_cpu_space_size) {
temp_cpu_space_size_ = work_cpu_space_size;
temp_cpu_space_ = NDArray(TShape({static_cast<dim_t>(temp_cpu_space_size_)}), ctx_,
Expand All @@ -856,6 +860,9 @@ class RNNOp {

if (ctx.is_train || ctx.need_grad) {
// allocate reserve space
if (param_.projection_size.has_value()) {
LOG(FATAL) << "No training support for LSTM with projection on CPU currently.";
}

const size_t r_size = GetRNNReserveSpaceSize(param_.num_layers, direction,
param_.seq_length_, param_.batch_size_,
Expand Down Expand Up @@ -896,6 +903,7 @@ class RNNOp {
param_.batch_size_,
param_.input_size_,
param_.state_size,
projection_size,
x.dptr_,
hx.dptr_,
cx_ptr,
Expand Down Expand Up @@ -1096,10 +1104,17 @@ class RNNOp {
#endif // MXNET_USE_CUDNN == 1 && defined(__CUDACC__)

if (ctx_.dev_type == kCPU) {
int projection_size = 0;
if (param_.projection_size.has_value()) {
// TODO(zixuanweeei): Add training support for LSTM with projection on CPU.
// projection_size = param_.projection_size.value();
LOG(FATAL) << "No training support for LSTM with projection on CPU currently.";
}

// allocate temp space
const size_t work_cpu_space_size =
GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_,
param_.state_size, direction, param_.mode);
GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_, param_.state_size,
projection_size, direction, param_.mode);
if (!temp_init_space_ || temp_cpu_space_size_ != work_cpu_space_size) {
LOG(FATAL) << "Check temp init error";
}
Expand Down
44 changes: 29 additions & 15 deletions src/operator/rnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -190,20 +190,19 @@ static std::vector<ResourceRequest> RNNResourceEx(const NodeAttrs& attrs, const
return request;
}

#if MXNET_USE_MKLDNN == 1
inline static bool RNNStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
DispatchMode wanted_mode = DispatchMode::kFCompute;

#if MXNET_USE_MKLDNN == 1
wanted_mode = DispatchMode::kFComputeEx;
#endif // MXNET_USE_MKLDNN == 1

return storage_type_assign(out_attrs, mxnet::kDefaultStorage,
dispatch_mode, wanted_mode);
const RNNParam& param = nnvm::get<RNNParam>(attrs.parsed);
const bool support_mkldnn_rnn =
!param.projection_size.has_value() && dmlc::GetEnv("MXNET_USE_MKLDNN_RNN", 1);
return MKLDNNStorageType(attrs, dev_mask, support_mkldnn_rnn,
dispatch_mode, in_attrs, out_attrs);
}
#endif // MXNET_USE_MKLDNN == 1

struct RNNGrad {
const char *op_name;
Expand Down Expand Up @@ -246,9 +245,7 @@ static OpStatePtr CreateRNNState(const nnvm::NodeAttrs &attrs,
}

#if MXNET_USE_MKLDNN == 1
if ((in_types[0] == mshadow::kFloat32 || in_types[0] == mshadow::kFloat16)
&& in_shapes[0].ndim() == 3 && ctx.dev_type == kCPU
&& dmlc::GetEnv("MXNET_USE_MKLDNN_RNN", 1)) {
if (ctx.dev_type == kCPU && SupportMKLDNNRnn(param, in_types[rnn_enum::kData])) {
const mxnet::TShape& data_shape = in_shapes[rnn_enum::kData];
state = OpStatePtr::Create<MKLDNNRnnOp>(param, data_shape[0],
data_shape[1], data_shape[2]);
Expand All @@ -274,7 +271,7 @@ static void RNNStatefulComputeExCPU(const OpStatePtr& state_ptr,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
if (SupportMKLDNNRnn(inputs[0])) {
if (SupportMKLDNNRnn(inputs[rnn_enum::kData].dtype())) {
MKLDNNRnnOp& op = state_ptr.get_state<MKLDNNRnnOp>();
op.Forward(ctx, inputs, req, outputs);
} else {
Expand All @@ -287,7 +284,7 @@ static void RNNStatefulGradComputeExCPU(const OpStatePtr& state_ptr,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
if (SupportMKLDNNRnn(inputs[0])) {
if (SupportMKLDNNRnn(inputs[rnn_enum::kData].dtype())) {
MKLDNNRnnOp& op = state_ptr.get_state<MKLDNNRnnOp>();
op.Backward(ctx, inputs, req, outputs);
} else {
Expand Down Expand Up @@ -338,6 +335,23 @@ Long Short-Term Memory - Hochreiter, 1997. http://www.bioinf.jku.at/publications
h_t = o_t * \tanh(c_t)
\end{array}
With the projection size being set, LSTM could use the projection feature to reduce the parameters
size and give some speedups without significant damage to the accuracy.
Long Short-Term Memory Based Recurrent Neural Network Architectures for Large Vocabulary Speech
Recognition - Sak et al. 2014. https://arxiv.org/abs/1402.1128
.. math::
\begin{array}{ll}
i_t = \mathrm{sigmoid}(W_{ii} x_t + b_{ii} + W_{ri} r_{(t-1)} + b_{ri}) \\
f_t = \mathrm{sigmoid}(W_{if} x_t + b_{if} + W_{rf} r_{(t-1)} + b_{rf}) \\
g_t = \tanh(W_{ig} x_t + b_{ig} + W_{rc} r_{(t-1)} + b_{rg}) \\
o_t = \mathrm{sigmoid}(W_{io} x_t + b_{o} + W_{ro} r_{(t-1)} + b_{ro}) \\
c_t = f_t * c_{(t-1)} + i_t * g_t \\
h_t = o_t * \tanh(c_t)
r_t = W_{hr} h_t
\end{array}
**GRU**
Gated Recurrent Unit - Cho et al. 2014. http://arxiv.org/abs/1406.1078
Expand Down Expand Up @@ -385,10 +399,10 @@ The definition of GRU here is slightly different from paper but compatible with
})
.set_attr<mxnet::FInferShape>("FInferShape", RNNShape)
.set_attr<nnvm::FInferType>("FInferType", RNNType)
.set_attr<FInferStorageType>("FInferStorageType", RNNStorageType)
.set_attr<FCreateOpState>("FCreateOpState", CreateRNNState)
.set_attr<FStatefulCompute>("FStatefulCompute<cpu>", RNNStatefulCompute<cpu>)
#if MXNET_USE_MKLDNN == 1
.set_attr<FInferStorageType>("FInferStorageType", RNNStorageType)
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", RNNStatefulComputeExCPU)
#endif
Expand Down Expand Up @@ -427,9 +441,9 @@ NNVM_REGISTER_OP(_backward_RNN)
.set_attr_parser(ParamParser<RNNParam>)
.set_attr<bool>("TIsLayerOpBackward", true)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FInferStorageType>("FInferStorageType", RNNStorageType)
.set_attr<FStatefulCompute>("FStatefulCompute<cpu>", RNNStatefulGradCompute<cpu>)
#if MXNET_USE_MKLDNN == 1
.set_attr<FInferStorageType>("FInferStorageType", RNNStorageType)
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", RNNStatefulGradComputeExCPU)
#endif
Expand Down
Loading

0 comments on commit 6fa374b

Please sign in to comment.