Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix sanity
Browse files Browse the repository at this point in the history
  • Loading branch information
Bartlomiej Gawrych committed Nov 26, 2021
1 parent 552f0c2 commit 82c647a
Show file tree
Hide file tree
Showing 10 changed files with 62 additions and 83 deletions.
2 changes: 1 addition & 1 deletion python/mxnet/io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,4 @@ def _getdata_by_idx(data, idx):
def _slice_along_batch_axis(data, s, batch_axis):
"""Apply slice along the batch axis"""
ret = data.slice_axis(axis=batch_axis, begin=s.start, end=s.stop)
return ret
return ret
4 changes: 1 addition & 3 deletions src/operator/nn/mkldnn/mkldnn_rnn-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -246,9 +246,7 @@ class RnnPrimitive {
return workspace_desc_;
}

const mkldnn::primitive_attr& GetPrimAttr() const {
return *attr_;
}
const mkldnn::primitive_attr &GetPrimAttr() const { return *attr_; }

private:
std::shared_ptr<void> fwd_pd_;
Expand Down
30 changes: 4 additions & 26 deletions src/operator/nn/mkldnn/mkldnn_rnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ MKLDNNRnnFullParam MKLDNNRnnFullParamParser(const NodeAttrs& attrs,
// Set dims, workspace size, state_outputs, quantized and enable_u8_output flag
for (auto& layer_param : layer_params) {
layer_param.SetDims();
layer_param.state_outputs = rnn_param.state_outputs;
layer_param.state_outputs = rnn_param.state_outputs;
layer_param.quantized = full_param.mkldnn_param.quantized;
layer_param.enable_u8_output = true;
}
Expand Down Expand Up @@ -472,28 +472,6 @@ void MKLDNNRnnForward::SetNewDataMem(void* x,
}
}

// inline void MKLDNNMemoryReorder(const mkldnn::memory& src, const mkldnn::memory& dst) {
// #if DMLC_CXX11_THREAD_LOCAL
// static thread_local std::unordered_map<OpSignature, mkldnn::reorder, OpHash> reorderPrimitives;
// #else
// static MX_THREAD_LOCAL std::unordered_map<OpSignature, mkldnn::reorder, OpHash> reorderPrimitives;
// #endif
// OpSignature key{};
// key.AddSign(src);
// key.AddSign(dst);

// auto it = reorderPrimitives.find(key);
// if (it == reorderPrimitives.end()) {
// auto reorder = mkldnn::reorder(src, dst);
// it = AddToCache(&reorderPrimitives, key, reorder);
// }

// mkldnn_args_map_t net_args;
// net_args.emplace(MKLDNN_ARG_SRC, src);
// net_args.emplace(MKLDNN_ARG_DST, dst);
// MKLDNNStream::Get()->RegisterPrimArgs(it->second, net_args);
// }

