diff --git a/docs/python_docs/environment.yml b/docs/python_docs/environment.yml index 6c4a5beebe66..91f0f12c14f8 100644 --- a/docs/python_docs/environment.yml +++ b/docs/python_docs/environment.yml @@ -22,6 +22,7 @@ dependencies: - conda>=4.6.13 - pip - python +- setuptools==49.6.0 - jupyter - sphinx==2.4.0 - matplotlib diff --git a/include/mxnet/op_attr_types.h b/include/mxnet/op_attr_types.h index 88d21f23ea0c..1802cca8edd1 100644 --- a/include/mxnet/op_attr_types.h +++ b/include/mxnet/op_attr_types.h @@ -348,6 +348,21 @@ using FAvoidQuantizeInput = std::function; +/*! + * \brief Register a function to determine if the input of a quantized operator + * needs to be quantized asymmetrically. + */ +using FNeedAsymQuantizeInput = std::function; + +/*! + * \brief Register a function to determine if the output of a quantized operator + * needs to be dequantized. This is usually used for the quantized operators + * which can produce fp32 outputs directly. + */ +using FAvoidDequantizeOutput = std::function; + /*! * \brief Register a function to determine if the input of a quantized operator * needs to be calibrated. This is usually used for the quantized operators diff --git a/python/mxnet/io/io.py b/python/mxnet/io/io.py index 5a022ea1c81c..f051836d5aeb 100644 --- a/python/mxnet/io/io.py +++ b/python/mxnet/io/io.py @@ -37,7 +37,7 @@ from ..ndarray import array from ..ndarray import concat, tile -from .utils import _init_data, _has_instance, _getdata_by_idx +from .utils import _init_data, _has_instance, _getdata_by_idx, _slice_along_batch_axis class DataDesc(namedtuple('DataDesc', ['name', 'shape'])): """DataDesc is used to store name, shape, type and layout @@ -602,10 +602,12 @@ class NDArrayIter(DataIter): The data name. label_name : str, optional The label name. + layout : str, optional + The data layout """ def __init__(self, data, label=None, batch_size=1, shuffle=False, last_batch_handle='pad', data_name='data', - label_name='softmax_label'): + label_name='softmax_label', layout='NCHW'): super(NDArrayIter, self).__init__(batch_size) self.data = _init_data(data, allow_empty=False, default_name=data_name) @@ -631,20 +633,27 @@ def __init__(self, data, label=None, batch_size=1, shuffle=False, # used for 'roll_over' self._cache_data = None self._cache_label = None + self.layout = layout @property def provide_data(self): """The name and shape of data provided by this iterator.""" + batch_axis = self.layout.find('N') return [ - DataDesc(k, tuple([self.batch_size] + list(v.shape[1:])), v.dtype) + DataDesc(k, tuple(list(v.shape[:batch_axis]) + \ + [self.batch_size] + list(v.shape[batch_axis + 1:])), + v.dtype, layout=self.layout) for k, v in self.data ] @property def provide_label(self): """The name and shape of label provided by this iterator.""" + batch_axis = self.layout.find('N') return [ - DataDesc(k, tuple([self.batch_size] + list(v.shape[1:])), v.dtype) + DataDesc(k, tuple(list(v.shape[:batch_axis]) + \ + [self.batch_size] + list(v.shape[batch_axis + 1:])), + v.dtype, layout=self.layout) for k, v in self.label ] @@ -681,7 +690,7 @@ def next(self): data = self.getdata() label = self.getlabel() # iter should stop when last batch is not complete - if data[0].shape[0] != self.batch_size: + if data[0].shape[self.layout.find('N')] != self.batch_size: # in this case, cache it for next epoch self._cache_data = data self._cache_label = label @@ -697,7 +706,7 @@ def _getdata(self, data_source, start=None, end=None): end = data_source[0][1].shape[0] if data_source else 0 s = slice(start, end) return [ - x[1][s] + _slice_along_batch_axis(x[1], s, self.layout.find('N')) if isinstance(x[1], (np.ndarray, NDArray)) else # h5py (only supports indices in increasing order) array(x[1][sorted(self.idx[s])][[ @@ -716,7 +725,7 @@ def _concat(self, first_data, second_data): concat( first_data[i], second_data[i], - dim=0 + dim=self.layout.find('N') ) for i in range(len(first_data)) ] diff --git a/python/mxnet/io/utils.py b/python/mxnet/io/utils.py index 55ba34aea426..5003dd10ccdb 100644 --- a/python/mxnet/io/utils.py +++ b/python/mxnet/io/utils.py @@ -84,3 +84,8 @@ def _getdata_by_idx(data, idx): shuffle_data.append((k, array(v.asnumpy()[idx], v.context))) return shuffle_data + +def _slice_along_batch_axis(data, s, batch_axis): + """Apply slice along the batch axis""" + ret = data.slice_axis(axis=batch_axis, begin=s.start, end=s.stop) + return ret diff --git a/src/operator/nn/mkldnn/mkldnn_rnn-inl.h b/src/operator/nn/mkldnn/mkldnn_rnn-inl.h index 6590c91d832e..85b1c3102424 100644 --- a/src/operator/nn/mkldnn/mkldnn_rnn-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_rnn-inl.h @@ -33,10 +33,42 @@ #include "../../rnn-inl.h" #include "./mkldnn_base-inl.h" +#include "../../quantization/quantized_rnn-inl.h" namespace mxnet { namespace op { +struct MKLDNNRnnParam : public dmlc::Parameter { + bool quantized; + + DMLC_DECLARE_PARAMETER(MKLDNNRnnParam) { + DMLC_DECLARE_FIELD(quantized).set_default(false).describe( + "Whether it's a quantized RNN operator"); + } +}; + +inline void MKLDNNMemoryReorder(const mkldnn::memory& src, const mkldnn::memory& dst) { +#if DMLC_CXX11_THREAD_LOCAL + static thread_local std::unordered_map reorderPrimitives; +#else + static MX_THREAD_LOCAL std::unordered_map reorderPrimitives; +#endif + OpSignature key{}; + key.AddSign(src); + key.AddSign(dst); + + auto it = reorderPrimitives.find(key); + if (it == reorderPrimitives.end()) { + auto reorder = mkldnn::reorder(src, dst); + it = AddToCache(&reorderPrimitives, key, reorder); + } + + mkldnn_args_map_t net_args; + net_args.emplace(MKLDNN_ARG_SRC, src); + net_args.emplace(MKLDNN_ARG_DST, dst); + MKLDNNStream::Get()->RegisterPrimArgs(it->second, net_args); +} + struct MKLDNNRnnLayerParam { using memory = mkldnn::memory; using dims = mkldnn::memory::dims; @@ -65,6 +97,10 @@ struct MKLDNNRnnLayerParam { size_t native_single_b_size; // bias size of a single cell from framework size_t single_state_size; // state size of a single cell, hy, cy + bool quantized; // whether this layer is quantized + bool enable_u8_output; // true by default, only be false when it is the last fusion layer of the + // quantized rnn operator + MKLDNNRnnLayerParam(int num_layer, int batch_size, int seq_len, @@ -79,7 +115,9 @@ struct MKLDNNRnnLayerParam { batch_size(batch_size), input_size(input_size), state_size(state_size), - seq_len(seq_len) {} + seq_len(seq_len), + quantized(false), + enable_u8_output(false) {} void SetDims(); }; @@ -87,10 +125,11 @@ struct MKLDNNRnnLayerParam { typedef std::vector LayerParamVector; struct MKLDNNRnnFullParam { RNNParam default_param; + MKLDNNRnnParam mkldnn_param; LayerParamVector layer_params; }; -MKLDNNRnnFullParam MKLDNNRnnFullParamParser(const RNNParam& rnn_param, +MKLDNNRnnFullParam MKLDNNRnnFullParamParser(const nnvm::NodeAttrs& attrs, const int seq_len, const int batch_size, const int input_size); @@ -102,7 +141,7 @@ class MKLDNNRnnMemMgr { // The memory buffer in NDArray life-cycle NDArray workspace_; // This points to the memory buffer from a NDArray - char* curr_mem; + char* curr_mem = nullptr; // The total bytes of the workspace of a MKLDNNRnnOp size_t mem_size = 0; // The current available memory bytes @@ -113,7 +152,7 @@ class MKLDNNRnnMemMgr { std::vector> mem_holder; public: - void Init(dim_t size, const Context& ctx, int dtype = mshadow::kFloat32); + void Init(const dim_t size, const Context& ctx, int dtype = mshadow::kFloat32); void RegisterMem(std::shared_ptr mem) { mem_holder.push_back(mem); @@ -122,6 +161,8 @@ class MKLDNNRnnMemMgr { mkldnn::memory* Alloc(const mkldnn::memory::desc& md); }; +typedef std::shared_ptr shared_mkldnn_attr_t; + /* * Rnn Primitive. */ @@ -131,15 +172,15 @@ class RnnPrimitive { * lstm_forward, lbr_gru_forward, vanilla_rnn_forward */ template - static RnnPrimitive Create(Args&&... args) { + static RnnPrimitive Create(const shared_mkldnn_attr_t attr, Args&&... args) { RnnPrimitive rnn_fwd_prim; auto fwd_desc = typename rnn_fwd::desc(std::forward(args)...); rnn_fwd_prim.fwd_pd_.reset( - new typename rnn_fwd::primitive_desc(fwd_desc, CpuEngine::Get()->get_engine()), - [](typename rnn_fwd::primitive_desc* pd) { - delete reinterpret_cast(pd); - }); + new typename rnn_fwd::primitive_desc( + fwd_desc, attr ? *attr : mkldnn::primitive_attr(), CpuEngine::Get()->get_engine()), + [](void* pd) { delete reinterpret_cast(pd); }); auto fwd_pd = reinterpret_cast(rnn_fwd_prim.fwd_pd_.get()); + rnn_fwd_prim.attr_ = attr; rnn_fwd_prim.weights_layer_desc_ = fwd_pd->weights_layer_desc(); rnn_fwd_prim.weights_iter_desc_ = fwd_pd->weights_iter_desc(); rnn_fwd_prim.workspace_desc_ = fwd_pd->workspace_desc(); @@ -150,6 +191,7 @@ class RnnPrimitive { } RnnPrimitive() { + this->attr_ = nullptr; this->fwd_pd_ = nullptr; this->primitive_ = nullptr; this->weights_layer_desc_ = mkldnn::memory::desc(); @@ -158,6 +200,7 @@ class RnnPrimitive { } RnnPrimitive(const RnnPrimitive& rnn_fwd_prim) { + this->attr_ = rnn_fwd_prim.attr_; this->fwd_pd_ = rnn_fwd_prim.fwd_pd_; this->primitive_ = rnn_fwd_prim.primitive_; this->weights_layer_desc_ = rnn_fwd_prim.weights_layer_desc_; @@ -167,6 +210,7 @@ class RnnPrimitive { RnnPrimitive& operator=(const RnnPrimitive& rnn_fwd_prim) { if (this != &rnn_fwd_prim) { + this->attr_ = rnn_fwd_prim.attr_; this->fwd_pd_ = rnn_fwd_prim.fwd_pd_; this->primitive_ = rnn_fwd_prim.primitive_; this->weights_layer_desc_ = rnn_fwd_prim.weights_layer_desc_; @@ -196,9 +240,14 @@ class RnnPrimitive { return workspace_desc_; } + const mkldnn::primitive_attr& GetPrimAttr() const { + return *attr_; + } + private: std::shared_ptr fwd_pd_; std::shared_ptr primitive_; + shared_mkldnn_attr_t attr_; mkldnn::memory::desc weights_layer_desc_; mkldnn::memory::desc weights_iter_desc_; mkldnn::memory::desc workspace_desc_; @@ -207,7 +256,8 @@ class RnnPrimitive { RnnPrimitive GetRnnFwdPrim(const MKLDNNRnnLayerParam& layer_param, const bool is_train, const NDArray& data, - const NDArray& params); + const NDArray& params, + const shared_mkldnn_attr_t attr = nullptr); /* * Use this to manage memory and primitive of MKL-DNN RNN forward inference. @@ -217,10 +267,11 @@ class MKLDNNRnnForward { MKLDNNRnnForward(const MKLDNNRnnLayerParam& layer_param, const bool is_train, const NDArray& data, - const NDArray& params) + const NDArray& params, + const shared_mkldnn_attr_t attr = nullptr) : initialized_(false), param_(layer_param), - fwd_inf_(GetRnnFwdPrim(layer_param, false, data, params)) {} + fwd_inf_(GetRnnFwdPrim(layer_param, false, data, params, attr)) {} void SetNewDataMem(void* x, void* hx, @@ -240,6 +291,10 @@ class MKLDNNRnnForward { return fwd_inf_.GetPrim(); } + void ResetFwd(const NDArray& data, const NDArray& params, const shared_mkldnn_attr_t& attr) { + fwd_inf_ = GetRnnFwdPrim(this->param_, false, data, params, attr); + } + const size_t GetSize(int dtype) const { size_t bytes = mshadow::mshadow_sizeof(dtype); size_t size = 0; @@ -458,13 +513,13 @@ class MKLDNNRnnBackward { */ class MKLDNNRnnOp { public: - explicit MKLDNNRnnOp(const RNNParam& param, + explicit MKLDNNRnnOp(const nnvm::NodeAttrs &attrs, const int seq_len, const int batch_size, const int input_size) : initialized_(false), weights_version_(0), - full_param_(MKLDNNRnnFullParamParser(param, seq_len, batch_size, input_size)) {} + full_param_(MKLDNNRnnFullParamParser(attrs, seq_len, batch_size, input_size)) {} void Forward(const OpContext& ctx, const std::vector& inputs, diff --git a/src/operator/nn/mkldnn/mkldnn_rnn.cc b/src/operator/nn/mkldnn/mkldnn_rnn.cc index 863abdcdeddd..e4c5f546dbf6 100644 --- a/src/operator/nn/mkldnn/mkldnn_rnn.cc +++ b/src/operator/nn/mkldnn/mkldnn_rnn.cc @@ -33,6 +33,8 @@ namespace mxnet { namespace op { +DMLC_REGISTER_PARAMETER(MKLDNNRnnParam); + inline int GetRnnGatesNum(int mode) { switch (mode) { case rnn_enum::kLstm: @@ -82,12 +84,26 @@ void MKLDNNRnnLayerParam::SetDims() { reserve_size = 0; } -MKLDNNRnnFullParam MKLDNNRnnFullParamParser(const RNNParam& rnn_param, +MKLDNNRnnFullParam MKLDNNRnnFullParamParser(const NodeAttrs& attrs, const int seq_len, const int batch_size, const int input_size) { + const RNNParam& rnn_param = nnvm::get(attrs.parsed); MKLDNNRnnFullParam full_param; - full_param.default_param = rnn_param; + full_param.default_param = rnn_param; + try { + full_param.mkldnn_param.Init(attrs.dict, dmlc::parameter::kAllowUnknown); + } catch (const dmlc::ParamError& e) { + std::ostringstream os; + os << e.what(); + os << ", in operator " << attrs.op->name << "(" + << "name=\"" << attrs.name << "\""; + for (const auto& k : attrs.dict) { + os << ", " << k.first << "=\"" << k.second << "\""; + } + os << ")"; + throw dmlc::ParamError(os.str()); + } size_t state_size = rnn_param.state_size; LayerParamVector& layer_params = full_param.layer_params; @@ -116,15 +132,20 @@ MKLDNNRnnFullParam MKLDNNRnnFullParamParser(const RNNParam& rnn_param, false); } - // Set dims, workspace size, and state_outputs flag + // Set dims, workspace size, state_outputs, quantized and enable_u8_output flag for (auto& layer_param : layer_params) { layer_param.SetDims(); - layer_param.state_outputs = rnn_param.state_outputs; + layer_param.state_outputs = rnn_param.state_outputs; + layer_param.quantized = full_param.mkldnn_param.quantized; + layer_param.enable_u8_output = true; } + // Quantized RNN operator produces kFloat32 outputs. + if (full_param.mkldnn_param.quantized) + layer_params.back().enable_u8_output = false; return full_param; } -void MKLDNNRnnMemMgr::Init(dim_t size, const Context& ctx, int dtype) { +void MKLDNNRnnMemMgr::Init(const dim_t size, const Context& ctx, int dtype) { workspace_ = NDArray(TShape({size}), ctx, false, dtype); curr_mem = static_cast(workspace_.data().dptr_); mem_size = size * mshadow::mshadow_sizeof(dtype); @@ -157,31 +178,39 @@ mkldnn::memory* MKLDNNRnnMemMgr::Alloc(const mkldnn::memory::desc& md) { RnnPrimitive GetRnnFwdPrim(const MKLDNNRnnLayerParam& layer_param, const bool is_train, const NDArray& data, - const NDArray& params) { + const NDArray& params, + const shared_mkldnn_attr_t attr) { using namespace mkldnn; - using tag = mkldnn::memory::format_tag; - const int mode = layer_param.mode; - memory::data_type data_type = get_mkldnn_type(data.dtype()); - memory::data_type weight_type = get_mkldnn_type(params.dtype()); + using tag = mkldnn::memory::format_tag; + const int mode = layer_param.mode; + memory::data_type src_layer_dtype = get_mkldnn_type(data.dtype()); + memory::data_type iter_dtype = get_mkldnn_type(mshadow::kFloat32); + memory::data_type weight_dtype = + get_mkldnn_type(layer_param.quantized ? mshadow::kInt8 : params.dtype()); + memory::data_type bias_dtype = get_mkldnn_type(mshadow::kFloat32); + memory::data_type dst_layer_dtype = + get_mkldnn_type((layer_param.quantized && layer_param.enable_u8_output) ? mshadow::kUint8 : + mshadow::kFloat32); const prop_kind prop = is_train ? prop_kind::forward_training : prop_kind::forward_inference; - const rnn_direction mkldnn_rnn_direction = layer_param.bidirectional - ? rnn_direction::bidirectional_concat - : rnn_direction::unidirectional; - - auto src_layer_desc = memory::desc(layer_param.src_dims, data_type, tag::tnc); - auto weight_layer_desc = memory::desc(layer_param.weight_layer_dims, weight_type, tag::any); - auto weight_iter_desc = memory::desc(layer_param.weight_iter_dims, weight_type, tag::any); - auto bias_desc = memory::desc(layer_param.bias_dims, data_type, tag::ldgo); - auto dst_layer_desc = memory::desc(layer_param.dst_dims, data_type, tag::tnc); - auto src_state_desc = memory::desc(layer_param.state_dims, data_type, tag::ldnc); - auto dst_state_desc = layer_param.state_outputs - ? memory::desc(layer_param.state_dims, data_type, tag::ldnc) - : memory::desc(); + const rnn_direction mkldnn_rnn_direction = layer_param.bidirectional ? + rnn_direction::bidirectional_concat : + rnn_direction::unidirectional; + + auto src_layer_desc = memory::desc(layer_param.src_dims, src_layer_dtype, tag::tnc); + auto weight_layer_desc = memory::desc(layer_param.weight_layer_dims, weight_dtype, tag::any); + auto weight_iter_desc = memory::desc(layer_param.weight_iter_dims, weight_dtype, tag::any); + auto bias_desc = memory::desc(layer_param.bias_dims, bias_dtype, tag::ldgo); + auto dst_layer_desc = memory::desc(layer_param.dst_dims, dst_layer_dtype, tag::tnc); + auto src_state_desc = memory::desc(layer_param.state_dims, iter_dtype, tag::ldnc); + auto dst_state_desc = layer_param.state_outputs ? + memory::desc(layer_param.state_dims, iter_dtype, tag::ldnc) : + memory::desc(); auto fwd = RnnPrimitive(); switch (mode) { case rnn_enum::kLstm: - fwd = RnnPrimitive::Create(prop, + fwd = RnnPrimitive::Create(attr, + prop, mkldnn_rnn_direction, src_layer_desc, src_state_desc, @@ -194,7 +223,8 @@ RnnPrimitive GetRnnFwdPrim(const MKLDNNRnnLayerParam& layer_param, dst_state_desc); break; case rnn_enum::kGru: - fwd = RnnPrimitive::Create(prop, + fwd = RnnPrimitive::Create(attr, + prop, mkldnn_rnn_direction, src_layer_desc, src_state_desc, @@ -207,6 +237,7 @@ RnnPrimitive GetRnnFwdPrim(const MKLDNNRnnLayerParam& layer_param, case rnn_enum::kRnnRelu: case rnn_enum::kRnnTanh: fwd = RnnPrimitive::Create( + attr, prop, mode == rnn_enum::kRnnTanh ? algorithm::eltwise_tanh : algorithm::eltwise_relu, mkldnn_rnn_direction, @@ -418,11 +449,19 @@ void MKLDNNRnnForward::SetNewDataMem(void* x, auto& cpu_engine = CpuEngine::Get()->get_engine(); mkldnn_args_map_t& args = net_args_; + int src_dtype = dtype; + int dst_dtype = dtype; + if (param_.quantized) { + src_dtype = mshadow::kUint8; + if (param_.enable_u8_output) + dst_dtype = mshadow::kUint8; + } + RNN_HANDLE_FUNC(RNN_HANDLE_FUNC_NAME); // Set various data memory - RNN_FWD_SET(SRC, param_.src_dims, format_tag::tnc, x, dtype); - RNN_FWD_SET(DST, param_.dst_dims, format_tag::tnc, y, dtype); + RNN_FWD_SET(SRC, param_.src_dims, format_tag::tnc, x, src_dtype); + RNN_FWD_SET(DST, param_.dst_dims, format_tag::tnc, y, dst_dtype); RNN_FWD_SET(SRC_ITER, param_.state_dims, format_tag::ldnc, hx, dtype); if (param_.state_outputs) { @@ -437,35 +476,26 @@ void MKLDNNRnnForward::SetNewDataMem(void* x, } } -inline void MKLDNNMemoryReorder(const mkldnn::memory& src, const mkldnn::memory& dst) { -#if DMLC_CXX11_THREAD_LOCAL - static thread_local std::unordered_map reorderPrimitives; -#else - static MX_THREAD_LOCAL std::unordered_map reorderPrimitives; -#endif - OpSignature key{}; - key.AddSign(src); - key.AddSign(dst); - - auto it = reorderPrimitives.find(key); - if (it == reorderPrimitives.end()) { - auto reorder = mkldnn::reorder(src, dst); - it = AddToCache(&reorderPrimitives, key, reorder); - } - - mkldnn_args_map_t net_args; - net_args.emplace(MKLDNN_ARG_SRC, src); - net_args.emplace(MKLDNN_ARG_DST, dst); - MKLDNNStream::Get()->RegisterPrimArgs(it->second, net_args); -} - /* * Reorder the concatenated weights memory to a efficient memory block * with primitive-prefered format. */ void MKLDNNRnnForward::ReorderWeights() { - MKLDNNMemoryReorder(*weights_layer_r_, *weights_layer_); - MKLDNNMemoryReorder(*weights_iter_r_, *weights_iter_); + if (param_.quantized) { + const mkldnn::primitive_attr& attr = this->fwd_inf_.GetPrimAttr(); + auto ReorderWithAttr = [&](mkldnn::memory& src, mkldnn::memory& dst) { + auto reorder_pd = mkldnn::reorder::primitive_desc(src, dst, attr); + mkldnn_args_map_t net_args; + net_args[MKLDNN_ARG_SRC] = src; + net_args[MKLDNN_ARG_DST] = dst; + MKLDNNStream::Get()->RegisterPrimArgs(mkldnn::reorder(reorder_pd), net_args); + }; + ReorderWithAttr(*weights_layer_r_, *weights_layer_); + ReorderWithAttr(*weights_iter_r_, *weights_iter_); + } else { + MKLDNNMemoryReorder(*weights_layer_r_, *weights_layer_); + MKLDNNMemoryReorder(*weights_iter_r_, *weights_iter_); + } } void AdjustGruGateOrder(char* weight, @@ -546,7 +576,7 @@ void MKLDNNRnnForward::SetWeightsMem(MKLDNNRnnMemMgr* mgr, const bool is_train, const int dtype) { using format_tag = mkldnn::memory::format_tag; - auto mkldnn_dtype = get_mkldnn_type(dtype); + const auto mkldnn_dtype = get_mkldnn_type(dtype); // Get the weights' memory for RNN forward primitive if (weights_layer_ == nullptr) { weights_layer_ = mgr->Alloc(fwd_inf_.GetLayerDesc()); @@ -643,7 +673,7 @@ void MKLDNNRnnForward::SetWeightsMem(MKLDNNRnnMemMgr* mgr, // space for weights and their gradients. Then, forward training primitives // could fetch them from the scope of forward inference. And from there, we // don't need to reorder the plain memory to the optimal rnn-packed memory - // for forward inference. + // for forward inference ReorderWeights(); initialized_ = true; } @@ -705,6 +735,19 @@ void MKLDNNRnnOp::Init(const OpContext& ctx, const std::vector& outputs) { using format_tag = mkldnn::memory::format_tag; + // Get the bytes of a real type + const NDArray& weights = inputs[rnn_enum::kParams]; + int dtype = weights.dtype(); + size_t dtype_bytes = mshadow::mshadow_sizeof(dtype); + const RNNParam& default_param = full_param_.default_param; + const size_t weights_size = + weights.data().Size() - GetRnnBiasSize(default_param.num_layers, + default_param.state_size, + default_param.bidirectional + 1, + default_param.mode); + char* weights_ptr = static_cast(weights.data().dptr_); + char* bias_ptr = weights_ptr + weights_size * dtype_bytes; + // In the `autograd.record()` context, RNNOp is required to run into // `forward_training` mode. const bool is_training = (ctx.is_train || ctx.need_grad); @@ -719,7 +762,7 @@ void MKLDNNRnnOp::Init(const OpContext& ctx, for (auto& layer_param : full_param_.layer_params) { fwd_inf_vec_.emplace_back( - layer_param, ctx.is_train, inputs[rnn_enum::kData], inputs[rnn_enum::kParams]); + layer_param, ctx.is_train, inputs[rnn_enum::kData], inputs[rnn_enum::kParams], nullptr); buffer_size += fwd_inf_vec_.back().GetSize(inputs[rnn_enum::kParams].dtype()); } mgr_.Init(buffer_size, ctx.run_ctx.ctx, inputs[rnn_enum::kParams].dtype()); @@ -732,19 +775,7 @@ void MKLDNNRnnOp::Init(const OpContext& ctx, } } - // Get the bytes of a real type - const NDArray& weights = inputs[rnn_enum::kParams]; - int dtype = weights.dtype(); - size_t dtype_bytes = mshadow::mshadow_sizeof(dtype); - const RNNParam& default_param = full_param_.default_param; - char* weights_ptr = static_cast(weights.data().dptr_); - char* bias_ptr = - weights_ptr + (weights.data().Size() - GetRnnBiasSize(default_param.num_layers, - default_param.state_size, - default_param.bidirectional + 1, - default_param.mode)) * - dtype_bytes; for (auto& fwd_layer : fwd_inf_vec_) { size_t single_w_bytes = fwd_layer.GetParam().single_w_size * dtype_bytes; size_t single_b_bytes = fwd_layer.GetParam().native_single_b_size * dtype_bytes; @@ -769,7 +800,7 @@ void MKLDNNRnnOp::Init(const OpContext& ctx, CHECK_EQ(num_fusion, fwd_inf_vec_.size()) << "Layer vector's size has a different value than the number of fusion."; if (dst_.size() < num_fusion - 1) { - int data_dtype = outputs[rnn_enum::kOut].dtype(); + const int data_dtype = outputs[rnn_enum::kOut].dtype(); // Here we need `fwd_inf_vec_.size() - 1` spaces for the intermediate // results of the multiple fused layers. And for the result of the last // fused layer, `outputs[rnn_enum::kOut]` could provide the space. Hence, @@ -1047,6 +1078,12 @@ void MKLDNNRnnOp::Forward(const OpContext& ctx, weights_version_ = inputs[rnn_enum::kParams].version(); } + if (dmlc::GetEnv("MXNET_RNN_USE_WEIGHT_CACHE", 0) && !initialized_) { + LOG(INFO) << "The current weight of RNN is assumed to be fixed and cached during " + "the whole inference pipeline. Please set MXNET_RNN_USE_WEIGHT_CACHE=0, if " + "the weight is going to be changed at runtime."; + } + // Check if weights NDArray was changed. If so, reset initialized_ if (!is_training && fwd_inf_vec_.size() > 0 && weights_version_ != inputs[rnn_enum::kParams].version()) { @@ -1056,12 +1093,12 @@ void MKLDNNRnnOp::Forward(const OpContext& ctx, weights_version_ = inputs[rnn_enum::kParams].version(); } - if (!initialized_ || is_training || fwd_inf_vec_.size() == 0) { + if (!initialized_ || is_training || fwd_inf_vec_.empty()) { Init(ctx, inputs, req, outputs); } // Get data type - int data_dtype = inputs[rnn_enum::kData].dtype(); + int data_dtype = outputs[rnn_enum::kOut].dtype(); // Get temporary memory for output, state_out, statecell_out const int num_layers = default_param.num_layers; const int seq_length = default_param.seq_length_; diff --git a/src/operator/quantization/mkldnn/mkldnn_quantize_asym-inl.h b/src/operator/quantization/mkldnn/mkldnn_quantize_asym-inl.h new file mode 100644 index 000000000000..32dcf5aab249 --- /dev/null +++ b/src/operator/quantization/mkldnn/mkldnn_quantize_asym-inl.h @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file mkldnn_quantize_asym-inl.h + * \brief implementation of asymmetric quantize operation using DNNL + */ + +#ifndef MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_QUANTIZE_ASYM_INL_H_ +#define MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_QUANTIZE_ASYM_INL_H_ +#if MXNET_USE_MKLDNN == 1 + +#include +#include +#include "../../nn/mkldnn/mkldnn_base-inl.h" +#include "../quantize_asym-inl.h" + +namespace mxnet { +namespace op { + +class MKLDNNQuantizeAsymOp { + public: + explicit MKLDNNQuantizeAsymOp(const nnvm::NodeAttrs& attrs) + : param_(nnvm::get(attrs.parsed)) {} + + void Forward(const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs); + + private: + QuantizeAsymParam param_; + bool initialized_{false}; + float cached_scale_{0.f}; + float cached_shift_{0.f}; + mkldnn::memory::desc o_desc_; + mkldnn_args_map_t args_; + std::shared_ptr fwd_pd_; +}; + +void MKLDNNQuantizeAsymOp::Forward(const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using mshadow::red::limits::MaxValue; + using mshadow::red::limits::MinValue; + NDArray in_buffer = inputs[0]; + float scale = 0.f; + float shift = 0.f; + + // Pass through quantized data + if (inputs[0].dtype() == mshadow::kUint8) { + *outputs[1].data().dptr() = 1; + *outputs[2].data().dptr() = 0; + if (req[0] != kWriteInplace) { + const_cast(outputs[0]).CopyFrom(inputs[0].GetMKLDNNData()); + MKLDNNStream::Get()->Submit(); + } + } else { + in_buffer = inputs[0].Reorder2Default(); + const mkldnn::memory* i_mem = static_cast(in_buffer.GetMKLDNNData()); + float* in_ptr = in_buffer.data().dptr(); + const int nthreads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); + if (inputs[0].dtype() == mshadow::kInt8) { + *outputs[1].data().dptr() = 1; + *outputs[2].data().dptr() = 128; +#pragma omp parallel for num_threads(nthreads) + for (index_t i = 0; i < static_cast(in_buffer.shape().Size()); ++i) { + in_ptr[i] += 128.0f; + } + } else if (inputs[0].dtype() == mshadow::kFloat32) { + if (param_.min_calib_range.has_value() && param_.max_calib_range.has_value()) { + scale = + MaxValue() / (param_.max_calib_range.value() - param_.min_calib_range.value()); + shift = MaxValue() - param_.max_calib_range.value() * scale; + } else { + float data_min = mshadow::red::limits::MaxValue(); + float data_max = mshadow::red::limits::MinValue(); + std::vector data_maxs(nthreads, data_max); + std::vector data_mins(nthreads, data_min); +#pragma omp parallel for num_threads(nthreads) + for (index_t i = 0; i < static_cast(in_buffer.shape().Size()); i++) { + int tid = omp_get_thread_num(); + if (in_ptr[i] > data_maxs[tid]) + data_maxs[tid] = in_ptr[i]; + if (in_ptr[i] < data_mins[tid]) + data_mins[tid] = in_ptr[i]; + } + for (index_t i = 0; i < nthreads; i++) { + if (data_maxs[i] > data_max) + data_max = data_maxs[i]; + if (data_mins[i] < data_min) + data_min = data_mins[i]; + } + scale = MaxValue() / (data_max - data_min); + shift = MaxValue() - data_max * scale; + } + + if (initialized_ && (cached_scale_ != scale || cached_shift_ != shift)) + initialized_ = false; + } + + *outputs[1].data().dptr() = scale; + *outputs[2].data().dptr() = shift; + + if (!initialized_) { + cached_scale_ = scale; + cached_shift_ = shift; + mkldnn::primitive_attr attr; + attr.set_rnn_data_qparams(scale, shift); + const mkldnn::engine& cpu_engine = mxnet::CpuEngine::Get()->get_engine(); + const mkldnn::memory::desc& i_desc = i_mem->get_desc(); + o_desc_ = i_desc; + o_desc_.data.data_type = get_mkldnn_type_t(outputs[0].dtype()); + mkldnn::reorder::primitive_desc reorder_pd(cpu_engine, i_desc, cpu_engine, o_desc_, attr); + fwd_pd_ = std::make_shared(reorder_pd); + initialized_ = true; + } + mkldnn_output_t o_mem = CreateMKLDNNMem(outputs[0], o_desc_, req[0]); + args_[MKLDNN_ARG_FROM] = *i_mem; + args_[MKLDNN_ARG_TO] = *o_mem.second; + MKLDNNStream::Get()->RegisterPrimArgs(*fwd_pd_, args_); + CommitOutput(outputs[0], o_mem); + MKLDNNStream::Get()->Submit(); + } +} + +void MKLDNNQuantizeAsymForward(const OpStatePtr& state_ptr, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + if (inputs[0].shape().ndim() == 3 && inputs[0].dtype() == mshadow::kFloat32) { + MKLDNNQuantizeAsymOp& op = state_ptr.get_state(); + op.Forward(ctx, inputs, req, outputs); + } else { + FallBackCompute(QuantizeAsymForward, state_ptr, ctx, inputs, req, outputs); + } +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_USE_MKLDNN == 1 +#endif // MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_QUANTIZE_ASYM_INL_H_ diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_rnn-inl.h b/src/operator/quantization/mkldnn/mkldnn_quantized_rnn-inl.h new file mode 100644 index 000000000000..7950f7dbdafa --- /dev/null +++ b/src/operator/quantization/mkldnn/mkldnn_quantized_rnn-inl.h @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file mkldnn_quantized_rnn-inl.h + * \brief Common functions for quantized recurrent neural network + * \author Zixuan Wei + */ + +#ifndef MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_QUANTIZED_RNN_INL_H_ +#define MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_QUANTIZED_RNN_INL_H_ + +#if MXNET_USE_MKLDNN == 1 + +#include +#include "../../nn/mkldnn/mkldnn_rnn-inl.h" +#include "../../rnn-inl.h" +#include "../quantized_rnn-inl.h" + +namespace mxnet { +namespace op { + +class MKLDNNQuantizedRnnOp { + public: + explicit MKLDNNQuantizedRnnOp(const nnvm::NodeAttrs& attrs, + const int seq_len, + const int batch_size, + const int input_size) + : initialized_(false), + weights_ver_(0), + rnn_attr_(new mkldnn::primitive_attr), + full_param_(MKLDNNRnnFullParamParser(attrs, seq_len, batch_size, input_size)) {} + + void Forward(const OpContext& op_ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs); + + private: + bool initialized_; + size_t weights_ver_; + shared_mkldnn_attr_t rnn_attr_; + MKLDNNRnnFullParam full_param_; + MKLDNNRnnMemMgr mgr_; + std::vector fwd_inf_vec_; // forward inference layers + + // Used to store the intermediate results of multi-layer + std::vector dst_; + // According to + // https://intel.github.io/mkl-dnn/dev_guide_int8_computations.html, the + // non-symmetric quantization is assumed by LSTM primitive. Namely, the + // formula is: + // data_f32 = (data_u8 - shift) / scale + float cached_data_shift_{0.0}; + float cached_data_scale_{0.0}; + void Init(const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs); +}; + +} // namespace op +} // namespace mxnet + +#endif // MXNET_USE_MKLDNN == 1 +#endif // MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_QUANTIZED_RNN_INL_H_ diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_rnn.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_rnn.cc new file mode 100644 index 000000000000..9b4377aeb74e --- /dev/null +++ b/src/operator/quantization/mkldnn/mkldnn_quantized_rnn.cc @@ -0,0 +1,364 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file mkldnn_quantized_rnn.cc + * \brief Common functions for quantized recurrent neural network + * \author Zixuan Wei + */ + +#if MXNET_USE_MKLDNN == 1 + +#include "../quantization_utils.h" +#include "./mkldnn_quantized_rnn-inl.h" + +namespace mxnet { +namespace op { + +std::vector GetMKLDNNRnnWeightsQParams(const MKLDNNRnnFullParam& full_param, float* w_ptr) { + const int nthreads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); + const RNNParam& default_param = full_param.default_param; + const LayerParamVector& layer_params = full_param.layer_params; + + const MKLDNNRnnLayerParam& layer_param0 = layer_params.at(0); + const size_t w_size0 = layer_param0.single_w_size; + const size_t wx_size0 = 4 * layer_param0.state_size * layer_param0.input_size; + const size_t wh_size0 = 4 * layer_param0.state_size * layer_param0.state_size; + + int directions = 1; + float* wx = w_ptr; + float* wh = wx + wx_size0; + float* fake_wx = wx; + float* fake_wh = wh; + + std::vector wx_goi_max; + std::vector wh_goi_max; + if (default_param.bidirectional) { + directions = 2; + wx_goi_max.resize(wx_size0); + wh_goi_max.resize(wh_size0); + fake_wx = wx_goi_max.data(); + fake_wh = wh_goi_max.data(); +#pragma omp parallel for num_threads(nthreads) + for (index_t i = 0; i < static_cast(wx_size0); ++i) { + fake_wx[i] = MaxAbs(wx[i], wx[i + w_size0]); + } +#pragma omp parallel for num_threads(nthreads) + for (index_t i = 0; i < static_cast(wh_size0); ++i) { + fake_wh[i] = MaxAbs(wh[i], wh[i + w_size0]); + } + } + std::vector w_max(4 * layer_param0.state_size, 0.0); + const index_t input_size = layer_param0.input_size; // input + const index_t state_size = layer_param0.state_size; // state + const index_t gates_nblks = 4 * layer_param0.state_size; // gates * state + for (index_t go = 0; go < gates_nblks; ++go) { + float tmp_max = w_max[go]; + for (index_t i = 0; i < input_size; ++i) { + tmp_max = MaxAbs(fake_wx[go * input_size + i], tmp_max); + } + for (index_t i = 0; i < state_size; ++i) { + tmp_max = MaxAbs(fake_wh[go * state_size + i], tmp_max); + } + w_max[go] = tmp_max; + } + wx += layer_param0.single_w_size * directions; + wh += layer_param0.single_w_size * directions; + + std::vector goi_max(wh_size0, 0.0); + for (size_t lyr = 1; lyr < layer_params.size(); ++lyr) { + const MKLDNNRnnLayerParam& layer_param = layer_params.at(lyr); + const int weight_nblks = layer_param.num_layer * directions; + for (int blk = 0; blk < weight_nblks; ++blk) { +#pragma omp parallel for num_threads(nthreads) + for (index_t i = 0; i < static_cast(wh_size0); ++i) { + goi_max[i] = MaxAbs(wx[i], wh[i]); + } + for (index_t go = 0; go < gates_nblks; ++go) { + float tmp = w_max[go]; +// NOTES: min/max reductions were supported since OpenMP 3.1, which was +// released in Jul 2011 (hence the version number). +#if _OPENMP >= 201107 +#pragma omp parallel for reduction(max : tmp) num_threads(nthreads) +#endif + for (index_t i = 0; i < state_size; ++i) { + tmp = Max(goi_max[go * state_size + i], tmp); + } + w_max[go] = tmp; + } + } + wx += layer_param.single_w_size * directions; + wh = wx + wh_size0; + } +#pragma omp parallel for num_threads(nthreads) + for (index_t i = 0; i < static_cast(w_max.size()); ++i) { + w_max[i] = mshadow::red::limits::MaxValue() / w_max[i]; + } + return w_max; +} + +void MKLDNNQuantizedRnnOp::Init(const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using format_tag = mkldnn::memory::format_tag; + + // Get the bytes of a real type + const NDArray& weights = inputs[rnn_enum::kParams]; + int dtype = weights.dtype(); + int weights_dtype = weights.dtype(); + size_t dtype_bytes = mshadow::mshadow_sizeof(dtype); + const RNNParam& default_param = full_param_.default_param; + const size_t weights_size = + weights.data().Size() - GetRnnBiasSize(default_param.num_layers, + default_param.state_size, + default_param.bidirectional + 1, + default_param.mode); + char* weights_ptr = static_cast(weights.data().dptr_); + char* bias_ptr = weights_ptr + weights_size * dtype_bytes; + + // In the `autograd.record()` context, RNNOp is required to run into + // `forward_training` mode. + + const size_t num_fusion = full_param_.layer_params.size(); + if (fwd_inf_vec_.size() < num_fusion) { + size_t buffer_size = 0; // Element number, instead of bytes, in the buffer + for (auto& layer_param : full_param_.layer_params) { + buffer_size += layer_param.workspace_size + layer_param.reserve_size; + } + buffer_size += outputs[rnn_enum::kOut].data().Size() * (num_fusion - 1); + buffer_size += kMKLDNNAlign * num_fusion * 5; // Add margin for alignment + + for (auto& layer_param : full_param_.layer_params) { + fwd_inf_vec_.emplace_back( + layer_param, false, inputs[rnn_enum::kData], inputs[rnn_enum::kParams], rnn_attr_); + buffer_size += fwd_inf_vec_.back().GetSize(inputs[rnn_enum::kParams].dtype()); + } + mgr_.Init(buffer_size, ctx.run_ctx.ctx, inputs[rnn_enum::kParams].dtype()); + } + + for (auto& fwd_layer : fwd_inf_vec_) { + size_t single_w_bytes = fwd_layer.GetParam().single_w_size * dtype_bytes; + size_t single_b_bytes = fwd_layer.GetParam().native_single_b_size * dtype_bytes; + size_t directions = fwd_layer.GetParam().bidirectional ? 2 : 1; + size_t layer_weights_bytes = single_w_bytes * directions; + size_t layer_bias_bytes = single_b_bytes * directions; // Native MXNet has double bias + + if (!fwd_layer.IsInitialized()) + fwd_layer.SetWeightsMem(&(this->mgr_), weights_ptr, bias_ptr, false, weights_dtype); + weights_ptr += layer_weights_bytes; + bias_ptr += layer_bias_bytes; + } + + CHECK_EQ(num_fusion, fwd_inf_vec_.size()) + << "Layer vector's size has a different value than the number of fusion."; + if (dst_.size() < num_fusion - 1) { + const int data_dtype = outputs[rnn_enum::kOut].dtype(); + // Here we need `fwd_inf_vec_.size() - 1` spaces for the intermediate + // results of the multiple fused layers. And for the result of the last + // fused layer, `outputs[rnn_enum::kOut]` could provide the space. Hence, + // `forward_inf_vec_.back()` is excluded when allocates the spaces for + // intermediate results. + for (std::vector::const_iterator fwd = fwd_inf_vec_.begin(); + fwd != fwd_inf_vec_.end() - 1; + ++fwd) + dst_.push_back( + mgr_.Alloc({fwd->GetParam().dst_dims, get_mkldnn_type(data_dtype), format_tag::tnc})); + } + + initialized_ = true; +} + +template +inline void RegisterMKLDNNRnn(MKLDNNRnnX const& rnn) { + MKLDNNStream::Get()->RegisterPrimArgs(rnn.GetFwd(), rnn.GetArgsMap()); +} + +template <> +inline void RegisterMKLDNNRnn(MKLDNNRnnBackward const& rnn) { + MKLDNNStream::Get()->RegisterPrimArgs(rnn.GetBwd(), rnn.GetArgsMap()); + rnn.SetNativeWeightsGrads(); +} + +void MKLDNNQuantizedRnnOp::Forward(const OpContext& op_ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + TmpMemMgr::Get()->Init(op_ctx.requested[0]); + + const RNNParam& default_param = full_param_.default_param; + const uint32_t num_base_inputs = GetRnnNumInputs(default_param); + float data_scale = inputs[num_base_inputs + quantized_rnn::kDataScale].data().dptr()[0]; + float data_shift = inputs[num_base_inputs + quantized_rnn::kDataShift].data().dptr()[0]; + + const bool need_reset_weight = (!dmlc::GetEnv("MXNET_RNN_USE_WEIGHT_CACHE", 0) && + weights_ver_ != inputs[rnn_enum::kParams].version()) ? + true : + false; + const NDArray& weights = inputs.at(rnn_enum::kParams); + float* weights_ptr = weights.data().dptr(); + if (!initialized_ || fwd_inf_vec_.empty()) { + weights_ver_ = inputs[rnn_enum::kParams].version(); + cached_data_scale_ = data_scale; + cached_data_shift_ = data_shift; + rnn_attr_->set_rnn_data_qparams(data_scale, data_shift); + if (need_reset_weight || fwd_inf_vec_.empty()) + rnn_attr_->set_rnn_weights_qparams(0 + (1 << 3) + (1 << 4), + GetMKLDNNRnnWeightsQParams(full_param_, weights_ptr)); + } + + // Initialize weights version + if (!initialized_ && weights_ver_ == 0) { + weights_ver_ = inputs[rnn_enum::kParams].version(); + cached_data_scale_ = data_scale; + cached_data_shift_ = data_shift; + } + + if (!fwd_inf_vec_.empty() && + ((cached_data_scale_ != data_scale || cached_data_shift_ != data_shift))) { + initialized_ = false; + weights_ver_ = inputs[rnn_enum::kParams].version(); + cached_data_scale_ = data_scale; + cached_data_shift_ = data_shift; + } + + // Check if weights NDArray was changed. If so, reset initialized_ + if (fwd_inf_vec_.size() > 0 && weights_ver_ != inputs[rnn_enum::kParams].version()) { + initialized_ = false; + for (auto& fwd : fwd_inf_vec_) + fwd.Reset(); + weights_ver_ = inputs[rnn_enum::kParams].version(); + cached_data_scale_ = data_scale; + cached_data_shift_ = data_shift; + } + + if (!initialized_ || fwd_inf_vec_.empty()) { + Init(op_ctx, inputs, req, outputs); + } + + // Get data type + int data_dtype = outputs[rnn_enum::kOut].dtype(); + // Get temporary memory for output, state_out, statecell_out + const int num_layers = default_param.num_layers; + const int seq_length = default_param.seq_length_; + const int batch_size = default_param.batch_size_; + const int state_size = default_param.state_size; + const int directions = default_param.bidirectional ? 2 : 1; + mkldnn::memory::desc dst_desc({seq_length, batch_size, directions * state_size}, + get_mkldnn_type(data_dtype), + mkldnn::memory::format_tag::tnc); + mkldnn::memory::desc state_desc({num_layers, directions, batch_size, state_size}, + get_mkldnn_type(data_dtype), + mkldnn::memory::format_tag::ldnc); + auto out_mem = CreateMKLDNNMem(outputs[rnn_enum::kOut], dst_desc, req[rnn_enum::kOut]); + mkldnn_output_t stateout_mem; + mkldnn_output_t statecellout_mem; + + // Get input & output NDArray + char* src = static_cast(inputs[rnn_enum::kData].data().dptr_); + char* src_state = static_cast(inputs[rnn_enum::kState].data().dptr_); + char* dst = static_cast(out_mem.second->get_data_handle()); + char* dst_state = nullptr; // Output state + char* src_state_cell = nullptr; // Used in LSTM for cell state + char* dst_state_cell = nullptr; // Used in LSTM for cell state + const size_t cell_bytes = (default_param.bidirectional + 1) * default_param.batch_size_ * + default_param.state_size * mshadow::mshadow_sizeof(data_dtype); + + if (default_param.state_outputs && req[rnn_enum::kStateOut] != kNullOp) { + stateout_mem = + CreateMKLDNNMem(outputs[rnn_enum::kStateOut], state_desc, req[rnn_enum::kStateOut]); + dst_state = static_cast(stateout_mem.second->get_data_handle()); + } + + if (default_param.mode == rnn_enum::kLstm) { + src_state_cell = static_cast(inputs[rnn_enum::kStateCell].data().dptr_); + if (default_param.state_outputs && req[rnn_enum::kStateCellOut] != kNullOp) { + statecellout_mem = CreateMKLDNNMem( + outputs[rnn_enum::kStateCellOut], state_desc, req[rnn_enum::kStateCellOut]); + dst_state_cell = static_cast(statecellout_mem.second->get_data_handle()); + } + } + + if (fwd_inf_vec_.size() == 1) { + fwd_inf_vec_.front().SetNewDataMem( + src, src_state, src_state_cell, dst, dst_state, dst_state_cell, data_dtype); + } else { + CHECK_EQ(fwd_inf_vec_.size(), dst_.size() + 1) << "Output memory error."; + size_t cell_bytes = (default_param.bidirectional + 1) * default_param.batch_size_ * + default_param.state_size * mshadow::mshadow_sizeof(data_dtype); + + // Set input data memory for the first layer. This stores intermediate + // output results in this->xxx, used as the source input of the next layer. + fwd_inf_vec_.front().SetNewDataMem(src, + src_state, + src_state_cell, + this->dst_.front()->get_data_handle(), + dst_state, + dst_state_cell, + data_dtype); + // 1st_lyr -> dst_handle -> next_lyr -> dst_handle -> next_lyr -> ... + for (size_t lyr = 1; lyr < fwd_inf_vec_.size() - 1; ++lyr) { + src_state += cell_bytes; + if (src_state_cell) + src_state_cell += cell_bytes; + if (dst_state) + dst_state += cell_bytes; + if (dst_state_cell) + dst_state_cell += cell_bytes; + fwd_inf_vec_.at(lyr).SetNewDataMem(this->dst_.at(lyr - 1)->get_data_handle(), + src_state, + src_state_cell, + this->dst_.at(lyr)->get_data_handle(), + dst_state, + dst_state_cell, + data_dtype); + } + // Set output data memory for the last layer. + src_state += cell_bytes; + if (src_state_cell) + src_state_cell += cell_bytes; + if (dst_state) + dst_state += cell_bytes; + if (dst_state_cell) + dst_state_cell += cell_bytes; + fwd_inf_vec_.back().SetNewDataMem(this->dst_.back()->get_data_handle(), + src_state, + src_state_cell, + dst, + dst_state, + dst_state_cell, + data_dtype); + } + + for (auto& inf_lyr : fwd_inf_vec_) + RegisterMKLDNNRnn(inf_lyr); + + CommitOutput(outputs[rnn_enum::kOut], out_mem); + if (default_param.state_outputs) { + CommitOutput(outputs[rnn_enum::kStateOut], stateout_mem); + if (default_param.mode == rnn_enum::kLstm) + CommitOutput(outputs[rnn_enum::kStateCellOut], statecellout_mem); + } + MKLDNNStream::Get()->Submit(); +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_USE_MKLDNN == 1 diff --git a/src/operator/quantization/quantize_asym-inl.h b/src/operator/quantization/quantize_asym-inl.h new file mode 100644 index 000000000000..3aa44c4e4fd6 --- /dev/null +++ b/src/operator/quantization/quantize_asym-inl.h @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file quantize_asym-inl.h + * \brief implementation of asymmetric quantize operation + */ +#ifndef MXNET_OPERATOR_QUANTIZATION_QUANTIZE_ASYM_INL_H_ +#define MXNET_OPERATOR_QUANTIZATION_QUANTIZE_ASYM_INL_H_ + +#include +#include +#include +#include +#include + +#include "../mshadow_op.h" +#include "../mxnet_op.h" +#include "../tensor/broadcast_reduce_op.h" +#include "./quantization_utils.h" + +namespace mxnet { +namespace op { + +struct QuantizeAsymParam : public dmlc::Parameter { + dmlc::optional min_calib_range; + dmlc::optional max_calib_range; + + DMLC_DECLARE_PARAMETER(QuantizeAsymParam) { + DMLC_DECLARE_FIELD(min_calib_range) + .set_default(dmlc::optional()) + .describe( + "The minimum scalar value in the form of float32. If " + "present, it will be used to " + "quantize the fp32 data."); + DMLC_DECLARE_FIELD(max_calib_range) + .set_default(dmlc::optional()) + .describe( + "The maximum scalar value in the form of float32. If " + "present, it will be used to " + "quantize the fp32 data."); + } +}; + +// quantize float to uint8_t +struct quantize_asymmetric { + template + MSHADOW_XINLINE static void Map(int i, + DstDType* out, + float* oscale, + float* oshift, + const SrcDType* in, + const float scale, + const float shift) { + out[i] = static_cast(in[i] * scale + shift + 0.5); + *oscale = scale; + *oshift = shift; + } +}; + +template +class QuantizeAsymOp { + public: + explicit QuantizeAsymOp(const nnvm::NodeAttrs& attrs) : attrs_(attrs) {} + + void Forward(const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mxnet_op; + using mshadow::red::limits::MaxValue; + using mshadow::red::limits::MinValue; + + CHECK_EQ(outputs[0].type_flag_, mshadow::kUint8) + << "Asymmetric quantization only supports uint8 outputs."; + mshadow::Stream* s = ctx.get_stream(); + const int input_data_dtype = inputs[0].type_flag_; + if (input_data_dtype == mshadow::kUint8) { + *outputs[1].dptr() = 1; + *outputs[2].dptr() = 0; + UnaryOp::IdentityCompute(attrs_, ctx, {inputs[0]}, req, outputs); + } else if (input_data_dtype == mshadow::kInt8) { + const float scale = 1; + const float shift = 128; + Kernel::Launch(s, + outputs[0].Size(), + outputs[0].dptr(), + outputs[1].dptr(), + outputs[2].dptr(), + inputs[0].dptr(), + scale, + shift); + } else if (input_data_dtype == mshadow::kFloat32) { + const QuantizeAsymParam& param = nnvm::get(attrs_.parsed); + if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) { + const float scale = + MaxValue() / (param.max_calib_range.value() - param.min_calib_range.value()); + const float shift = MaxValue() - param.max_calib_range.value() * scale; + Kernel::Launch(s, + outputs[0].Size(), + outputs[0].dptr(), + outputs[1].dptr(), + outputs[2].dptr(), + inputs[0].dptr(), + scale, + shift); + } else { + mxnet::TShape src_shape, dst_shape; + const size_t float_bytes = sizeof(float); + const size_t temp_reduce_size = ConfigReduce( + s, inputs[0].shape_, mxnet::TShape(1, 1), &src_shape, &dst_shape); + Tensor temp_space = ctx.requested[0].get_space_typed( + Shape1(2 * float_bytes + temp_reduce_size), s); + const int dev_id = ctx.run_ctx.ctx.dev_id; + TBlob in_min_t( + reinterpret_cast(temp_space.dptr_), Shape1(1), xpu::kDevMask, dev_id); + TBlob in_max_t( + reinterpret_cast(temp_space.dptr_) + 1, Shape1(1), xpu::kDevMask, dev_id); + Tensor workspace( + temp_space.dptr_ + 2 * float_bytes, Shape1(temp_reduce_size), s); + broadcast::Reduce( + s, in_min_t.reshape(dst_shape), kWriteTo, workspace, inputs[0].reshape(src_shape)); + broadcast::Reduce( + s, in_max_t.reshape(dst_shape), kWriteTo, workspace, inputs[0].reshape(src_shape)); + const float scale = + MaxValue() / (*in_max_t.dptr() - *in_min_t.dptr()); + const float shift = MaxValue() - *in_max_t.dptr() * scale; + Kernel::Launch(s, + outputs[0].Size(), + outputs[0].dptr(), + outputs[1].dptr(), + outputs[2].dptr(), + inputs[0].dptr(), + scale, + shift); + } + } else { + LOG(FATAL) << "Asymmetric quantizaiton only supports int8, uint8 and " + "float inputs"; + } + } + + private: + nnvm::NodeAttrs attrs_; +}; + +template +void QuantizeAsymForward(const OpStatePtr& state_ptr, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + QuantizeAsymOp& op = state_ptr.get_state>(); + op.Forward(ctx, inputs, req, outputs); +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_QUANTIZATION_QUANTIZE_ASYM_INL_H_ diff --git a/src/operator/quantization/quantize_asym.cc b/src/operator/quantization/quantize_asym.cc new file mode 100644 index 000000000000..f44dadb5b991 --- /dev/null +++ b/src/operator/quantization/quantize_asym.cc @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file quantize_asym.cc + * \brief implementation of asymmetric quantize operation + */ + +#include "./quantize_asym-inl.h" +#if MXNET_USE_MKLDNN == 1 +#include "./mkldnn/mkldnn_quantize_asym-inl.h" +#endif + +namespace mxnet { +namespace op { + +DMLC_REGISTER_PARAMETER(QuantizeAsymParam); + +inline bool QuantizeAsymShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector* in_attrs, + mxnet::ShapeVector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 3U); + + mxnet::TShape dshape = in_attrs->at(0); + SHAPE_ASSIGN_CHECK(*out_attrs, 0, dshape); + SHAPE_ASSIGN_CHECK(*out_attrs, 1, TShape(1, 1)); + SHAPE_ASSIGN_CHECK(*out_attrs, 2, TShape(1, 1)); + + if (out_attrs->at(0).ndim() > 0) { + dshape[0] = out_attrs->at(0)[0]; + SHAPE_ASSIGN_CHECK(*in_attrs, 0, dshape); + } + + return !shape_is_none(out_attrs->at(0)); +} + +inline bool QuantizeAsymType(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 3U); + + CHECK_EQ(in_attrs->at(0), mshadow::kFloat32); + + TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kUint8); + TYPE_ASSIGN_CHECK(*out_attrs, 1, mshadow::kFloat32); + TYPE_ASSIGN_CHECK(*out_attrs, 2, mshadow::kFloat32); + + return !type_is_none(out_attrs->at(0)); +} + +bool QuantizeAsymStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector* in_attrs, + std::vector* out_attrs) { + *dispatch_mode = DispatchMode::kFCompute; +#if MXNET_USE_MKLDNN == 1 + if (dev_mask == mshadow::cpu::kDevMask) { + *dispatch_mode = DispatchMode::kFComputeEx; + } +#endif + out_attrs->at(0) = kDefaultStorage; + out_attrs->at(1) = kDefaultStorage; + out_attrs->at(2) = kDefaultStorage; + return true; +} + +OpStatePtr CreateQuantizeAsymState(const nnvm::NodeAttrs& attrs, + const Context& ctx, + const std::vector& in_shapes, + const std::vector& in_types) { + OpStatePtr state; + if (ctx.dev_type == kGPU) { + state = OpStatePtr::Create>(attrs); + } else { +#if MXNET_USE_MKLDNN == 1 + if (in_shapes[0].ndim() == 3 && in_types[0] == mshadow::kFloat32) { + state = OpStatePtr::Create(attrs); + return state; + } +#else + state = OpStatePtr::Create>(attrs); +#endif + } + return state; +} + +NNVM_REGISTER_OP(_contrib_quantize_asym) + .describe(R"code(Quantize a input tensor from float to uint8_t. +Output `scale` and `shift` are scalar floats that specify the quantization parameters for the input +data. +The output is calculated using the following equation: +`out[i] = in[i] * scale + shift + 0.5`, +where `scale = uint8_range / (max_range - min_range)` and +`shift = numeric_limits::max - max_range * scale`. +.. Note:: + This operator only supports forward propagation. DO NOT use it in training.)code" ADD_FILELINE) + .set_attr_parser(ParamParser) + .set_num_inputs(1) + .set_num_outputs(3) + .set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"data"}; + }) + .set_attr("FListOutputNames", + [](const NodeAttrs& attrs) { + return std::vector{"output", "scale", "shift"}; + }) + .set_attr("FInferShape", QuantizeAsymShape) + .set_attr("FInferType", QuantizeAsymType) + .set_attr("FInferStorageType", QuantizeAsymStorageType) + .set_attr("FGradient", MakeZeroGradNodes) + .set_attr("FCreateOpState", CreateQuantizeAsymState) +#if MXNET_USE_MKLDNN == 1 + .set_attr("TIsMKLDNN", true) + .set_attr("FStatefulComputeEx", MKLDNNQuantizeAsymForward) +#endif + .set_attr("FStatefulCompute", QuantizeAsymForward) + .set_attr("FNeedCalibrateInput", + [](const NodeAttrs& attrs) { return std::vector{0}; }) + .set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + const QuantizeAsymParam& param = + nnvm::get(attrs.parsed); + if (param.max_calib_range.has_value() && + param.max_calib_range.has_value()) { + return std::vector(); + } else { + return std::vector( + 1, ResourceRequest::kTempSpace); + } + }) + .add_argument("data", "NDArray-or-Symbol", "A ndarray/symbol of type `float32`") + .add_arguments(QuantizeAsymParam::__FIELDS__()); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/quantization/quantize_graph_pass.cc b/src/operator/quantization/quantize_graph_pass.cc index 9a8cbc24ff45..a8fd6b07bb0f 100644 --- a/src/operator/quantization/quantize_graph_pass.cc +++ b/src/operator/quantization/quantize_graph_pass.cc @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2016 by Contributors * \file quantize_graph_pass.cc * \brief */ @@ -251,10 +250,14 @@ Graph QuantizeGraph(Graph &&src) { static const auto& need_requantize_map = Op::GetAttr("FNeedRequantize"); static const auto& avoid_quantize_input_map = Op::GetAttr("FAvoidQuantizeInput"); - const auto offline_params = src.GetAttr>("offline_params"); - const auto quantized_dtype = src.GetAttr("quantized_dtype"); + static const auto& avoid_dequantize_map = + Op::GetAttr("FAvoidDequantizeOutput"); + static const auto& need_asym_quantize_map = + Op::GetAttr("FNeedAsymQuantizeInput"); + const auto offline_params = src.GetAttr>("offline_params"); + const auto quantized_dtype = src.GetAttr("quantized_dtype"); const auto quantize_granularity = src.GetAttr("quantize_granularity"); - const auto dev_type = src.GetAttr("target_ctx"); + const auto dev_type = src.GetAttr("target_ctx"); if (dev_type == Context::kGPU && quantize_granularity == "channel-wise") { LOG(FATAL) << "`channel-wise` quantization option is not supported yet by GPU," @@ -295,7 +298,14 @@ Graph QuantizeGraph(Graph &&src) { if (avoid_quantize_input_map.count(node->op()) && avoid_quantize_input_map[node->op()](node->attrs, i, quantize_granularity)) { new_node->inputs.emplace_back(mirror_entry); - } else if (!quantized_node_map.count(e.node)) { + } else if (!quantized_node_map.count(e.node) || + (avoid_dequantize_map.count(e.node->op()) && + avoid_dequantize_map[e.node->op()](e.node->attrs, e.index))) { + // If the input of current quantized node has non-support of quantization, a quantize op + // is supposed to insert into the position after the input node to quantize the float + // input to int8/uint8 type. Also, a quantized operator with avoid-dequantize attribute + // can produce float outputs directly. A quantize op is necessary to convert them into + // int8/uint8 type as the input of current quantized node. if (mirror_entry_map.count(e)) { new_node->inputs.emplace_back(mirror_entry_map[e]); } else { @@ -314,9 +324,22 @@ Graph QuantizeGraph(Graph &&src) { } } - ObjectPtr quantize_node = InsertNode("_contrib_quantize_v2", - e.node->attrs.name + suffix + "_quantize", new_node, mirror_entry); - quantize_node->attrs.dict["out_type"] = quantized_dtype; + ObjectPtr quantize_node; + if (need_asym_quantize_map.count(node->op()) && + need_asym_quantize_map[node->op()](node->attrs, i)) { + quantize_node = InsertNode("_contrib_quantize_asym", + e.node->attrs.name + suffix + "_quantize", + new_node, + mirror_entry); + } else { + quantize_node = InsertNode("_contrib_quantize_v2", + e.node->attrs.name + suffix + "_quantize", + new_node, + mirror_entry); + // If current node is rnn op, the quantize op is supposed to quantize the result of + // pre-node to uint8, as quantized rnn op requires uint8 input. + quantize_node->attrs.dict["out_type"] = quantized_dtype; + } quantize_node->op()->attr_parser(&(quantize_node->attrs)); mirror_entry_map[e] = NodeEntry{quantize_node, 0, e.version}; } @@ -401,9 +424,13 @@ Graph QuantizeGraph(Graph &&src) { ObjectPtr mirror_node = mirror_map.at(e.node.get()); NodeEntry mirror_entry = NodeEntry{ mirror_node, e.index, e.version}; - // if input node is quantized operator, add dequantize node + // If input node is quantized operator, add dequantize node. But if input node is a + // quantized operator with avoid-dequantize attribute, its output may be already in float + // type, which dosen't need a dequantize op. if (quantized_node_map.count(e.node) && - (mirror_node->op() != Op::Get("_contrib_dequantize"))) { + mirror_node->op() != Op::Get("_contrib_dequantize") && + !(avoid_dequantize_map.count(e.node->op()) && + avoid_dequantize_map[e.node->op()](e.node->attrs, e.index))) { // here we calculate the output number (exclude min/max, in order to // calculate min/max index from mirror node) based on assumption that // there is only 1 min and 1 max output from mirror node (which is @@ -435,7 +462,9 @@ Graph QuantizeGraph(Graph &&src) { std::vector outputs; for (const auto& e : src.outputs) { - if (quantized_node_map.count(e.node)) { + if (quantized_node_map.count(e.node) && + !(avoid_dequantize_map.count(e.node->op()) && + avoid_dequantize_map[e.node->op()](e.node->attrs, e.index))) { // Only insert dequantize for those Ops supports quantize and not excluded. ObjectPtr mirror_node = mirror_map.at(e.node.get()); NodeEntry mirror_entry = NodeEntry{mirror_node, e.index, e.version}; diff --git a/src/operator/quantization/quantize_v2.cc b/src/operator/quantization/quantize_v2.cc index 9a30386723be..8d4a87655a31 100644 --- a/src/operator/quantization/quantize_v2.cc +++ b/src/operator/quantization/quantize_v2.cc @@ -18,9 +18,7 @@ */ /*! - * Copyright (c) 2017 by Contributors - * \file quantize.cc - * \brief + * \file quantize_v2.cc */ #include "./quantize_v2-inl.h" diff --git a/src/operator/quantization/quantized_rnn-inl.h b/src/operator/quantization/quantized_rnn-inl.h new file mode 100644 index 000000000000..6ab53cef867c --- /dev/null +++ b/src/operator/quantization/quantized_rnn-inl.h @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file quantized_rnn-inl.h + * \brief Common functions for quantized recurrent neural network + * \author Zixuan Wei + */ + +#ifndef MXNET_OPERATOR_QUANTIZATION_QUANTIZED_RNN_INL_H_ +#define MXNET_OPERATOR_QUANTIZATION_QUANTIZED_RNN_INL_H_ + +namespace mxnet { +namespace op { + +namespace quantized_rnn { +enum QuantizedRnnInputs { kData, kParams, kState, kStateCell }; +enum QuantizedRnnInputMinMax { kDataScale, kDataShift }; +enum QuantizedRnnOutputs { kOut, kStateOut, kStateCellOut }; +} // namespace quantized_rnn + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_QUANTIZATION_QUANTIZED_RNN_INL_H_ diff --git a/src/operator/quantization/quantized_rnn.cc b/src/operator/quantization/quantized_rnn.cc new file mode 100644 index 000000000000..1870fd2a40c2 --- /dev/null +++ b/src/operator/quantization/quantized_rnn.cc @@ -0,0 +1,356 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file quantized_rnn.cc + * \brief Common functions for quantized recurrent neural network + * \author Zixuan Wei + */ + +#include +#include +#include + +#include "../rnn-inl.h" +#include "./quantization_utils.h" +#include "./quantized_rnn-inl.h" + +#if MXNET_USE_MKLDNN == 1 +#include "./mkldnn/mkldnn_quantized_rnn-inl.h" +#endif + +namespace mxnet { +namespace op { + +uint32_t QuantizedRnnNumInputs(const NodeAttrs& attrs) { + const RNNParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(param.mode, rnn_enum::kLstm) + << "Quantized recurrent neural network only supports LSTM operator on " + "CPU."; + return 6U; +} + +uint32_t QuantizedRnnNumOutputs(const NodeAttrs& attrs) { + const RNNParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(param.mode, rnn_enum::kLstm) + << "Quantized recurrent neural network only supports LSTM operator on " + "CPU."; + return param.state_outputs ? 3U : 1U; +} + +std::vector QuantizedRnnInputNames(const NodeAttrs& attrs) { + const RNNParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(param.mode, rnn_enum::kLstm) + << "Quantized recurrent neural network only supports LSTM operator on " + "CPU."; + return std::vector{ + "data", "parameters", "state", "state_cell", "min_data", "max_data"}; +} + +std::vector QuantizedRnnOutputNames(const NodeAttrs& attrs) { + const RNNParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(param.mode, rnn_enum::kLstm) + << "Quantized recurrent neural network only supports LSTM operator on " + "CPU."; + if (param.state_outputs) { + return std::vector{"output", "state_output", "statecell_ouput"}; + } else { + return std::vector{"output"}; + } +} + +bool QuantizedRnnShape(const nnvm::NodeAttrs& attrs, + std::vector* in_shape, + std::vector* out_shape) { + const RNNParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(param.mode, rnn_enum::kLstm) << "Quantized RNN operator only supports LSTM mode."; + + const uint32_t num_inputs = QuantizedRnnNumInputs(attrs); + const uint32_t num_outputs = QuantizedRnnNumOutputs(attrs); + CHECK_EQ(in_shape->size(), num_inputs) + << "Arguments' size of quantized RNN operator is mismatched. Expected " << num_inputs + << " argmuments but got " << in_shape->size() << "."; + CHECK_EQ(out_shape->size(), num_outputs); + + const mxnet::TShape dshape = in_shape->at(quantized_rnn::kData); + if (!mxnet::ndim_is_known(dshape)) + return false; + CHECK_EQ(dshape.ndim(), 3U) << "Input data of RNN operator should be 3-rank " + "tensor of dim [steps, batch, input size]"; + const dim_t batch_size = dshape[1]; + const dim_t input_size = dshape[2]; + const dim_t directions = param.bidirectional ? 2 : 1; + const dim_t total_lyrs = directions * param.num_layers; + const dim_t state_size = param.state_size; + SHAPE_ASSIGN_CHECK(*in_shape, quantized_rnn::kState, Shape3(total_lyrs, batch_size, state_size)); + if (param.mode == rnn_enum::kLstm) + SHAPE_ASSIGN_CHECK( + *in_shape, quantized_rnn::kStateCell, Shape3(total_lyrs, batch_size, state_size)); + + const int param_size_fp = GetRnnParamSize( + param.num_layers, input_size, state_size, directions, param.mode, param.projection_size); + SHAPE_ASSIGN_CHECK(*in_shape, quantized_rnn::kParams, Shape1(param_size_fp)); + const uint32_t num_base_inputs = GetRnnNumInputs(param); + for (size_t i = num_base_inputs; i < num_inputs; ++i) + SHAPE_ASSIGN_CHECK(*in_shape, i, Shape1(1)); + + out_shape->clear(); + out_shape->push_back({dshape[0], batch_size, directions * state_size}); // output dim: [T, N, C] + if (param.state_outputs) { + out_shape->push_back({total_lyrs, batch_size, state_size}); // state dim: [L*D, N, C] + if (param.mode == rnn_enum::kLstm) + out_shape->push_back({total_lyrs, batch_size, state_size}); // cell dim: [L*D, N, C] + } + return true; +} + +bool QuantizedRnnType(const nnvm::NodeAttrs& attrs, + std::vector* in_type, + std::vector* out_type) { + const RNNParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(param.mode, rnn_enum::kLstm) << "Quantized RNN operator only supports LSTM mode."; + + const uint32_t num_inputs = QuantizedRnnNumInputs(attrs); + const uint32_t num_outputs = QuantizedRnnNumOutputs(attrs); + CHECK_EQ(in_type->size(), num_inputs); + CHECK_EQ(out_type->size(), num_outputs); + + CHECK_EQ(in_type->at(quantized_rnn::kData), mshadow::kUint8) + << "Quantized RNN operator only supports uint8 input, while " + << in_type->at(quantized_rnn::kData) << " is given."; + TYPE_ASSIGN_CHECK(*in_type, quantized_rnn::kParams, mshadow::kFloat32); + TYPE_ASSIGN_CHECK(*in_type, quantized_rnn::kState, mshadow::kFloat32); + const uint32_t num_base_inputs = GetRnnNumInputs(param); + if (param.mode == rnn_enum::kLstm) + TYPE_ASSIGN_CHECK(*in_type, quantized_rnn::kStateCell, mshadow::kFloat32); + for (size_t i = num_base_inputs; i < num_inputs; ++i) + TYPE_ASSIGN_CHECK(*in_type, i, mshadow::kFloat32); + + TYPE_ASSIGN_CHECK(*out_type, quantized_rnn::kOut, mshadow::kFloat32); + if (param.state_outputs) { + TYPE_ASSIGN_CHECK(*out_type, quantized_rnn::kStateOut, mshadow::kFloat32); + if (param.mode == rnn_enum::kLstm) + TYPE_ASSIGN_CHECK(*out_type, quantized_rnn::kStateCellOut, mshadow::kFloat32); + } + return true; +} + +bool QuantizedRnnStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector* in_attrs, + std::vector* out_attrs) { + const uint32_t num_inputs = QuantizedRnnNumInputs(attrs); + const uint32_t num_outputs = QuantizedRnnNumOutputs(attrs); + CHECK_EQ(in_attrs->size(), num_inputs); + CHECK_EQ(out_attrs->size(), num_outputs); + +#if MXNET_USE_MKLDNN == 1 + return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, out_attrs); +#else + *dispatch_mode = DispatchMode::kFCompute; + + for (auto& v : *out_attrs) { + v = kDefaultStorage; + if (common::stype_string(v).compare("unknown") == 0) { + return false; + } + } + + for (auto& v : *in_attrs) { + v = kDefaultStorage; + if (common::stype_string(v).compare("unknown") == 0) { + return false; + } + } + return true; +#endif +} + +void QuantizedRnnParamParser(nnvm::NodeAttrs* attrs) { + RNNParam param; + attrs->dict["quantized"] = "true"; + try { + param.Init(attrs->dict, dmlc::parameter::kAllowUnknown); + } catch (const dmlc::ParamError& e) { + std::ostringstream os; + os << e.what(); + os << ", in operator " << attrs->op->name << "(" + << "name=\"" << attrs->name << "\""; + for (const auto& k : attrs->dict) { + os << ", " << k.first << "=\"" << k.second << "\""; + } + os << ")"; + throw dmlc::ParamError(os.str()); + } + attrs->parsed = std::move(param); +} + +OpStatePtr CreateQuantizedRnnState(const nnvm::NodeAttrs& attrs, + const Context ctx, + const mxnet::ShapeVector& in_shapes, + const std::vector& in_types) { + const RNNParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(param.mode, rnn_enum::kLstm) << "Quantized RNN operator only supports LSTM mode."; + OpStatePtr state = OpStatePtr(); +#if MXNET_USE_MKLDNN == 1 + const int data_type = in_types[quantized_rnn::kData]; + const int weight_type = in_types[quantized_rnn::kParams]; + if (data_type == mshadow::kUint8 && weight_type == mshadow::kFloat32) { + const mxnet::TShape& data_shape = in_shapes[quantized_rnn::kData]; + state = OpStatePtr::Create( + attrs, data_shape[0], data_shape[1], data_shape[2]); + } +#else + LOG(FATAL) << "Quantized RNN operator relies on MKL-DNN library." + << " Please build MXNet with USE_MKLDNN=ON to leverage this operator."; +#endif + return state; +} + +void QuantizedRnnForwardCPU(const OpStatePtr& state_ptr, + const OpContext& ctx, + const std::vector& in_data, + const std::vector& req, + const std::vector& out_data) { + LOG(FATAL) << "Quantized RNN operator relies on MKL-DNN library." + << " Please build MXNet with USE_MKLDNN=ON to leverage this operator."; +} + +#if MXNET_USE_MKLDNN == 1 +void QuantizedRnnForwardCPUEx(const OpStatePtr& state_ptr, + const OpContext& ctx, + const std::vector& in_data, + const std::vector& req, + const std::vector& out_data) { + MKLDNNQuantizedRnnOp& op = state_ptr.get_state(); + op.Forward(ctx, in_data, req, out_data); +} +#endif // MXNET_USE_MKLDNN == 1 + +bool NeedAsymQuantizeRnnInput(const NodeAttrs& attrs, const size_t index_to_check) { + bool need_asym_quantize = false; + switch (index_to_check) { + case rnn_enum::kData: { + need_asym_quantize = true; + break; + } + default: { + need_asym_quantize = false; + } + } + return need_asym_quantize; +} + +bool AvoidRnnQuantizeInput(const NodeAttrs& attrs, + const size_t index_to_check, + const std::string quantize_granularity) { + std::unordered_set avoid_indexes; + avoid_indexes.insert({quantized_rnn::kParams, quantized_rnn::kState, quantized_rnn::kStateCell}); + + return avoid_indexes.count(index_to_check); +} + +bool AvoidRnnDequantizeOutput(const NodeAttrs& attrs, const size_t index_to_check) { + return true; +} + +static std::vector QuantizedRnnResourceEx(const NodeAttrs& attrs, + const int dev_mask, + const DispatchMode dispatch_mode) { + std::vector request; + if (dev_mask == kGPU) { +#if MXNET_USE_CUDNN == 1 + LOG(FATAL) << "Currently, quantized RNN is not supported on the GPU platform."; +#endif + } else { +#if MXNET_USE_MKLDNN == 1 + request.emplace_back(ResourceRequest::kTempSpace); +#endif + } + return request; +} + +NNVM_REGISTER_OP(_contrib_quantized_rnn) + .describe( + R"code(RNN operator for input data type of uint8. The weight of each gates is converted +to int8, while bias is accumulated in type float32. The hidden state and cell state are in type +float32. For the input data, two more arguments of type float32 must be provided representing the +thresholds of quantizing argument from data type float32 to uint8. The final outputs contain the +recurrent result in float32. It only supports quantization for Vanilla LSTM network. +.. Note:: + This operator only supports forward propagation. DO NOT use it in training.)code" ADD_FILELINE) + .set_num_inputs(QuantizedRnnNumInputs) + .set_num_outputs(QuantizedRnnNumOutputs) + .set_attr_parser(QuantizedRnnParamParser) + .set_attr("FListInputNames", QuantizedRnnInputNames) + .set_attr("FListOutputNames", QuantizedRnnOutputNames) + .set_attr("FInferShape", QuantizedRnnShape) + .set_attr("FInferType", QuantizedRnnType) + .set_attr("FInferStorageType", QuantizedRnnStorageType) + .set_attr("FCreateOpState", CreateQuantizedRnnState) + .set_attr("FStatefulCompute", QuantizedRnnForwardCPU) +#if MXNET_USE_MKLDNN == 1 + .set_attr("TIsMKLDNN", true) + .set_attr("FStatefulComputeEx", QuantizedRnnForwardCPUEx) +#endif + .set_attr("FResourceRequestEx", QuantizedRnnResourceEx) + .add_argument("data", "NDArray-or-Symbol", "Input data.") + .add_argument("parameters", "NDArray-or-Symbol", "weight.") + .add_argument("state", "NDArray-or-Symbol", "initial hidden state of the RNN") + .add_argument("state_cell", + "NDArray-or-Symbol", + "initial cell state for LSTM networks (only for LSTM)") + .add_argument("data_scale", "NDArray-or-Symbol", "quantization scale of data.") + .add_argument("data_shift", "NDArray-or-Symbol", "quantization shift of data.") + .add_arguments(RNNParam::__FIELDS__()); + +NNVM_REGISTER_OP(RNN) + .set_attr("FQuantizable", + [](const NodeAttrs& attrs) { +#if MXNET_USE_MKLDNN == 1 + const RNNParam& param = nnvm::get(attrs.parsed); + if (param.mode != rnn_enum::kLstm) + LOG(INFO) << "Quantized RNN only supports LSTM mode."; + return param.mode == rnn_enum::kLstm ? QuantizeType::kMust : + QuantizeType::kNone; +#else + LOG(INFO) << "Quantized RNN is not supported by this MXNet release. Please enable MKL-DNN to " + << "use the feature."; + return QuantizeType::kNone; +#endif // MXNET_USE_MKLDNN == 1 + }) + .set_attr("FQuantizedOp", + [](const NodeAttrs& attrs) { + nnvm::ObjectPtr node = nnvm::Node::Create(); + node->attrs.op = Op::Get("_contrib_quantized_rnn"); + node->attrs.name = "quantized_" + attrs.name; + node->attrs.dict = attrs.dict; + node->attrs.dict["quantized"] = "true"; + if (node->op()->attr_parser != nullptr) { + node->op()->attr_parser(&(node->attrs)); + } + return node; + }) + .set_attr("FNeedAsymQuantizeInput", NeedAsymQuantizeRnnInput) + .set_attr("FAvoidQuantizeInput", AvoidRnnQuantizeInput) + .set_attr("FAvoidDequantizeOutput", AvoidRnnDequantizeOutput); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h index 1d6f2eb5c36a..c55abf9aa49e 100644 --- a/src/operator/rnn-inl.h +++ b/src/operator/rnn-inl.h @@ -18,9 +18,7 @@ */ /*! - * Copyright (c) 2015 by Contributors * \file rnn-inl.h - * \brief * \author Sebastian Bodenstein, Shu Zhang */ #ifndef MXNET_OPERATOR_RNN_INL_H_ @@ -244,9 +242,9 @@ inline size_t GetRNNReserveSpaceSize(int num_layer, return size; } -inline size_t GetNumInputArguments(RNNParam param_) { - size_t num_inputs = (param_.mode == rnn_enum::kLstm) ? 4U : 3U; - if (param_.use_sequence_length) num_inputs += 1U; +inline size_t GetRnnNumInputs(RNNParam param) { + size_t num_inputs = (param.mode == rnn_enum::kLstm) ? 4U : 3U; + if (param.use_sequence_length) num_inputs += 1U; return num_inputs; } @@ -571,7 +569,7 @@ class RNNOp { using namespace mshadow::expr; CHECK(param_.p >= 0.0f && param_.p < 1.0f) << "unsupported dropout value, should be 0 <= dropout < 1"; - size_t num_inputs = GetNumInputArguments(param_); + size_t num_inputs = GetRnnNumInputs(param_); // kOut size_t num_outputs = 1; @@ -940,7 +938,7 @@ class RNNOp { CHECK(param_.p >= 0.0f && param_.p < 1.0f) << "unsupported dropout value, should be 0 <= dropout < 1"; - size_t num_inputs = GetNumInputArguments(param_); + size_t num_inputs = GetRnnNumInputs(param_); // kOut size_t num_outputs = 1; @@ -1178,7 +1176,7 @@ class RNNOp { const std::vector &out_data) { using namespace mshadow; - size_t num_inputs = GetNumInputArguments(param_); + size_t num_inputs = GetRnnNumInputs(param_); // kOut size_t num_outputs = 1; if (param_.state_outputs) { diff --git a/src/operator/rnn.cc b/src/operator/rnn.cc index efebc915a0e7..514b4c5c7326 100644 --- a/src/operator/rnn.cc +++ b/src/operator/rnn.cc @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2015 by Contributors * \file rnn.cc * \brief * \author Sebastian Bodenstein @@ -35,31 +34,41 @@ namespace mxnet { namespace op { DMLC_REGISTER_PARAMETER(RNNParam); -static inline std::vector ListArguments(const RNNParam& param_) { +static inline std::vector ListRnnInputNames(const RNNParam& param) { // All RNNs start off with same 3 input arguments std::vector arguments{"data", "parameters", "state"}; // LSTMs also have an additional state_cell argument - if (param_.mode == rnn_enum::kLstm) { + if (param.mode == rnn_enum::kLstm) { arguments.emplace_back("state_cell"); } // All RNNs have option of additional sequence_length argument - if (param_.use_sequence_length) { + if (param.use_sequence_length) { arguments.emplace_back("sequence_length"); } return arguments; } +static inline std::vector ListRnnOutputNames(const RNNParam& param) { + std::vector names{"output"}; + if (param.state_outputs) { + names.emplace_back("state_output"); + if (param.mode == rnn_enum::kLstm) + names.emplace_back("statecell_output"); + } + return names; +} + static bool RNNShape(const nnvm::NodeAttrs& attrs, std::vector *in_shape, std::vector *out_shape) { - const RNNParam& param_ = nnvm::get(attrs.parsed); using namespace mshadow; + const RNNParam& param = nnvm::get(attrs.parsed); - // Query param_ object to figure out what the expectd input arguments are - std::vector expected_arguments = ListArguments(param_); + // Query param object to figure out what the expectd input arguments are + std::vector expected_arguments = ListRnnInputNames(param); CHECK_EQ(in_shape->size(), expected_arguments.size()) << "Input shape mismatch. Expected " << expected_arguments.size() << " input parameters but got " << in_shape->size() << "."; @@ -69,34 +78,31 @@ static bool RNNShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(dshape.ndim(), 3U) \ << "Input data should be rank-3 tensor of dim [sequence length, batch size, input size]"; // data: [sequence len, batch, input dimension] - int batch_size = dshape[1]; - int input_size = dshape[2]; - int numDirections = param_.bidirectional ? 2 : 1; - int total_layers = numDirections * param_.num_layers; // double for bidirectional - int layer_size = (param_.projection_size.has_value()) ? - param_.projection_size.value() : param_.state_size; - SHAPE_ASSIGN_CHECK(*in_shape, - rnn_enum::kState, - Shape3(total_layers, batch_size, layer_size)); - if (param_.mode == rnn_enum::kLstm) { - SHAPE_ASSIGN_CHECK(*in_shape, - rnn_enum::kStateCell, - Shape3(total_layers, batch_size, param_.state_size)); + int batch_size = dshape[1]; + int input_size = dshape[2]; + int numDirections = param.bidirectional ? 2 : 1; + int total_layers = numDirections * param.num_layers; // double for bidirectional + int layer_size = + (param.projection_size.has_value()) ? param.projection_size.value() : param.state_size; + SHAPE_ASSIGN_CHECK(*in_shape, rnn_enum::kState, Shape3(total_layers, batch_size, layer_size)); + if (param.mode == rnn_enum::kLstm) { + SHAPE_ASSIGN_CHECK( + *in_shape, rnn_enum::kStateCell, Shape3(total_layers, batch_size, param.state_size)); } // calculate parameter vector length - int param_size = GetRnnParamSize(param_.num_layers, + int param_size = GetRnnParamSize(param.num_layers, input_size, - param_.state_size, + param.state_size, numDirections, - param_.mode, - param_.projection_size); + param.mode, + param.projection_size); SHAPE_ASSIGN_CHECK(*in_shape, rnn_enum::kParams, Shape1(param_size)); // Check on sequence_length shape if using - if (param_.use_sequence_length) { + if (param.use_sequence_length) { size_t seq_len_input_idx = rnn_enum::kSequenceLength; - if (param_.mode != rnn_enum::kLstm) --seq_len_input_idx; + if (param.mode != rnn_enum::kLstm) --seq_len_input_idx; SHAPE_ASSIGN_CHECK(*in_shape, seq_len_input_idx, Shape1(batch_size)); } @@ -104,29 +110,29 @@ static bool RNNShape(const nnvm::NodeAttrs& attrs, out_shape->clear(); // output: [sequence len, batch, output size] TShape oshape = dshape; - if (param_.projection_size.has_value()) { - oshape[2] = numDirections * param_.projection_size.value(); + if (param.projection_size.has_value()) { + oshape[2] = numDirections * param.projection_size.value(); } else { - oshape[2] = numDirections * param_.state_size; + oshape[2] = numDirections * param.state_size; } out_shape->push_back(oshape); - if (param_.state_outputs) { + if (param.state_outputs) { // outStateShape: [layer_num, batch, state size] TShape outStateShape = dshape; outStateShape[0] = total_layers; outStateShape[1] = batch_size; - if (param_.projection_size.has_value()) { - outStateShape[2] = param_.projection_size.value(); + if (param.projection_size.has_value()) { + outStateShape[2] = param.projection_size.value(); } else { - outStateShape[2] = param_.state_size; + outStateShape[2] = param.state_size; } out_shape->push_back(outStateShape); // Deal with lstm cell state - if (param_.mode == rnn_enum::kLstm) { + if (param.mode == rnn_enum::kLstm) { TShape cellStateShape = dshape; cellStateShape[0] = total_layers; cellStateShape[1] = batch_size; - cellStateShape[2] = param_.state_size; + cellStateShape[2] = param.state_size; out_shape->push_back(cellStateShape); } } @@ -137,33 +143,33 @@ static bool RNNShape(const nnvm::NodeAttrs& attrs, static bool RNNType(const nnvm::NodeAttrs& attrs, std::vector *in_type, std::vector *out_type) { - const RNNParam& param_ = nnvm::get(attrs.parsed); + const RNNParam& param = nnvm::get(attrs.parsed); - CHECK_EQ(in_type->size(), GetNumInputArguments(param_)); + CHECK_EQ(in_type->size(), GetRnnNumInputs(param)); size_t seq_len_input_idx = rnn_enum::kSequenceLength; - if (param_.mode != rnn_enum::kLstm) --seq_len_input_idx; + if (param.mode != rnn_enum::kLstm) --seq_len_input_idx; int dtype = (*in_type)[0]; CHECK_NE(dtype, -1) << "First input must have specified type"; - std::vector arguments = ListArguments(param_); + std::vector arguments = ListRnnInputNames(param); for (size_t i = 0; i < in_type->size(); ++i) { if ((*in_type)[i] == -1) { TYPE_ASSIGN_CHECK(*in_type, i, dtype); } else { // If using sequence length argument, it has its own indexing type // All other input arguments must match the main data type - if (!(param_.use_sequence_length && i == seq_len_input_idx)) { + if (!(param.use_sequence_length && i == seq_len_input_idx)) { UNIFORM_TYPE_CHECK((*in_type)[i], dtype, arguments[i]); } } } out_type->clear(); out_type->push_back(dtype); - if (param_.state_outputs) { + if (param.state_outputs) { out_type->push_back(dtype); // Deal with lstm cell state - if (param_.mode == rnn_enum::kLstm) { + if (param.mode == rnn_enum::kLstm) { out_type->push_back(dtype); } } @@ -248,7 +254,7 @@ static OpStatePtr CreateRNNState(const nnvm::NodeAttrs &attrs, #if MXNET_USE_MKLDNN == 1 if (ctx.dev_type == kCPU && SupportMKLDNNRnn(param, in_types[rnn_enum::kData])) { const mxnet::TShape& data_shape = in_shapes[rnn_enum::kData]; - state = OpStatePtr::Create(param, data_shape[0], + state = OpStatePtr::Create(attrs, data_shape[0], data_shape[1], data_shape[2]); return state; } @@ -370,7 +376,7 @@ The definition of GRU here is slightly different from paper but compatible with .set_attr_parser(ParamParser) .set_num_inputs([](const NodeAttrs& attrs) { const RNNParam& params = nnvm::get(attrs.parsed); - return GetNumInputArguments(params); + return GetRnnNumInputs(params); }) .set_num_outputs([](const NodeAttrs& attrs) { const RNNParam& params = nnvm::get(attrs.parsed); @@ -383,20 +389,13 @@ The definition of GRU here is slightly different from paper but compatible with return num_outputs; }) -.set_attr("FListInputNames", - [](const NodeAttrs& attrs) { +.set_attr("FListInputNames", [](const NodeAttrs& attrs) { const RNNParam& params = nnvm::get(attrs.parsed); - return ListArguments(params); + return ListRnnInputNames(params); }) .set_attr("FListOutputNames", [](const NodeAttrs& attrs) { const RNNParam& params = nnvm::get(attrs.parsed); - std::vector names{"output"}; - if (params.state_outputs) { - names.emplace_back("state_output"); - if (params.mode == rnn_enum::kLstm) - names.emplace_back("statecell_output"); - } - return names; + return ListRnnOutputNames(params); }) .set_attr("FInferShape", RNNShape) .set_attr("FInferType", RNNType) @@ -437,7 +436,7 @@ NNVM_REGISTER_OP(_backward_RNN) }) .set_num_outputs([](const NodeAttrs& attrs) { const RNNParam& params = nnvm::get(attrs.parsed); - return GetNumInputArguments(params); + return GetRnnNumInputs(params); }) .set_attr_parser(ParamParser) .set_attr("TIsLayerOpBackward", true) diff --git a/tests/python/quantization/test_quantization.py b/tests/python/quantization/test_quantization.py index d36094fb9665..f85a5e83d214 100644 --- a/tests/python/quantization/test_quantization.py +++ b/tests/python/quantization/test_quantization.py @@ -594,6 +594,69 @@ def maxabs(a, b): check_quantized_fc((256, 2048, 2, 2), 800, True, qdtype) check_quantized_fc((256, 111, 2, 2), 800, True, qdtype) +@with_seed() +def test_quantized_rnn(): + def check_quantized_rnn(num_layers, bidirectional, seq_len, batch_size, input_dim, state_dim): + if is_test_for_gpu(): + print('skipped testing test_quantized_rnn for gpu since it is not supported yet') + return + if is_test_for_native_cpu(): + print('skipped testing test_quantized_rnn for native cpu since it is not supported yet') + return + + data_shape = (seq_len, batch_size, input_dim) + data = mx.sym.Variable(name='data', shape=data_shape, dtype='float32') + rnn_fp32 = mx.sym.RNN(data=data, + num_layers=num_layers, + bidirectional=bidirectional, + state_outputs=True, + state_size=state_dim, + mode='lstm', + name='rnn') + arg_shapes, _, _ = rnn_fp32.infer_shape(data=data_shape) + arg_names = rnn_fp32.list_arguments() + rnn_fp32_exe = rnn_fp32.simple_bind(ctx=mx.current_context(), grad_req='null') + + data = mx.nd.random.uniform(low=-1, high=1, shape=arg_shapes[0]) + weight = mx.nd.random.uniform(low=-1, high=1, shape=arg_shapes[1]) + state = mx.nd.random.uniform(low=-1, high=1, shape=arg_shapes[2]) + cell = mx.nd.random.uniform(low=-1, high=1, shape=arg_shapes[3]) + + rnn_fp32_exe.arg_dict[arg_names[0]][:] = data + rnn_fp32_exe.arg_dict[arg_names[1]][:] = weight + rnn_fp32_exe.arg_dict[arg_names[2]][:] = state + rnn_fp32_exe.arg_dict[arg_names[3]][:] = cell + output = rnn_fp32_exe.forward()[0] + + data_min = mx.nd.min(data) + data_max = mx.nd.max(data) + qdata = mx.sym.Variable(name='qdata', shape=data_shape, dtype='uint8') + rnn_int8 = mx.sym.contrib.quantized_rnn(data=qdata, + num_layers=num_layers, + bidirectional=bidirectional, + state_outputs=True, + state_size=state_dim, + mode='lstm', + name='qrnn') + qarg_names = rnn_int8.list_arguments() + rnn_int8_exe = rnn_int8.simple_bind(ctx=mx.current_context(), grad_req='null') + data_scale = 128.0 / (data_max - data_min) + data_shift = 128.0 - data_max * data_scale + qdata = (data * data_scale + data_shift + 0.5).astype('uint8') + rnn_int8_exe.arg_dict[qarg_names[0]][:] = qdata + rnn_int8_exe.arg_dict[qarg_names[1]][:] = weight + rnn_int8_exe.arg_dict[qarg_names[2]][:] = state + rnn_int8_exe.arg_dict[qarg_names[3]][:] = cell + rnn_int8_exe.arg_dict[qarg_names[4]][:] = data_scale + rnn_int8_exe.arg_dict[qarg_names[5]][:] = data_shift + qoutput = rnn_int8_exe.forward()[0] + + mse = np.mean((output.asnumpy() - qoutput.asnumpy())**2) + assert mse < 0.001 + + check_quantized_rnn(1, False, 5, 2, 16, 16) + check_quantized_rnn(1, True, 5, 2, 16, 16) + @with_seed() def test_quantized_embedding(): def check_quantized_embedding(data_shape, input_dim, output_dim):