Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for int8 x uint8 for MatMulInteger, and int16 x int16 custom op #1391

Merged
merged 7 commits into from
Aug 29, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions onnxruntime/contrib_ops/cpu/matmul_integer16.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "contrib_ops/cpu/matmul_integer16.h"
#include "core/providers/cpu/math/matmul_helper.h"

namespace onnxruntime {
namespace contrib {

ONNX_OPERATOR_KERNEL_EX(
MatMulInteger16,
askhade marked this conversation as resolved.
Show resolved Hide resolved
kMSDomain,
1,
kCpuExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T1", DataTypeImpl::GetTensorType<int16_t>())
.TypeConstraint("T2", DataTypeImpl::GetTensorType<int16_t>())
.TypeConstraint("T3", DataTypeImpl::GetTensorType<int32_t>()),
MatMulInteger16<int16_t, int16_t, int32_t>);

template <>
Status MatMulInteger16<int16_t, int16_t, int32_t>::Compute(OpKernelContext* ctx) const {
auto A = ctx->Input<Tensor>(0);
auto B = ctx->Input<Tensor>(1);
ORT_ENFORCE(A != nullptr && B != nullptr);

MatMulComputeHelper helper;
ORT_RETURN_IF_ERROR(helper.Compute(A->Shape(), B->Shape()));
Tensor* Y = ctx->Output(0, helper.OutputShape());

for (int i = 0; i < static_cast<int>(helper.OutputOffsets().size()); i++) {
EigenCastGEMM<int16_t, int16_t, int32_t>(
A->template Data<int16_t>() + helper.LeftOffsets()[i],
B->template Data<int16_t>() + helper.RightOffsets()[i],
Y->template MutableData<int32_t>() + helper.OutputOffsets()[i],
static_cast<int>(helper.M()),
static_cast<int>(helper.N()),
static_cast<int>(helper.K()));
}

return Status::OK();
}

} // namespace contrib
} // namespace onnxruntime
22 changes: 22 additions & 0 deletions onnxruntime/contrib_ops/cpu/matmul_integer16.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/common/common.h"
#include "core/framework/op_kernel.h"
#include "core/util/math_cpuonly.h"

namespace onnxruntime {
namespace contrib {

template <typename T1, typename T2, typename T3>
class MatMulInteger16 final : public OpKernel {
public:
MatMulInteger16(const OpKernelInfo& info) : OpKernel(info) {
}

Status Compute(OpKernelContext* context) const override;
};
} // namespace contrib
} // namespace onnxruntime
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/cpu_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Range);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, WordConvEmbedding);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, GatherND);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulInteger16);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MurmurHash3);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, MaxpoolWithMask);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Pad);
Expand Down Expand Up @@ -87,6 +88,7 @@ void RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, WordConvEmbedding)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, GatherND)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MurmurHash3)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulInteger16)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, MaxpoolWithMask)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Pad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Unique)>,
Expand Down
39 changes: 38 additions & 1 deletion onnxruntime/core/graph/contrib_ops/contrib_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ void convPoolShapeInference(
int input1Idx,
int input2Idx);
void globalPoolTypeShapeInference(ONNX_NAMESPACE::InferenceContext& ctx);
void matmulShapeInference(
ONNX_NAMESPACE::InferenceContext& ctx,
int input1Idx,
int input2Idx);
} // namespace ONNX_NAMESPACE

namespace onnxruntime {
Expand Down Expand Up @@ -1158,6 +1162,39 @@ of [N, 0] then [N, 0].
updateOutputShape(ctx, 0, output_shape);
});

ONNX_CONTRIB_OPERATOR_SCHEMA(MatMulInteger16)
askhade marked this conversation as resolved.
Show resolved Hide resolved
.SetDomain(kMSDomain)
.SinceVersion(1)
.SetDoc(R"DOC(
Matrix product that behaves like numpy.matmul: https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.matmul.html.
The production MUST never overflow. The accumulation may overflow if and only if in 32 bits.)DOC")
.Input(0, "A", "N-dimensional matrix A", "T1")
.Input(1, "B", "N-dimensional matrix B", "T2")
.Output(0, "Y", "Matrix multiply results from A * B", "T3")
.TypeConstraint("T1", {"tensor(int16)", "tensor(uint16)"}, "Constrain input A data types as 16-bit integer tensor")
.TypeConstraint("T2", {"tensor(int16)", "tensor(uint16)"}, "Constrain input B data types as 16-bit integer tensor")
.TypeConstraint("T3",
{"tensor(int32)", "tensor(uint32)"},
"Constrain output Y data types as 32-bit integer tensor."
"T3 must be tensor(uint32) when both T1 and T2 are tensor(uint16),"
"or must be tensor(int32) when either T1 or T2 is tensor(int16).")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
auto a_type = ctx.getInputType(0);
auto b_type = ctx.getInputType(1);
auto y_type = ctx.getOutputType(0);
if (nullptr == a_type || nullptr == b_type || nullptr == y_type ||
a_type->value_case() != ONNX_NAMESPACE::TypeProto::kTensorType ||
b_type->value_case() != ONNX_NAMESPACE::TypeProto::kTensorType) {
fail_type_inference(
"inputs are expected to have tensor type and output type should not be null.");
}