/*
* Reorder the concatenated weights memory to a efficient memory block
* with primitive-prefered format.
Expand Down Expand Up @@ -621,7 +599,7 @@ void MKLDNNRnnForward::SetWeightsMem(MKLDNNRnnMemMgr* mgr,
// convert void* to char* for arithmetic operations
char* weights_ptr = static_cast<char*>(w_ptr);
size_t wx_bytes = GetRnnGatesNum(param_.mode) * param_.state_size * param_.input_size *
dtype_bytes; //* DIMS: ngates x state_size x input_size
dtype_bytes; //* DIMS: ngates x state_size x input_size
size_t wh_bytes = GetRnnGatesNum(param_.mode) * param_.state_size * param_.state_size *
dtype_bytes; //* DIMS: ngates x state_size x state_size
char* l2r_wx = weights_ptr;
Expand All @@ -635,7 +613,7 @@ void MKLDNNRnnForward::SetWeightsMem(MKLDNNRnnMemMgr* mgr,
ConcatWeights(*weights_iter_r_, 1, {l2r_wh, r2l_wh}, format_tag::ldgoi);
} else if (param_.num_layer == 1 && !param_.bidirectional) {
//* single uni-directional layer, no concatenate operator needed
// tutttaj
// tutttaj
std::memcpy(weights_layer_r_->get_data_handle(), l2r_wx, wx_bytes);
std::memcpy(weights_iter_r_->get_data_handle(), l2r_wh, wh_bytes);
} else if (param_.num_layer > 1 && !param_.bidirectional) {
Expand Down Expand Up @@ -1097,7 +1075,7 @@ void MKLDNNRnnOp::Forward(const OpContext& ctx,
weights_version_ = inputs[rnn_enum::kParams].version();
}

if (dmlc::GetEnv("MXNET_RNN_USE_WEIGHT_CACHE", 0) && !initialized_) {
if (dmlc::GetEnv("MXNET_RNN_USE_WEIGHT_CACHE", 0) && !initialized_) {
LOG(INFO) << "The current weight of RNN is assumed to be fixed and cached during "
"the whole inference pipeline. Please set MXNET_RNN_USE_WEIGHT_CACHE=0, if "
"the weight changed at runtime.";
Expand Down
16 changes: 8 additions & 8 deletions src/operator/quantization/mkldnn/mkldnn_quantize_asym-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,24 +26,24 @@
#define MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_QUANTIZE_ASYM_INL_H_
#if MXNET_USE_MKLDNN == 1

#include "../../nn/mkldnn/mkldnn_base-inl.h"
#include "../quantize_asym-inl.h"
#include <memory>
#include <vector>
#include "../../nn/mkldnn/mkldnn_base-inl.h"
#include "../quantize_asym-inl.h"

namespace mxnet {
namespace op {

class MKLDNNQuantizeAsymOp {
public:
public:
explicit MKLDNNQuantizeAsymOp(const nnvm::NodeAttrs &attrs)
: param_(nnvm::get<QuantizeAsymParam>(attrs.parsed)) {}

void Forward(const OpContext &ctx, const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs);

private:
private:
QuantizeAsymParam param_;
bool initialized_{false};
float cached_scale_{0.f};
Expand Down Expand Up @@ -159,8 +159,8 @@ void MKLDNNQuantizeAsymForward(const OpStatePtr &state_ptr,
}
}

} // namespace op
} // namespace mxnet
} // namespace op
} // namespace mxnet

#endif // MXNET_USE_MKLDNN == 1
#endif // MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_QUANTIZE_ASYM_INL_H_
#endif // MXNET_USE_MKLDNN == 1
#endif // MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_QUANTIZE_ASYM_INL_H_
16 changes: 8 additions & 8 deletions src/operator/quantization/mkldnn/mkldnn_quantized_rnn-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,16 @@

#if MXNET_USE_MKLDNN == 1

#include <vector>
#include "../../nn/mkldnn/mkldnn_rnn-inl.h"
#include "../../rnn-inl.h"
#include "../quantized_rnn-inl.h"
#include <vector>

namespace mxnet {
namespace op {

class MKLDNNQuantizedRnnOp {
public:
public:
explicit MKLDNNQuantizedRnnOp(const nnvm::NodeAttrs &attrs, const int seq_len,
const int batch_size, const int input_size)
: initialized_(false), weights_ver_(0),
Expand All @@ -50,13 +50,13 @@ class MKLDNNQuantizedRnnOp {
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs);

private:
private:
bool initialized_;
size_t weights_ver_;
shared_mkldnn_attr_t rnn_attr_;
MKLDNNRnnFullParam full_param_;
MKLDNNRnnMemMgr mgr_;
std::vector<MKLDNNRnnForward> fwd_inf_vec_; // forward inference layers
std::vector<MKLDNNRnnForward> fwd_inf_vec_; // forward inference layers

// Used to store the intermediate results of multi-layer
std::vector<mkldnn::memory *> dst_;
Expand All @@ -72,8 +72,8 @@ class MKLDNNQuantizedRnnOp {
const std::vector<NDArray> &outputs);
};

} // namespace op
} // namespace mxnet
} // namespace op
} // namespace mxnet

#endif // MXNET_USE_MKLDNN == 1
#endif // MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_QUANTIZED_RNN_INL_H_
#endif // MXNET_USE_MKLDNN == 1
#endif // MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_QUANTIZED_RNN_INL_H_
24 changes: 12 additions & 12 deletions src/operator/quantization/mkldnn/mkldnn_quantized_rnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ GetMKLDNNRnnWeightsQParams(const MKLDNNRnnFullParam &full_param, float *w_ptr) {
}
}
std::vector<float> w_max(4 * layer_param0.state_size, 0.0);
const index_t input_size = layer_param0.input_size; // input
const index_t state_size = layer_param0.state_size; // state
const index_t gates_nblks = 4 * layer_param0.state_size; // gates * state
const index_t input_size = layer_param0.input_size; // input
const index_t state_size = layer_param0.state_size; // state
const index_t gates_nblks = 4 * layer_param0.state_size; // gates * state
for (index_t go = 0; go < gates_nblks; ++go) {
float tmp_max = w_max[go];
for (index_t i = 0; i < input_size; ++i) {
Expand Down Expand Up @@ -141,12 +141,12 @@ void MKLDNNQuantizedRnnOp::Init(const OpContext &ctx,

const size_t num_fusion = full_param_.layer_params.size();
if (fwd_inf_vec_.size() < num_fusion) {
size_t buffer_size = 0; // Element number, instead of bytes, in the buffer
size_t buffer_size = 0; // Element number, instead of bytes, in the buffer
for (auto &layer_param : full_param_.layer_params) {
buffer_size += layer_param.workspace_size + layer_param.reserve_size;
}
buffer_size += outputs[rnn_enum::kOut].data().Size() * (num_fusion - 1);
buffer_size += kMKLDNNAlign * num_fusion * 5; // Add margin for alignment
buffer_size += kMKLDNNAlign * num_fusion * 5; // Add margin for alignment

for (auto &layer_param : full_param_.layer_params) {
fwd_inf_vec_.emplace_back(layer_param, false, inputs[rnn_enum::kData],
Expand All @@ -164,7 +164,7 @@ void MKLDNNQuantizedRnnOp::Init(const OpContext &ctx,
size_t directions = fwd_layer.GetParam().bidirectional ? 2 : 1;
size_t layer_weights_bytes = single_w_bytes * directions;
size_t layer_bias_bytes =
single_b_bytes * directions; // Native MXNet has double bias
single_b_bytes * directions; // Native MXNet has double bias

if (!fwd_layer.IsInitialized())
fwd_layer.SetWeightsMem(&(this->mgr_), weights_ptr, bias_ptr, false,
Expand Down Expand Up @@ -289,9 +289,9 @@ void MKLDNNQuantizedRnnOp::Forward(const OpContext &op_ctx,
char *src = static_cast<char *>(inputs[rnn_enum::kData].data().dptr_);
char *src_state = static_cast<char *>(inputs[rnn_enum::kState].data().dptr_);
char *dst = static_cast<char *>(out_mem.second->get_data_handle());
char *dst_state = nullptr; // Output state
char *src_state_cell = nullptr; // Used in LSTM for cell state
char *dst_state_cell = nullptr; // Used in LSTM for cell state
char *dst_state = nullptr; // Output state
char *src_state_cell = nullptr; // Used in LSTM for cell state
char *dst_state_cell = nullptr; // Used in LSTM for cell state
const size_t cell_bytes =
(default_param.bidirectional + 1) * default_param.batch_size_ *
default_param.state_size * mshadow::mshadow_sizeof(data_dtype);
Expand Down Expand Up @@ -368,7 +368,7 @@ void MKLDNNQuantizedRnnOp::Forward(const OpContext &op_ctx,
MKLDNNStream::Get()->Submit();
}

} // namespace op
} // namespace mxnet
} // namespace op
} // namespace mxnet

#endif // MXNET_USE_MKLDNN == 1
#endif // MXNET_USE_MKLDNN == 1
19 changes: 10 additions & 9 deletions src/operator/quantization/quantize_asym-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,17 @@
#ifndef MXNET_OPERATOR_QUANTIZATION_QUANTIZE_ASYM_INL_H_
#define MXNET_OPERATOR_QUANTIZATION_QUANTIZE_ASYM_INL_H_

#include "../mshadow_op.h"
#include "../mxnet_op.h"
#include "../tensor/broadcast_reduce_op.h"
#include "./quantization_utils.h"
#include <dmlc/logging.h>
#include <dmlc/parameter.h>
#include <mshadow/tensor.h>
#include <mxnet/operator_util.h>
#include <vector>

#include "../mshadow_op.h"
#include "../mxnet_op.h"
#include "../tensor/broadcast_reduce_op.h"
#include "./quantization_utils.h"

namespace mxnet {
namespace op {

Expand Down Expand Up @@ -69,7 +70,7 @@ struct quantize_asymmetric {
};

template <typename xpu> class QuantizeAsymOp {
public:
public:
explicit QuantizeAsymOp(const nnvm::NodeAttrs &attrs) : attrs_(attrs) {}

void Forward(const OpContext &ctx, const std::vector<TBlob> &inputs,
Expand Down Expand Up @@ -145,7 +146,7 @@ template <typename xpu> class QuantizeAsymOp {
}
}

private:
private:
nnvm::NodeAttrs attrs_;
};

Expand All @@ -158,7 +159,7 @@ void QuantizeAsymForward(const OpStatePtr &state_ptr, const OpContext &ctx,
op.Forward(ctx, inputs, req, outputs);
}

} // namespace op
} // namespace mxnet
} // namespace op
} // namespace mxnet

#endif // MXNET_OPERATOR_QUANTIZATION_QUANTIZE_ASYM_INL_H_
#endif // MXNET_OPERATOR_QUANTIZATION_QUANTIZE_ASYM_INL_H_
4 changes: 2 additions & 2 deletions src/operator/quantization/quantize_asym.cc
Original file line number Diff line number Diff line change
Expand Up @@ -158,5 +158,5 @@ where `scale = uint8_range / (max_range - min_range)` and
"A ndarray/symbol of type `float32`")
.add_arguments(QuantizeAsymParam::__FIELDS__());

} // namespace op
} // namespace mxnet
} // namespace op
} // namespace mxnet
8 changes: 4 additions & 4 deletions src/operator/quantization/quantized_rnn-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ namespace quantized_rnn {
enum QuantizedRnnInputs { kData, kParams, kState, kStateCell };
enum QuantizedRnnInputMinMax { kDataScale, kDataShift };
enum QuantizedRnnOutputs { kOut, kStateOut, kStateCellOut };
} // namespace quantized_rnn
} // namespace quantized_rnn

} // namespace op
} // namespace mxnet
} // namespace op
} // namespace mxnet

#endif // MXNET_OPERATOR_QUANTIZATION_QUANTIZED_RNN_INL_H_
#endif // MXNET_OPERATOR_QUANTIZATION_QUANTIZED_RNN_INL_H_
22 changes: 12 additions & 10 deletions src/operator/quantization/quantized_rnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@
* \author Zixuan Wei
*/

