Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[v1.x] Restore quantized RNN operator from MXNet 1.6 #20759

Merged
merged 8 commits into from
Dec 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/python_docs/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ dependencies:
- conda>=4.6.13
- pip
- python
- setuptools==49.6.0
- jupyter
- sphinx==2.4.0
- matplotlib
Expand Down
15 changes: 15 additions & 0 deletions include/mxnet/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,21 @@ using FAvoidQuantizeInput = std::function<bool (const NodeAttrs& attrs,
const size_t index,
const std::string quantize_granularity)>;

/*!
* \brief Register a function to determine if the input of a quantized operator
* needs to be quantized asymmetrically.
*/
using FNeedAsymQuantizeInput = std::function<bool (const NodeAttrs& attrs,
const size_t index)>;

/*!
* \brief Register a function to determine if the output of a quantized operator
* needs to be dequantized. This is usually used for the quantized operators
* which can produce fp32 outputs directly.
*/
using FAvoidDequantizeOutput = std::function<bool (const NodeAttrs& attrs,
const size_t index)>;

/*!
* \brief Register a function to determine if the input of a quantized operator
* needs to be calibrated. This is usually used for the quantized operators
Expand Down
23 changes: 16 additions & 7 deletions python/mxnet/io/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from ..ndarray import array
from ..ndarray import concat, tile

from .utils import _init_data, _has_instance, _getdata_by_idx
from .utils import _init_data, _has_instance, _getdata_by_idx, _slice_along_batch_axis

class DataDesc(namedtuple('DataDesc', ['name', 'shape'])):
"""DataDesc is used to store name, shape, type and layout
Expand Down Expand Up @@ -602,10 +602,12 @@ class NDArrayIter(DataIter):
The data name.
label_name : str, optional
The label name.
layout : str, optional
The data layout
"""
def __init__(self, data, label=None, batch_size=1, shuffle=False,
last_batch_handle='pad', data_name='data',
label_name='softmax_label'):
label_name='softmax_label', layout='NCHW'):
super(NDArrayIter, self).__init__(batch_size)

self.data = _init_data(data, allow_empty=False, default_name=data_name)
Expand All @@ -631,20 +633,27 @@ def __init__(self, data, label=None, batch_size=1, shuffle=False,
# used for 'roll_over'
self._cache_data = None
self._cache_label = None
self.layout = layout

@property
def provide_data(self):
"""The name and shape of data provided by this iterator."""
batch_axis = self.layout.find('N')
return [
DataDesc(k, tuple([self.batch_size] + list(v.shape[1:])), v.dtype)
DataDesc(k, tuple(list(v.shape[:batch_axis]) + \
[self.batch_size] + list(v.shape[batch_axis + 1:])),
v.dtype, layout=self.layout)
for k, v in self.data
]

@property
def provide_label(self):
"""The name and shape of label provided by this iterator."""
batch_axis = self.layout.find('N')
return [
DataDesc(k, tuple([self.batch_size] + list(v.shape[1:])), v.dtype)
DataDesc(k, tuple(list(v.shape[:batch_axis]) + \
[self.batch_size] + list(v.shape[batch_axis + 1:])),
v.dtype, layout=self.layout)
for k, v in self.label
]

Expand Down Expand Up @@ -681,7 +690,7 @@ def next(self):
data = self.getdata()
label = self.getlabel()
# iter should stop when last batch is not complete
if data[0].shape[0] != self.batch_size:
if data[0].shape[self.layout.find('N')] != self.batch_size:
# in this case, cache it for next epoch
self._cache_data = data
self._cache_label = label
Expand All @@ -697,7 +706,7 @@ def _getdata(self, data_source, start=None, end=None):
end = data_source[0][1].shape[0] if data_source else 0
s = slice(start, end)
return [
x[1][s]
_slice_along_batch_axis(x[1], s, self.layout.find('N'))
if isinstance(x[1], (np.ndarray, NDArray)) else
# h5py (only supports indices in increasing order)
array(x[1][sorted(self.idx[s])][[
Expand All @@ -716,7 +725,7 @@ def _concat(self, first_data, second_data):
concat(
first_data[i],
second_data[i],
dim=0
dim=self.layout.find('N')
) for i in range(len(first_data))
]

Expand Down
5 changes: 5 additions & 0 deletions python/mxnet/io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,8 @@ def _getdata_by_idx(data, idx):
shuffle_data.append((k, array(v.asnumpy()[idx], v.context)))

return shuffle_data

def _slice_along_batch_axis(data, s, batch_axis):
"""Apply slice along the batch axis"""
ret = data.slice_axis(axis=batch_axis, begin=s.start, end=s.stop)
return ret
83 changes: 69 additions & 14 deletions src/operator/nn/mkldnn/mkldnn_rnn-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,42 @@

#include "../../rnn-inl.h"
#include "./mkldnn_base-inl.h"
#include "../../quantization/quantized_rnn-inl.h"

namespace mxnet {
namespace op {

struct MKLDNNRnnParam : public dmlc::Parameter<MKLDNNRnnParam> {
bool quantized;

DMLC_DECLARE_PARAMETER(MKLDNNRnnParam) {
DMLC_DECLARE_FIELD(quantized).set_default(false).describe(
"Whether it's a quantized RNN operator");
}
};

inline void MKLDNNMemoryReorder(const mkldnn::memory& src, const mkldnn::memory& dst) {
#if DMLC_CXX11_THREAD_LOCAL
static thread_local std::unordered_map<OpSignature, mkldnn::reorder, OpHash> reorderPrimitives;
#else
static MX_THREAD_LOCAL std::unordered_map<OpSignature, mkldnn::reorder, OpHash> reorderPrimitives;
#endif
OpSignature key{};
key.AddSign(src);
key.AddSign(dst);

auto it = reorderPrimitives.find(key);
if (it == reorderPrimitives.end()) {
auto reorder = mkldnn::reorder(src, dst);
it = AddToCache(&reorderPrimitives, key, reorder);
}

mkldnn_args_map_t net_args;
net_args.emplace(MKLDNN_ARG_SRC, src);
net_args.emplace(MKLDNN_ARG_DST, dst);
MKLDNNStream::Get()->RegisterPrimArgs(it->second, net_args);
}

struct MKLDNNRnnLayerParam {
using memory = mkldnn::memory;
using dims = mkldnn::memory::dims;
Expand Down Expand Up @@ -65,6 +97,10 @@ struct MKLDNNRnnLayerParam {
size_t native_single_b_size; // bias size of a single cell from framework
size_t single_state_size; // state size of a single cell, hy, cy

bool quantized; // whether this layer is quantized
bool enable_u8_output; // true by default, only be false when it is the last fusion layer of the
// quantized rnn operator

MKLDNNRnnLayerParam(int num_layer,
int batch_size,
int seq_len,
Expand All @@ -79,18 +115,21 @@ struct MKLDNNRnnLayerParam {
batch_size(batch_size),
input_size(input_size),
state_size(state_size),
seq_len(seq_len) {}
seq_len(seq_len),
quantized(false),
enable_u8_output(false) {}

void SetDims();
};

typedef std::vector<MKLDNNRnnLayerParam> LayerParamVector;
struct MKLDNNRnnFullParam {
RNNParam default_param;
MKLDNNRnnParam mkldnn_param;
LayerParamVector layer_params;
};

MKLDNNRnnFullParam MKLDNNRnnFullParamParser(const RNNParam& rnn_param,
MKLDNNRnnFullParam MKLDNNRnnFullParamParser(const nnvm::NodeAttrs& attrs,
const int seq_len,
const int batch_size,
const int input_size);
Expand All @@ -102,7 +141,7 @@ class MKLDNNRnnMemMgr {
// The memory buffer in NDArray life-cycle
NDArray workspace_;
// This points to the memory buffer from a NDArray
char* curr_mem;
char* curr_mem = nullptr;
// The total bytes of the workspace of a MKLDNNRnnOp
size_t mem_size = 0;
// The current available memory bytes
Expand All @@ -113,7 +152,7 @@ class MKLDNNRnnMemMgr {
std::vector<std::shared_ptr<const mkldnn::memory>> mem_holder;

public:
void Init(dim_t size, const Context& ctx, int dtype = mshadow::kFloat32);
void Init(const dim_t size, const Context& ctx, int dtype = mshadow::kFloat32);

void RegisterMem(std::shared_ptr<const mkldnn::memory> mem) {
mem_holder.push_back(mem);
Expand All @@ -122,6 +161,8 @@ class MKLDNNRnnMemMgr {
mkldnn::memory* Alloc(const mkldnn::memory::desc& md);
};

typedef std::shared_ptr<mkldnn::primitive_attr> shared_mkldnn_attr_t;

/*
* Rnn Primitive.
*/
Expand All @@ -131,15 +172,15 @@ class RnnPrimitive {
* lstm_forward, lbr_gru_forward, vanilla_rnn_forward
*/
template <typename rnn_fwd, typename... Args>
static RnnPrimitive Create(Args&&... args) {
static RnnPrimitive Create(const shared_mkldnn_attr_t attr, Args&&... args) {
RnnPrimitive rnn_fwd_prim;
auto fwd_desc = typename rnn_fwd::desc(std::forward<Args>(args)...);
rnn_fwd_prim.fwd_pd_.reset(
new typename rnn_fwd::primitive_desc(fwd_desc, CpuEngine::Get()->get_engine()),
[](typename rnn_fwd::primitive_desc* pd) {
delete reinterpret_cast<typename rnn_fwd::primitive_desc*>(pd);
});
new typename rnn_fwd::primitive_desc(
fwd_desc, attr ? *attr : mkldnn::primitive_attr(), CpuEngine::Get()->get_engine()),
[](void* pd) { delete reinterpret_cast<typename rnn_fwd::primitive_desc*>(pd); });
auto fwd_pd = reinterpret_cast<typename rnn_fwd::primitive_desc*>(rnn_fwd_prim.fwd_pd_.get());
rnn_fwd_prim.attr_ = attr;
rnn_fwd_prim.weights_layer_desc_ = fwd_pd->weights_layer_desc();
rnn_fwd_prim.weights_iter_desc_ = fwd_pd->weights_iter_desc();
rnn_fwd_prim.workspace_desc_ = fwd_pd->workspace_desc();
Expand All @@ -150,6 +191,7 @@ class RnnPrimitive {
}

RnnPrimitive() {
this->attr_ = nullptr;
this->fwd_pd_ = nullptr;
this->primitive_ = nullptr;
this->weights_layer_desc_ = mkldnn::memory::desc();
Expand All @@ -158,6 +200,7 @@ class RnnPrimitive {
}

RnnPrimitive(const RnnPrimitive& rnn_fwd_prim) {
this->attr_ = rnn_fwd_prim.attr_;
this->fwd_pd_ = rnn_fwd_prim.fwd_pd_;
this->primitive_ = rnn_fwd_prim.primitive_;
this->weights_layer_desc_ = rnn_fwd_prim.weights_layer_desc_;
Expand All @@ -167,6 +210,7 @@ class RnnPrimitive {

RnnPrimitive& operator=(const RnnPrimitive& rnn_fwd_prim) {
if (this != &rnn_fwd_prim) {
this->attr_ = rnn_fwd_prim.attr_;
this->fwd_pd_ = rnn_fwd_prim.fwd_pd_;
this->primitive_ = rnn_fwd_prim.primitive_;
this->weights_layer_desc_ = rnn_fwd_prim.weights_layer_desc_;
Expand Down Expand Up @@ -196,9 +240,14 @@ class RnnPrimitive {
return workspace_desc_;
}

const mkldnn::primitive_attr& GetPrimAttr() const {
return *attr_;
}

private:
std::shared_ptr<void> fwd_pd_;
std::shared_ptr<mkldnn::primitive> primitive_;
shared_mkldnn_attr_t attr_;
mkldnn::memory::desc weights_layer_desc_;
mkldnn::memory::desc weights_iter_desc_;
mkldnn::memory::desc workspace_desc_;
Expand All @@ -207,7 +256,8 @@ class RnnPrimitive {
RnnPrimitive GetRnnFwdPrim(const MKLDNNRnnLayerParam& layer_param,
const bool is_train,
const NDArray& data,
const NDArray& params);
const NDArray& params,
const shared_mkldnn_attr_t attr = nullptr);

/*
* Use this to manage memory and primitive of MKL-DNN RNN forward inference.
Expand All @@ -217,10 +267,11 @@ class MKLDNNRnnForward {
MKLDNNRnnForward(const MKLDNNRnnLayerParam& layer_param,
const bool is_train,
const NDArray& data,
const NDArray& params)
const NDArray& params,
const shared_mkldnn_attr_t attr = nullptr)
: initialized_(false),
param_(layer_param),
fwd_inf_(GetRnnFwdPrim(layer_param, false, data, params)) {}
fwd_inf_(GetRnnFwdPrim(layer_param, false, data, params, attr)) {}

void SetNewDataMem(void* x,
void* hx,
Expand All @@ -240,6 +291,10 @@ class MKLDNNRnnForward {
return fwd_inf_.GetPrim();
}

void ResetFwd(const NDArray& data, const NDArray& params, const shared_mkldnn_attr_t& attr) {
fwd_inf_ = GetRnnFwdPrim(this->param_, false, data, params, attr);
}

const size_t GetSize(int dtype) const {
size_t bytes = mshadow::mshadow_sizeof(dtype);
size_t size = 0;
Expand Down Expand Up @@ -458,13 +513,13 @@ class MKLDNNRnnBackward {
*/
class MKLDNNRnnOp {
public:
explicit MKLDNNRnnOp(const RNNParam& param,
explicit MKLDNNRnnOp(const nnvm::NodeAttrs &attrs,
const int seq_len,
const int batch_size,
const int input_size)
: initialized_(false),
weights_version_(0),
full_param_(MKLDNNRnnFullParamParser(param, seq_len, batch_size, input_size)) {}
full_param_(MKLDNNRnnFullParamParser(attrs, seq_len, batch_size, input_size)) {}

void Forward(const OpContext& ctx,
const std::vector<NDArray>& inputs,
Expand Down
Loading