Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[v1.x] Restore quantized RNN operator from MXNet 1.6 (#20759)
Browse files Browse the repository at this point in the history
* restore but seg fault

* Refactor & seg fault fixed

* apply formatter

* fix sanity

* Fix docs build

* anko review

* Remove copyright by contributors from touched files

* remove comments / apply formatter
  • Loading branch information
bgawrych authored Dec 3, 2021
1 parent 6cb49f0 commit 8c69a9f
Show file tree
Hide file tree
Showing 18 changed files with 1,712 additions and 167 deletions.
1 change: 1 addition & 0 deletions docs/python_docs/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ dependencies:
- conda>=4.6.13
- pip
- python
- setuptools==49.6.0
- jupyter
- sphinx==2.4.0
- matplotlib
Expand Down
15 changes: 15 additions & 0 deletions include/mxnet/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,21 @@ using FAvoidQuantizeInput = std::function<bool (const NodeAttrs& attrs,
const size_t index,
const std::string quantize_granularity)>;

/*!
* \brief Register a function to determine if the input of a quantized operator
* needs to be quantized asymmetrically.
*/
using FNeedAsymQuantizeInput = std::function<bool (const NodeAttrs& attrs,
const size_t index)>;

/*!
* \brief Register a function to determine if the output of a quantized operator
* needs to be dequantized. This is usually used for the quantized operators
* which can produce fp32 outputs directly.
*/
using FAvoidDequantizeOutput = std::function<bool (const NodeAttrs& attrs,
const size_t index)>;

/*!
* \brief Register a function to determine if the input of a quantized operator
* needs to be calibrated. This is usually used for the quantized operators
Expand Down
23 changes: 16 additions & 7 deletions python/mxnet/io/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from ..ndarray import array
from ..ndarray import concat, tile

from .utils import _init_data, _has_instance, _getdata_by_idx
from .utils import _init_data, _has_instance, _getdata_by_idx, _slice_along_batch_axis

class DataDesc(namedtuple('DataDesc', ['name', 'shape'])):
"""DataDesc is used to store name, shape, type and layout
Expand Down Expand Up @@ -602,10 +602,12 @@ class NDArrayIter(DataIter):
The data name.
label_name : str, optional
The label name.
layout : str, optional
The data layout
"""
def __init__(self, data, label=None, batch_size=1, shuffle=False,
last_batch_handle='pad', data_name='data',
label_name='softmax_label'):
label_name='softmax_label', layout='NCHW'):
super(NDArrayIter, self).__init__(batch_size)

self.data = _init_data(data, allow_empty=False, default_name=data_name)
Expand All @@ -631,20 +633,27 @@ def __init__(self, data, label=None, batch_size=1, shuffle=False,
# used for 'roll_over'
self._cache_data = None
self._cache_label = None
self.layout = layout

@property
def provide_data(self):
"""The name and shape of data provided by this iterator."""
batch_axis = self.layout.find('N')
return [
DataDesc(k, tuple([self.batch_size] + list(v.shape[1:])), v.dtype)
DataDesc(k, tuple(list(v.shape[:batch_axis]) + \
[self.batch_size] + list(v.shape[batch_axis + 1:])),
v.dtype, layout=self.layout)
for k, v in self.data
]

@property
def provide_label(self):
"""The name and shape of label provided by this iterator."""
batch_axis = self.layout.find('N')
return [
DataDesc(k, tuple([self.batch_size] + list(v.shape[1:])), v.dtype)
DataDesc(k, tuple(list(v.shape[:batch_axis]) + \
[self.batch_size] + list(v.shape[batch_axis + 1:])),
v.dtype, layout=self.layout)
for k, v in self.label
]

Expand Down Expand Up @@ -681,7 +690,7 @@ def next(self):
data = self.getdata()
label = self.getlabel()
# iter should stop when last batch is not complete
if data[0].shape[0] != self.batch_size:
if data[0].shape[self.layout.find('N')] != self.batch_size:
# in this case, cache it for next epoch
self._cache_data = data
self._cache_label = label
Expand All @@ -697,7 +706,7 @@ def _getdata(self, data_source, start=None, end=None):
end = data_source[0][1].shape[0] if data_source else 0
s = slice(start, end)
return [
x[1][s]
_slice_along_batch_axis(x[1], s, self.layout.find('N'))
if isinstance(x[1], (np.ndarray, NDArray)) else
# h5py (only supports indices in increasing order)
array(x[1][sorted(self.idx[s])][[
Expand All @@ -716,7 +725,7 @@ def _concat(self, first_data, second_data):
concat(
first_data[i],
second_data[i],
dim=0
dim=self.layout.find('N')
) for i in range(len(first_data))
]

Expand Down
5 changes: 5 additions & 0 deletions python/mxnet/io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,8 @@ def _getdata_by_idx(data, idx):
shuffle_data.append((k, array(v.asnumpy()[idx], v.context)))

return shuffle_data

def _slice_along_batch_axis(data, s, batch_axis):
"""Apply slice along the batch axis"""
ret = data.slice_axis(axis=batch_axis, begin=s.start, end=s.stop)
return ret
83 changes: 69 additions & 14 deletions src/operator/nn/mkldnn/mkldnn_rnn-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,42 @@

#include "../../rnn-inl.h"
#include "./mkldnn_base-inl.h"
#include "../../quantization/quantized_rnn-inl.h"

namespace mxnet {
namespace op {

struct MKLDNNRnnParam : public dmlc::Parameter<MKLDNNRnnParam> {
bool quantized;

DMLC_DECLARE_PARAMETER(MKLDNNRnnParam) {
DMLC_DECLARE_FIELD(quantized).set_default(false).describe(
"Whether it's a quantized RNN operator");
}
};

inline void MKLDNNMemoryReorder(const mkldnn::memory& src, const mkldnn::memory& dst) {
#if DMLC_CXX11_THREAD_LOCAL
static thread_local std::unordered_map<OpSignature, mkldnn::reorder, OpHash> reorderPrimitives;
#else
static MX_THREAD_LOCAL std::unordered_map<OpSignature, mkldnn::reorder, OpHash> reorderPrimitives;
#endif
OpSignature key{};
key.AddSign(src);
key.AddSign(dst);

auto it = reorderPrimitives.find(key);
if (it == reorderPrimitives.end()) {
auto reorder = mkldnn::reorder(src, dst);
it = AddToCache(&reorderPrimitives, key, reorder);
}

mkldnn_args_map_t net_args;
net_args.emplace(MKLDNN_ARG_SRC, src);
net_args.emplace(MKLDNN_ARG_DST, dst);
MKLDNNStream::Get()->RegisterPrimArgs(it->second, net_args);
}

struct MKLDNNRnnLayerParam {
using memory = mkldnn::memory;
using dims = mkldnn::memory::dims;
Expand Down Expand Up @@ -65,6 +97,10 @@ struct MKLDNNRnnLayerParam {
size_t native_single_b_size; // bias size of a single cell from framework
size_t single_state_size; // state size of a single cell, hy, cy

bool quantized; // whether this layer is quantized
bool enable_u8_output; // true by default, only be false when it is the last fusion layer of the
// quantized rnn operator

MKLDNNRnnLayerParam(int num_layer,
int batch_size,
int seq_len,
Expand All @@ -79,18 +115,21 @@ struct MKLDNNRnnLayerParam {
batch_size(batch_size),
input_size(input_size),
state_size(state_size),
seq_len(seq_len) {}
seq_len(seq_len),
quantized(false),
enable_u8_output(false) {}

void SetDims();
};

typedef std::vector<MKLDNNRnnLayerParam> LayerParamVector;
struct MKLDNNRnnFullParam {
RNNParam default_param;
MKLDNNRnnParam mkldnn_param;
LayerParamVector layer_params;
};

MKLDNNRnnFullParam MKLDNNRnnFullParamParser(const RNNParam& rnn_param,
MKLDNNRnnFullParam MKLDNNRnnFullParamParser(const nnvm::NodeAttrs& attrs,
const int seq_len,
const int batch_size,
const int input_size);
Expand All @@ -102,7 +141,7 @@ class MKLDNNRnnMemMgr {
// The memory buffer in NDArray life-cycle
NDArray workspace_;
// This points to the memory buffer from a NDArray
char* curr_mem;
char* curr_mem = nullptr;
// The total bytes of the workspace of a MKLDNNRnnOp
size_t mem_size = 0;
// The current available memory bytes
Expand All @@ -113,7 +152,7 @@ class MKLDNNRnnMemMgr {
std::vector<std::shared_ptr<const mkldnn::memory>> mem_holder;

public:
void Init(dim_t size, const Context& ctx, int dtype = mshadow::kFloat32);
void Init(const dim_t size, const Context& ctx, int dtype = mshadow::kFloat32);

void RegisterMem(std::shared_ptr<const mkldnn::memory> mem) {
mem_holder.push_back(mem);
Expand All @@ -122,6 +161,8 @@ class MKLDNNRnnMemMgr {
mkldnn::memory* Alloc(const mkldnn::memory::desc& md);
};

typedef std::shared_ptr<mkldnn::primitive_attr> shared_mkldnn_attr_t;

/*
* Rnn Primitive.
*/
Expand All @@ -131,15 +172,15 @@ class RnnPrimitive {
* lstm_forward, lbr_gru_forward, vanilla_rnn_forward
*/
template <typename rnn_fwd, typename... Args>
static RnnPrimitive Create(Args&&... args) {
static RnnPrimitive Create(const shared_mkldnn_attr_t attr, Args&&... args) {
RnnPrimitive rnn_fwd_prim;
auto fwd_desc = typename rnn_fwd::desc(std::forward<Args>(args)...);
rnn_fwd_prim.fwd_pd_.reset(
new typename rnn_fwd::primitive_desc(fwd_desc, CpuEngine::Get()->get_engine()),
[](typename rnn_fwd::primitive_desc* pd) {
delete reinterpret_cast<typename rnn_fwd::primitive_desc*>(pd);
});
new typename rnn_fwd::primitive_desc(
fwd_desc, attr ? *attr : mkldnn::primitive_attr(), CpuEngine::Get()->get_engine()),
[](void* pd) { delete reinterpret_cast<typename rnn_fwd::primitive_desc*>(pd); });
auto fwd_pd = reinterpret_cast<typename rnn_fwd::primitive_desc*>(rnn_fwd_prim.fwd_pd_.get());
rnn_fwd_prim.attr_ = attr;
rnn_fwd_prim.weights_layer_desc_ = fwd_pd->weights_layer_desc();
rnn_fwd_prim.weights_iter_desc_ = fwd_pd->weights_iter_desc();
rnn_fwd_prim.workspace_desc_ = fwd_pd->workspace_desc();
Expand All @@ -150,6 +191,7 @@ class RnnPrimitive {
}

RnnPrimitive() {
this->attr_ = nullptr;
this->fwd_pd_ = nullptr;
this->primitive_ = nullptr;
this->weights_layer_desc_ = mkldnn::memory::desc();
Expand All @@ -158,6 +200,7 @@ class RnnPrimitive {
}

RnnPrimitive(const RnnPrimitive& rnn_fwd_prim) {
this->attr_ = rnn_fwd_prim.attr_;
this->fwd_pd_ = rnn_fwd_prim.fwd_pd_;
this->primitive_ = rnn_fwd_prim.primitive_;
this->weights_layer_desc_ = rnn_fwd_prim.weights_layer_desc_;
Expand All @@ -167,6 +210,7 @@ class RnnPrimitive {

RnnPrimitive& operator=(const RnnPrimitive& rnn_fwd_prim) {
if (this != &rnn_fwd_prim) {
this->attr_ = rnn_fwd_prim.attr_;
this->fwd_pd_ = rnn_fwd_prim.fwd_pd_;
this->primitive_ = rnn_fwd_prim.primitive_;
this->weights_layer_desc_ = rnn_fwd_prim.weights_layer_desc_;
Expand Down Expand Up @@ -196,9 +240,14 @@ class RnnPrimitive {
return workspace_desc_;
}

const mkldnn::primitive_attr& GetPrimAttr() const {
return *attr_;
}

private:
std::shared_ptr<void> fwd_pd_;
std::shared_ptr<mkldnn::primitive> primitive_;
shared_mkldnn_attr_t attr_;
mkldnn::memory::desc weights_layer_desc_;
mkldnn::memory::desc weights_iter_desc_;
mkldnn::memory::desc workspace_desc_;
Expand All @@ -207,7 +256,8 @@ class RnnPrimitive {
RnnPrimitive GetRnnFwdPrim(const MKLDNNRnnLayerParam& layer_param,
const bool is_train,
const NDArray& data,
const NDArray& params);
const NDArray& params,
const shared_mkldnn_attr_t attr = nullptr);

/*
* Use this to manage memory and primitive of MKL-DNN RNN forward inference.
Expand All @@ -217,10 +267,11 @@ class MKLDNNRnnForward {
MKLDNNRnnForward(const MKLDNNRnnLayerParam& layer_param,
const bool is_train,
const NDArray& data,
const NDArray& params)
const NDArray& params,
const shared_mkldnn_attr_t attr = nullptr)
: initialized_(false),
param_(layer_param),
fwd_inf_(GetRnnFwdPrim(layer_param, false, data, params)) {}
fwd_inf_(GetRnnFwdPrim(layer_param, false, data, params, attr)) {}

void SetNewDataMem(void* x,
void* hx,
Expand All @@ -240,6 +291,10 @@ class MKLDNNRnnForward {
return fwd_inf_.GetPrim();
}

void ResetFwd(const NDArray& data, const NDArray& params, const shared_mkldnn_attr_t& attr) {
fwd_inf_ = GetRnnFwdPrim(this->param_, false, data, params, attr);
}

const size_t GetSize(int dtype) const {
size_t bytes = mshadow::mshadow_sizeof(dtype);
size_t size = 0;
Expand Down Expand Up @@ -458,13 +513,13 @@ class MKLDNNRnnBackward {
*/
class MKLDNNRnnOp {
public:
explicit MKLDNNRnnOp(const RNNParam& param,
explicit MKLDNNRnnOp(const nnvm::NodeAttrs &attrs,
const int seq_len,
const int batch_size,
const int input_size)
: initialized_(false),
weights_version_(0),
full_param_(MKLDNNRnnFullParamParser(param, seq_len, batch_size, input_size)) {}
full_param_(MKLDNNRnnFullParamParser(attrs, seq_len, batch_size, input_size)) {}

void Forward(const OpContext& ctx,
const std::vector<NDArray>& inputs,
Expand Down
Loading

0 comments on commit 8c69a9f

Please sign in to comment.