#include "../rnn-inl.h"
#include "./quantization_utils.h"
#include "./quantized_rnn-inl.h"
#include <dmlc/logging.h>
#include <utility>
#include <vector>

#include "../rnn-inl.h"
#include "./quantization_utils.h"
#include "./quantized_rnn-inl.h"

#if MXNET_USE_MKLDNN == 1
#include "./mkldnn/mkldnn_quantized_rnn-inl.h"
#endif
Expand Down Expand Up @@ -115,13 +117,13 @@ bool QuantizedRnnShape(const nnvm::NodeAttrs &attrs,

out_shape->clear();
out_shape->push_back({dshape[0], batch_size,
directions * state_size}); // output dim: [T, N, C]
directions * state_size}); // output dim: [T, N, C]
if (param.state_outputs) {
out_shape->push_back(
{total_lyrs, batch_size, state_size}); // state dim: [L*D, N, C]
{total_lyrs, batch_size, state_size}); // state dim: [L*D, N, C]
if (param.mode == rnn_enum::kLstm)
out_shape->push_back(
{total_lyrs, batch_size, state_size}); // cell dim: [L*D, N, C]
{total_lyrs, batch_size, state_size}); // cell dim: [L*D, N, C]
}
return true;
}
Expand Down Expand Up @@ -250,7 +252,7 @@ void QuantizedRnnForwardCPUEx(const OpStatePtr &state_ptr, const OpContext &ctx,
MKLDNNQuantizedRnnOp &op = state_ptr.get_state<MKLDNNQuantizedRnnOp>();
op.Forward(ctx, in_data, req, out_data);
}
#endif // MXNET_USE_MKLDNN == 1
#endif // MXNET_USE_MKLDNN == 1

bool NeedAsymQuantizeRnnInput(const NodeAttrs &attrs,
const size_t index_to_check) {
Expand Down Expand Up @@ -352,7 +354,7 @@ NNVM_REGISTER_OP(RNN)
LOG(INFO) << "Quantized RNN is not supported by this MXNet release. Please enable MKL-DNN to "
<< "use the feature.";
return QuantizeType::kNone;
#endif // MXNET_USE_MKLDNN == 1
#endif // MXNET_USE_MKLDNN == 1
})
.set_attr<FQuantizedOp>("FQuantizedOp",
[](const NodeAttrs &attrs) {
Expand All @@ -373,5 +375,5 @@ NNVM_REGISTER_OP(RNN)
.set_attr<FAvoidDequantizeOutput>("FAvoidDequantizeOutput",
AvoidRnnDequantizeOutput);

} // namespace op
} // namespace mxnet
} // namespace op
} // namespace mxnet

0 comments on commit 82c647a

Please sign in to comment.