Skip to content

Commit

Permalink
Address CR from MatMulInteger PR in GitHub
Browse files Browse the repository at this point in the history
  • Loading branch information
KeDengMS committed Jul 16, 2019
1 parent 74dac1c commit a78ff58
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 42 deletions.
19 changes: 17 additions & 2 deletions onnxruntime/core/providers/cpu/math/matmul_integer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,25 @@ Status MatMulInteger::Compute(OpKernelContext* ctx) const {
static_cast<int>(helper.K()));
}
} else {
if (has_a_zero_point_) {
if (has_a_zero_point_ || has_b_zero_point_) {
// currently zero point is only supported in Gemmlowp path above
// in future, the selection of Eigen/Gemmlowp/mklml/etc. should be in a common math library like SGEMM
ORT_NOT_IMPLEMENTED("MatMulInteger: Unsupported input types with zero point");

auto IsZeroPointTensorAllZero = [](OpKernelContext* ctx, int input_idx) -> bool {
auto t = ctx->Input<Tensor>(input_idx);
ORT_ENFORCE(t->Shape().NumDimensions() <= 1 && t->Shape().Size() == 1,
"Currently only scalar zero_point is supported. TODO: add per channel zero point support.");
ORT_ENFORCE(t->DataType() == DataTypeImpl::GetType<int8_t>() ||
t->DataType() == DataTypeImpl::GetType<uint8_t>());
auto data = reinterpret_cast<const int8_t*>(t->DataRaw());
auto vec = std::vector<int8_t>(data, data + t->Shape().Size());
return std::all_of(vec.begin(), vec.end(), [](int8_t v) { return v == 0; });
};

if ((has_a_zero_point_ && !IsZeroPointTensorAllZero(ctx, 2)) ||
(has_b_zero_point_ && !IsZeroPointTensorAllZero(ctx, 3))) {
ORT_NOT_IMPLEMENTED("MatMulInteger: Unsupported input types with zero point");
}
}

#define HANDLE_TYPES_WITH_EIGEN(T1, T2, T3) \
Expand Down
78 changes: 38 additions & 40 deletions onnxruntime/test/providers/cpu/math/matmul_integer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#include "core/framework/op_kernel.h"
#include "core/util/math_cpuonly.h"

#include <random>

namespace onnxruntime {
namespace test {

Expand All @@ -31,53 +33,49 @@ TEST(MatmulIntegerOpTest, MatMulInteger) {
test.Run();
}

TEST(MatmulIntegerOpTest, MatMulInteger_Uint8_Int8_GEMM) {
OpTester test("MatMulInteger", 10);
test.AddInput<uint8_t>("T1", {2, 40}, {4, 26, 2, 69, 76, 14, 14, 44, 58, 2, 21, 21, 60, 32, 59, 23, 74, 56, 37, 25, 26, 31, 60, 37, 46, 18, 30, 56, 3, 38, 55, 72, 31, 56, 37, 52, 11, 50, 57, 74, 48, 59, 33, 0, 51, 22, 73, 50, 70, 43, 34, 63, 74, 63, 34, 15, 43, 36, 52, 72, 7, 17, 33, 49, 3, 67, 47, 12, 75, 62, 43, 46, 43, 3, 18, 8, 61, 19, 58, 13});
test.AddInput<int8_t>("T2", {40, 2}, {12, 10, 9, 13, 2, 7, 10, 4, 17, 15, 14, 19, 4, 11, 1, 19, 7, 0, 12, 3, 13, 8, 4, 12, 16, 18, 15, 12, 9, 8, 12, 19, 14, 11, 18, 15, 9, 3, 0, 13, 5, 1, 5, 3, 8, 3, 2, 17, 0, 0, 4, 13, 6, 12, 11, 17, 9, 10, 5, 8, 15, 5, 4, 18, 5, 0, 4, 10, 11, 4, 4, 0, 9, 14, 14, 17, 19, 13, 3, 13});
test.AddOutput<int32_t>("T3", {2, 2}, {14110, 15747, 14114, 17441});
test.Run();
}

TEST(MatmulIntegerOpTest, MatMulInteger_Uint8_Int8_GEMV_1) {
OpTester test("MatMulInteger", 10);
test.AddInput<uint8_t>("T1", {1, 64}, {1, 9, 0, 20, 24, 22, 22, 18, 20, 2, 21, 13, 9, 3, 19, 8, 16, 16, 9, 15, 2, 21, 4, 18, 2, 2, 10, 16, 0, 19, 17, 13, 21, 21, 7, 15, 5, 8, 23, 14, 17, 0, 21, 23, 24, 7, 18, 9, 17, 13, 14, 23, 1, 18, 3, 17, 5, 17, 15, 24, 24, 9, 22, 7});
test.AddInput<int8_t>("T2", {64, 2}, {18, 16, 3, 13, 14, 9, 18, 17, 7, 12, 11, 5, 11, 0, 13, 13, 9, 17, 10, 12, 10, 3, 16, 12, 5, 10, 2, 2, 3, 15, 17, 3, 0, 3, 16, 6, 19, 19, 17, 2, 5, 7, 11, 19, 2, 0, 15, 11, 0, 8, 17, 10, 1, 11, 8, 14, 8, 13, 7, 13, 11, 13, 5, 18, 4, 3, 16, 12, 10, 12, 4, 10, 11, 18, 19, 3, 5, 4, 9, 2, 19, 4, 17, 6, 5, 6, 14, 3, 18, 3, 13, 17, 16, 2, 5, 2, 15, 4, 16, 14, 11, 18, 0, 17, 1, 17, 8, 3, 18, 13, 7, 1, 7, 6, 11, 8, 7, 11, 10, 13, 16, 16, 4, 19, 4, 9, 4, 14});
test.AddOutput<int32_t>("T3", {1, 2}, {8553, 7926});
test.Run();
}

TEST(MatmulIntegerOpTest, MatMulInteger_Uint8_Int8_1) {
OpTester test("MatMulInteger", 10);
test.AddInput<uint8_t>("T1", {1, 288}, {3, 5, 2, 5, 3, 8, 8, 6, 1, 4, 8, 1, 5, 1, 3, 6, 1, 3, 1, 8, 6, 0, 8, 2, 7, 0, 2, 2, 7, 6, 0, 0, 1, 8, 0, 1, 6, 1, 7, 5, 4, 1, 7, 8, 8, 8, 4, 1, 4, 6, 2, 6, 0, 3, 4, 3, 5, 1, 7, 4, 3, 2, 2, 1, 2, 0, 6, 6, 8, 0, 7, 1, 7, 2, 1, 2, 1, 3, 3, 4, 3, 0, 6, 0, 7, 2, 6, 2, 2, 2, 6, 0, 3, 1, 3, 5, 6, 8, 1, 3, 6, 5, 3, 5, 3, 8, 4, 5, 5, 6, 7, 3, 6, 7, 8, 7, 7, 8, 7, 6, 6, 4, 2, 3, 2, 1, 0, 6, 1, 6, 5, 7, 0, 5, 6, 0, 1, 8, 1, 5, 7, 6, 2, 3, 7, 1, 5, 3, 8, 6, 3, 5, 2, 2, 3, 8, 6, 4, 2, 0, 5, 6, 6, 1, 1, 5, 2, 2, 1, 0, 5, 8, 3, 3, 5, 3, 6, 2, 7, 4, 3, 7, 2, 7, 4, 7, 6, 4, 6, 7, 2, 2, 4, 1, 7, 5, 5, 6, 7, 4, 2, 4, 6, 8, 0, 1, 6, 5, 1, 5, 1, 1, 4, 7, 6, 6, 4, 2, 4, 7, 0, 8, 4, 6, 6, 7, 7, 7, 1, 5, 6, 7, 2, 4, 2, 0, 7, 8, 1, 8, 5, 5, 5, 5, 7, 5, 3, 8, 6, 8, 6, 4, 8, 1, 2, 4, 7, 3, 2, 5, 4, 7, 5, 5, 7, 8, 4, 2, 0, 4, 5, 5, 6, 5, 6, 2, 0, 6, 4, 0, 6, 1, 6, 3, 2, 6, 7, 5});
test.AddInput<int8_t>("T2", {288, 1}, {2, 3, 8, 1, 9, 7, 4, 6, 8, 1, 9, 2, 2, 9, 1, 1, 9, 4, 4, 5, 0, 4, 1, 6, 3, 7, 0, 3, 8, 9, 1, 9, 8, 9, 0, 4, 9, 4, 6, 8, 0, 6, 2, 7, 5, 0, 8, 3, 7, 3, 2, 7, 7, 8, 8, 4, 5, 7, 1, 2, 6, 4, 1, 9, 7, 9, 4, 5, 3, 0, 6, 0, 0, 1, 7, 0, 8, 4, 4, 5, 9, 3, 9, 8, 5, 8, 6, 7, 8, 7, 0, 8, 6, 2, 2, 7, 2, 2, 8, 8, 0, 1, 9, 7, 4, 3, 6, 3, 7, 3, 5, 5, 0, 3, 0, 4, 9, 3, 7, 4, 7, 0, 8, 2, 8, 9, 8, 1, 7, 8, 6, 6, 0, 6, 7, 6, 2, 9, 6, 9, 7, 1, 1, 5, 3, 4, 1, 8, 7, 2, 6, 9, 7, 7, 8, 2, 6, 8, 8, 2, 5, 4, 8, 9, 9, 7, 7, 9, 9, 9, 8, 4, 0, 1, 5, 5, 3, 6, 0, 0, 0, 0, 9, 8, 5, 2, 2, 4, 4, 5, 2, 4, 9, 9, 5, 9, 9, 0, 5, 5, 5, 0, 7, 1, 0, 4, 1, 9, 6, 3, 3, 4, 9, 1, 7, 5, 5, 0, 6, 8, 1, 0, 4, 6, 5, 1, 5, 9, 5, 0, 4, 8, 6, 8, 2, 4, 8, 2, 6, 5, 8, 9, 7, 7, 4, 4, 3, 1, 4, 7, 1, 9, 5, 8, 6, 7, 9, 3, 8, 4, 5, 3, 9, 2, 3, 9, 0, 3, 4, 2, 6, 0, 9, 6, 4, 0, 5, 2, 3, 6, 6, 5, 1, 9, 6, 8, 1, 0});
test.AddOutput<int32_t>("T3", {1, 1}, {5566});
test.Run();
template <typename T>
std::vector<T> ToVector(const int* value, int size) {
std::vector<T> data(size);
for (int i = 0; i < size; i++)
data[i] = static_cast<T>(value[i]);
return data;
}

TEST(MatmulIntegerOpTest, MatMulInteger_Uint8_Int8_GEMV_2) {
// [M x N] = [M x K] x [K x N] = [batch_seq x input_dim] x [input_dim x embed_dim]
void RunMatMulIntegerU8S8Test(const int M, const int N, const int K) {
OpTester test("MatMulInteger", 10);
test.AddInput<uint8_t>("T1", {1, 16}, {15, 5, 21, 21, 16, 19, 25, 19, 8, 3, 12, 4, 28, 25, 25, 26});
test.AddInput<int8_t>("T2", {16, 2}, {26, 20, 23, 19, 6, 1, 20, 8, 16, 1, 11, 13, 25, 0, 19, 5, 6, 20, 9, 1, 16, 13, 2, 21, 27, 16, 19, 20, 5, 27, 6, 24});
test.AddOutput<int32_t>("T3", {1, 2}, {4289, 3592});
test.Run();
}
static std::default_random_engine e(123);
static std::uniform_int_distribution<int> n_unsigned(0, 127); // reserve 1-bit
static std::uniform_int_distribution<int> n_signed(-128, 127); // no reserved bit
Eigen::MatrixXi T1 = Eigen::MatrixXi::Random(K, M)
.unaryExpr([](int) { return n_unsigned(e); });
Eigen::MatrixXi T2 = Eigen::MatrixXi::Random(N, K)
.unaryExpr([](int) { return n_signed(e); });
Eigen::MatrixXi T3 = (T2 * T1).eval();

TEST(MatmulIntegerOpTest, MatMulInteger_Uint8_Int8_2) {
OpTester test("MatMulInteger", 10);
test.AddInput<uint8_t>("T1", {1, 260}, {6, 8, 4, 3, 1, 2, 7, 0, 7, 4, 2, 6, 9, 7, 5, 4, 8, 4, 1, 9, 4, 3, 1, 9, 9, 5, 7, 5, 9, 0, 3, 7, 3, 3, 9, 3, 8, 1, 4, 2, 7, 2, 6, 0, 4, 3, 4, 5, 2, 5, 8, 6, 6, 3, 0, 5, 8, 3, 5, 2, 8, 3, 3, 1, 2, 8, 3, 4, 1, 0, 9, 3, 3, 6, 9, 3, 3, 8, 6, 4, 6, 5, 2, 4, 4, 0, 2, 7, 7, 4, 9, 1, 0, 5, 4, 7, 5, 3, 7, 0, 8, 5, 1, 3, 4, 9, 3, 6, 9, 5, 4, 2, 7, 1, 7, 0, 2, 9, 7, 1, 7, 3, 6, 7, 5, 6, 2, 7, 1, 7, 5, 5, 8, 7, 1, 0, 0, 4, 0, 5, 1, 7, 8, 8, 9, 5, 8, 5, 6, 5, 9, 7, 6, 7, 7, 9, 9, 0, 1, 3, 6, 8, 0, 6, 0, 2, 7, 1, 5, 8, 2, 7, 8, 9, 2, 8, 2, 3, 6, 2, 8, 4, 3, 1, 4, 9, 1, 8, 9, 5, 7, 4, 2, 6, 4, 9, 5, 3, 4, 8, 4, 3, 5, 7, 3, 3, 7, 0, 7, 4, 2, 4, 6, 2, 9, 7, 9, 6, 4, 8, 7, 9, 8, 5, 3, 7, 3, 7, 6, 7, 8, 5, 0, 2, 1, 5, 6, 2, 9, 2, 3, 5, 9, 3, 1, 7, 9, 1, 9, 2, 8, 4, 1, 5, 7, 6, 8, 7, 2, 9});
test.AddInput<int8_t>("T2", {260, 1}, {2, 5, 3, 0, 8, 3, 4, 0, 6, 5, 2, 3, 0, 4, 4, 5, 7, 5, 7, 2, 4, 2, 5, 1, 0, 4, 6, 7, 5, 2, 1, 2, 4, 5, 7, 7, 1, 1, 7, 4, 8, 5, 0, 6, 0, 5, 3, 2, 8, 1, 5, 2, 4, 2, 2, 8, 5, 2, 7, 8, 6, 5, 3, 7, 7, 4, 5, 5, 7, 1, 4, 5, 1, 2, 7, 2, 5, 5, 6, 1, 5, 8, 2, 5, 5, 4, 7, 2, 1, 5, 5, 6, 5, 0, 0, 1, 0, 0, 8, 4, 5, 4, 0, 5, 4, 8, 0, 1, 2, 4, 7, 3, 3, 3, 6, 3, 6, 3, 5, 0, 7, 5, 0, 2, 7, 7, 4, 4, 0, 7, 5, 2, 6, 2, 4, 0, 7, 1, 7, 5, 4, 0, 1, 8, 4, 1, 8, 3, 5, 8, 2, 5, 1, 7, 7, 2, 0, 8, 7, 2, 3, 3, 4, 8, 3, 3, 7, 6, 2, 6, 2, 5, 3, 2, 0, 5, 5, 3, 6, 0, 2, 8, 4, 6, 2, 7, 1, 4, 5, 6, 2, 5, 6, 0, 5, 2, 4, 1, 5, 2, 2, 8, 6, 4, 4, 7, 3, 4, 3, 5, 8, 7, 7, 8, 5, 0, 2, 4, 1, 7, 4, 6, 8, 3, 7, 2, 8, 2, 4, 6, 8, 5, 3, 6, 0, 8, 1, 7, 1, 7, 8, 7, 3, 1, 0, 4, 3, 0, 1, 5, 5, 7, 2, 7, 3, 4, 8, 5, 3, 1});
test.AddOutput<int32_t>("T3", {1, 1}, {5056});
test.AddInput<uint8_t>("T1", {M, K},
ToVector<uint8_t>(T1.data(), M * K));
test.AddInput<int8_t>("T2", {K, N},
ToVector<int8_t>(T2.data(), K * N), /*is_initializer*/ true);
test.AddOutput<int32_t>("T3", {M, N},
ToVector<int32_t>(T3.data(), M * N));
test.Run();
}

TEST(MatmulIntegerOpTest, MatMulInteger_Uint8_Int8_3) {
OpTester test("MatMulInteger", 10);
test.AddInput<uint8_t>("T1", {1, 32}, {1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31, 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31});
test.AddInput<int8_t>("T2", {32, 1}, {2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32});
test.AddOutput<int32_t>("T3", {1, 1}, {11424});
test.Run();
TEST(MatmulIntegerOpTest, MatMulInteger_Uint8_Int8) {
// GEMV
RunMatMulIntegerU8S8Test(1, 2, 64);
RunMatMulIntegerU8S8Test(1, 2, 16);
RunMatMulIntegerU8S8Test(1, 1, 288);
RunMatMulIntegerU8S8Test(1, 1, 32);
RunMatMulIntegerU8S8Test(1, 1, 260);
// GEMM
RunMatMulIntegerU8S8Test(2, 2, 40);
RunMatMulIntegerU8S8Test(2, 48, 33);
RunMatMulIntegerU8S8Test(2, 51, 40);
RunMatMulIntegerU8S8Test(6, 10, 34);
RunMatMulIntegerU8S8Test(8, 16, 64);
}

} // namespace test
} // namespace onnxruntime
} // namespace onnxruntime

0 comments on commit a78ff58

Please sign in to comment.