// Right now we only support int32
y_type->mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto::INT32);

matmulShapeInference(ctx, 0, 1);
});

ONNX_CONTRIB_OPERATOR_SCHEMA(ReduceSumInteger)
.SetDomain(kMSDomain)
.SinceVersion(1)
Expand Down Expand Up @@ -1599,4 +1636,4 @@ Example 4:
#endif
}
} // namespace contrib
} // namespace onnxruntime
} // namespace onnxruntime
6 changes: 4 additions & 2 deletions onnxruntime/core/providers/cpu/cpu_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, uint8_t, QuantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, int8_t, QuantizeLinear);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, QLinearMatMul);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, MatMulInteger);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, uint8_t, MatMulInteger);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, int8_t, MatMulInteger);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, ConvInteger);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, QLinearConv);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, bool, Slice);
Expand Down Expand Up @@ -564,7 +565,8 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, uint8_t, QuantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, int8_t, QuantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, QLinearMatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, MatMulInteger)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, uint8_t, MatMulInteger)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, int8_t, MatMulInteger)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, ConvInteger)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, QLinearConv)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, bool, Slice)>,
Expand Down
62 changes: 59 additions & 3 deletions onnxruntime/core/providers/cpu/math/matmul_integer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,32 @@
namespace onnxruntime {

// only register this operator if low precision computation is enabled.
ONNX_OPERATOR_KERNEL_EX(
ONNX_OPERATOR_TYPED_KERNEL_EX(
MatMulInteger,
kOnnxDomain,
10,
uint8_t,
kCpuExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T1", DataTypeImpl::GetTensorType<uint8_t>())
.TypeConstraint("T2", DataTypeImpl::GetTensorType<uint8_t>())
.TypeConstraint("T3", DataTypeImpl::GetTensorType<int32_t>()),
MatMulInteger<uint8_t, uint8_t, int32_t>);
MatMulInteger<uint8_t, uint8_t>);

ONNX_OPERATOR_TYPED_KERNEL_EX(
MatMulInteger,
kOnnxDomain,
10,
int8_t,
kCpuExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T1", DataTypeImpl::GetTensorType<uint8_t>())
.TypeConstraint("T2", DataTypeImpl::GetTensorType<int8_t>())
.TypeConstraint("T3", DataTypeImpl::GetTensorType<int32_t>()),
MatMulInteger<uint8_t, int8_t>);

template <>
Status MatMulInteger<uint8_t, uint8_t, int32_t>::Compute(OpKernelContext* ctx) const {
Status MatMulInteger<uint8_t, uint8_t>::Compute(OpKernelContext* ctx) const {
auto a = ctx->Input<Tensor>(0);
auto b = ctx->Input<Tensor>(1);
ORT_ENFORCE(a != nullptr && b != nullptr);
Expand Down Expand Up @@ -60,7 +73,50 @@ Status MatMulInteger<uint8_t, uint8_t, int32_t>::Compute(OpKernelContext* ctx) c
static_cast<int>(helper.N()),
nullptr);
}
return Status::OK();
}

template <>
Status MatMulInteger<uint8_t, int8_t>::Compute(OpKernelContext* ctx) const {
auto a = ctx->Input<Tensor>(0);
auto b = ctx->Input<Tensor>(1);
ORT_ENFORCE(a != nullptr && b != nullptr);

MatMulComputeHelper helper;
ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b->Shape()));
Tensor* y = ctx->Output(0, helper.OutputShape());

if (has_a_zero_point_ || has_b_zero_point_) {
// currently zero point is only supported in Gemmlowp path above
// in future, the selection of Eigen/Gemmlowp/mklml/etc. should be in a common math library like SGEMM

auto IsZeroPointTensorAllZero = [](OpKernelContext* ctx, int input_idx) -> bool {
auto t = ctx->Input<Tensor>(input_idx);
ORT_ENFORCE(t->Shape().NumDimensions() <= 1 && t->Shape().Size() == 1,
"Currently only scalar zero_point is supported. TODO: add per channel zero point support.");
ORT_ENFORCE(t->DataType() == DataTypeImpl::GetType<int8_t>() ||
t->DataType() == DataTypeImpl::GetType<uint8_t>());
auto data = reinterpret_cast<const int8_t*>(t->DataRaw());
auto vec = std::vector<int8_t>(data, data + t->Shape().Size());
return std::all_of(vec.begin(), vec.end(), [](int8_t v) { return v == 0; });
};

if ((has_a_zero_point_ && !IsZeroPointTensorAllZero(ctx, 2)) ||
(has_b_zero_point_ && !IsZeroPointTensorAllZero(ctx, 3))) {
ORT_NOT_IMPLEMENTED("MatMulInteger: Unsupported input types with zero point");
}
}

// NOTE: Eigen based implementation is a reference implementation for accuracy only
for (int i = 0; i < static_cast<int>(helper.OutputOffsets().size()); i++) {
EigenCastGEMM<uint8_t, int8_t, int32_t>(
a->template Data<uint8_t>() + helper.LeftOffsets()[i],
b->template Data<int8_t>() + helper.RightOffsets()[i],
y->template MutableData<int32_t>() + helper.OutputOffsets()[i],
static_cast<int>(helper.M()),
static_cast<int>(helper.N()),
static_cast<int>(helper.K()));
}
return Status::OK();
}
} // namespace onnxruntime
6 changes: 3 additions & 3 deletions onnxruntime/core/providers/cpu/math/matmul_integer.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@

namespace onnxruntime {

template <typename T1, typename T2, typename T3>
template <typename T1, typename T2>
class MatMulInteger final : public OpKernel {
public:
MatMulInteger(const OpKernelInfo& info) : OpKernel(info) {
has_a_zero_point_ = false;
has_b_zero_point_ = false;
if (info.GetInputCount() > 2) {
has_a_zero_point_ = true;
has_a_zero_point_ = true;
}
if (info.GetInputCount() > 3) {
has_b_zero_point_ = true;
Expand All @@ -29,4 +29,4 @@ class MatMulInteger final : public OpKernel {
bool has_a_zero_point_;
bool has_b_zero_point_;
};
} // namespace onnxruntime
} // namespace onnxruntime
14 changes: 12 additions & 2 deletions onnxruntime/core/util/math_cpuonly.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
#include "onnxruntime_config.h"
// external/eigen/Eigen/src/Core/AssignEvaluator.h:86:63:
// error: enum constant in boolean context [-Werror=int-in-bool-context]
#if defined(__GNUC__) && __GNUC__>=7
#if defined(__GNUC__) && __GNUC__ >= 7
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wint-in-bool-context"
#ifdef HAS_DEPRECATED_COPY
Expand All @@ -30,7 +30,7 @@

#include "Eigen/Core"

#if defined(__GNUC__) && __GNUC__>=7
#if defined(__GNUC__) && __GNUC__ >= 7
#pragma GCC diagnostic pop
#endif

Expand Down Expand Up @@ -101,4 +101,14 @@ void FuseActivation(const std::string& activation, T* y_data, size_t size, float
}
}

// cast TA and TB to TC, and do matrix multiply in Eigen
// note that inputs/outputs is row-major, while Eigen is col-major
// so (M, K) x (K, N) -> (M, N) becomes (N, K) x (K, M) -> (N, M) in Eigen
template <typename TA, typename TB, typename TY>
void EigenCastGEMM(const TA* A_data, const TB* B_data, TY* Y_data, int M, int N, int K) {
auto A = ConstEigenMatrixMap<TA>(A_data, K, M);
auto B = ConstEigenMatrixMap<TB>(B_data, N, K);
EigenMatrixMap<TY>(Y_data, N, M) = B.template cast<TY>() * A.template cast<TY>();
}

} // namespace onnxruntime
41 changes: 41 additions & 0 deletions onnxruntime/test/contrib_ops/matmul_integer16_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "gtest/gtest.h"
#include "test/providers/provider_test_utils.h"

#include "core/common/common.h"
#include "core/framework/op_kernel.h"
#include "core/util/math_cpuonly.h"

namespace onnxruntime {
namespace test {

TEST(MatmulInteger16OpTest, MatMulInteger16_1) {
OpTester test("MatMulInteger16", 1, onnxruntime::kMSDomain);
test.AddInput<int16_t>("T1", {1, 1}, {15});
test.AddInput<int16_t>("T2", {1, 1}, {16});
test.AddOutput<int32_t>("T3", {1, 1}, {240});
test.Run();
}

TEST(MatmulInteger16OpTest, MatMulInteger16_2) {
OpTester test("MatMulInteger16", 1, onnxruntime::kMSDomain);
test.AddInput<int16_t>("T1", {1, 2}, {-7, 10});
test.AddInput<int16_t>("T2", {2, 1}, {-8, -11});
test.AddOutput<int32_t>("T3", {1, 1}, {-54});
test.Run();
}

TEST(MatmulInteger16OpTest, MatMulInteger16_3) {
OpTester test("MatMulInteger16", 1, onnxruntime::kMSDomain);
test.AddInput<int16_t>("T1", {3, 2}, {-7, 10, 10, -1113, 22, -356});
test.AddInput<int16_t>("T2", {2, 4}, {-8, -11, 13, 14, -99, 1234, 321, -6});
test.AddOutput<int32_t>("T3", {3, 4}, {-934, 12417, 3119, -158,
110107, -1373552, -357143, 6818,
35068, -439546, -113990, 2444});
test.Run();
}

} // namespace test
} // namespace onnxruntime
Loading