Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix QNN type inference #7074

Merged
merged 3 commits into from
Dec 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 30 additions & 6 deletions src/relay/qnn/op/concatenate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,29 +38,53 @@ namespace qnn {

bool QnnConcatenateRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// Expected Types: data, input_scales, input_zero_points, output_scale, output_zero_point,
// out_type
ICHECK_EQ(types.size(), 6);

if (types[0].as<IncompleteTypeNode>()) {
return false;
}
// Check the scale and zero point types
const auto* input_scales_tuple = types[1].as<TupleTypeNode>();
if (input_scales_tuple == nullptr) {
throw Error(ErrorBuilder()
<< "qnn concatenate requires a tuple of scales as the second argument, found "
<< PrettyPrint(types[1]));
if (types[1].as<IncompleteTypeNode>()) {
return false;
} else {
throw Error(ErrorBuilder()
<< "qnn concatenate requires a tuple of scales as the second argument, found "
<< PrettyPrint(types[1]));
}
}
for (const auto& input_scale : input_scales_tuple->fields) {
if (input_scale.as<IncompleteTypeNode>()) {
return false;
}
ICHECK(IsScalarType(input_scale, DataType::Float(32))); // input_scales[idx]
}

const auto* input_zero_points_tuple = types[2].as<TupleTypeNode>();
if (input_zero_points_tuple == nullptr) {
throw Error(ErrorBuilder()
<< "qnn concatenate requires a tuple of zero_points as the third argument, found "
<< PrettyPrint(types[2]));
if (types[2].as<IncompleteTypeNode>()) {
return false;
} else {
throw Error(ErrorBuilder()
<< "qnn concatenate requires a tuple of zero_points as the third argument, found "
<< PrettyPrint(types[2]));
}
}
for (const auto& input_zero_point : input_zero_points_tuple->fields) {
if (input_zero_point.as<IncompleteTypeNode>()) {
return false;
}
ICHECK(IsScalarType(input_zero_point, DataType::Int(32))); // input_zero_points[idx]
}

for (size_t i = 3; i < 5; ++i) {
if (types[i].as<IncompleteTypeNode>()) {
return false;
}
}
ICHECK(IsScalarType(types[3], DataType::Float(32))); // output_scale
ICHECK(IsScalarType(types[4], DataType::Int(32))); // output_zero_point

Expand Down
13 changes: 10 additions & 3 deletions src/relay/qnn/op/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ namespace qnn {

bool QnnConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// Expected Types: data, weight, input_zero_point, weight_zero_point, input_scale, weight_scale,
// out_type
ICHECK_EQ(types.size(), 7);
const auto* data = types[0].as<TensorTypeNode>();
const auto* weight = types[1].as<TensorTypeNode>();
Expand All @@ -57,22 +59,27 @@ bool QnnConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
ICHECK(param->out_dtype.bits() > 0) << "Output dtype bits should be greater than 0.";

// Check the types of scale and zero points.
for (size_t i = 2; i < 5; ++i) {
masahi marked this conversation as resolved.
Show resolved Hide resolved
if (types[i].as<IncompleteTypeNode>()) {
return false;
}
}
ICHECK(IsScalarType(types[2], DataType::Int(32))); // input_zero_point
ICHECK(IsScalarType(types[3], DataType::Int(32))); // kernel_zero_point
ICHECK(IsScalarType(types[3], DataType::Int(32))); // weight_zero_point
ICHECK(IsScalarType(types[4], DataType::Float(32))); // input_scale
// Kernel scale can be a vector of length output_channels or a scalar.
if (param->groups == 1) {
size_t axis = param->kernel_layout.operator std::string().find('O');
ICHECK(axis != std::string::npos) << "Kernel layout attribute is not defined";
AssignType(types[5], DataType::Float(32), weight->shape[axis], reporter); // kernel scale
AssignType(types[5], DataType::Float(32), weight->shape[axis], reporter); // weight_scale
} else {
// Here, total number of output channels depend on depth multiplier.
size_t o_axis = param->kernel_layout.operator std::string().find('O');
size_t i_axis = param->kernel_layout.operator std::string().find('I');
ICHECK(o_axis != std::string::npos || i_axis != std::string::npos)
<< "Kernel layout attribute is not defined";
AssignType(types[5], DataType::Float(32), weight->shape[i_axis] * weight->shape[o_axis],
reporter); // kernel scale
reporter); // weight_scale
}

// Collect the input tensor and output tensor devoid of scale and zero points to reuse Relay
Expand Down
11 changes: 9 additions & 2 deletions src/relay/qnn/op/convolution_transpose.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ Array<Array<Layout>> QnnConvTransposeInferCorrectLayout(

bool QnnConv2DTransposeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// Expected Types: data, weight, input_zero_point, weight_zero_point, input_scale, weight_scale,
// out_type
ICHECK_EQ(types.size(), 7);
const auto* data = types[0].as<TensorTypeNode>();
const auto* weight = types[1].as<TensorTypeNode>();
Expand All @@ -96,14 +98,19 @@ bool QnnConv2DTransposeRel(const Array<Type>& types, int num_inputs, const Attrs
ICHECK(param->out_dtype.bits() > 0) << "Output dtype bits should be greater than 0.";

// Check the types of scale and zero points.
for (size_t i = 2; i < 5; ++i) {
if (types[i].as<IncompleteTypeNode>()) {
return false;
}
}
ICHECK(IsScalarType(types[2], DataType::Int(32))); // input_zero_point
ICHECK(IsScalarType(types[3], DataType::Int(32))); // kernel_zero_point
ICHECK(IsScalarType(types[3], DataType::Int(32))); // weight_zero_point
ICHECK(IsScalarType(types[4], DataType::Float(32))); // input_scale
// Kernel scale can be a vector of length output_channels or a scalar.
if (param->groups == 1) {
size_t axis = param->kernel_layout.find('O');
ICHECK(axis != std::string::npos) << "Kernel layout attribute is not defined";
AssignType(types[5], DataType::Float(32), weight->shape[axis], reporter); // kernel scale
AssignType(types[5], DataType::Float(32), weight->shape[axis], reporter); // weight_scale
} else {
// Here, total number of output channels depend on depth multiplier.
size_t o_axis = param->kernel_layout.find('O');
Expand Down
15 changes: 11 additions & 4 deletions src/relay/qnn/op/dense.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ namespace qnn {

bool QnnDenseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// Expected Types: data, weight, input_zero_point, weight_zero_point, input_scale, weight_scale,
// out_type
ICHECK_EQ(types.size(), 7);
const auto* data = types[0].as<TensorTypeNode>();
const auto* weight = types[1].as<TensorTypeNode>();
Expand All @@ -53,10 +55,15 @@ bool QnnDenseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
<< "Expected quantized dense type(int32) for output but was " << param->out_dtype;

// Check the types of scale and zero points.
ICHECK(IsScalarType(types[2], DataType::Int(32))); // input_zero_point
ICHECK(IsScalarType(types[3], DataType::Int(32))); // kernel_zero_point
ICHECK(IsScalarType(types[4], DataType::Float(32))); // input_scale
AssignType(types[5], DataType::Float(32), param->units, reporter);
for (size_t i = 2; i < 5; ++i) {
if (types[i].as<IncompleteTypeNode>()) {
return false;
}
}
ICHECK(IsScalarType(types[2], DataType::Int(32))); // input_zero_point
ICHECK(IsScalarType(types[3], DataType::Int(32))); // weight_zero_point
ICHECK(IsScalarType(types[4], DataType::Float(32))); // input_scale
AssignType(types[5], DataType::Float(32), param->units, reporter); // weight_scale

ICHECK(param->out_dtype.bits() > 0) << "Output dtype bits should be greater than 0.";

Expand Down
13 changes: 13 additions & 0 deletions src/relay/qnn/op/op_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,22 @@ inline Array<Array<Layout> > QnnBinaryBroadcastLayout(const Attrs& attrs,

static inline bool QnnBroadcastRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// Expected Types: lhs, rhs, lhs_scale, lhs_zero_point, rhs_scale, rhs_zero_point, output_scale,
// output_zero_point, out_type
ICHECK_EQ(types.size(), kNumQnnBinaryOpArgTypes);

// Check the lhs and rhs types
for (size_t i = 0; i < 2; ++i) {
if (types[i].as<IncompleteTypeNode>()) {
return false;
}
}
// Check the scale and zero point types
for (size_t i = 2; i < 8; ++i) {
if (types[i].as<IncompleteTypeNode>()) {
return false;
}
}
ICHECK(IsScalarType(types[2], DataType::Float(32))); // lhs_scale
ICHECK(IsScalarType(types[3], DataType::Int(32))); // lhs_zero_point
ICHECK(IsScalarType(types[4], DataType::Float(32))); // rhs_scale
Expand Down
7 changes: 7 additions & 0 deletions src/relay/qnn/op/requantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -256,13 +256,20 @@ Expr RequantizeQnnCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
*/
bool RequantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// Expected Types: data, input_scale, input_zero_point, output_scale, output_zero_point, output
ICHECK_EQ(types.size(), 6);
const auto* data = types[0].as<TensorTypeNode>();

if (data == nullptr) {
return false;
}

// Check the scale and zero point types
for (size_t i = 3; i < 5; ++i) {
mbrookhart marked this conversation as resolved.
Show resolved Hide resolved
masahi marked this conversation as resolved.
Show resolved Hide resolved
if (types[i].as<IncompleteTypeNode>()) {
return false;
}
}
const auto in_dtype = data->dtype;
ICHECK(in_dtype == DataType::Int(8) || in_dtype == DataType::UInt(8) ||
in_dtype == DataType::Int(32))
Expand Down
41 changes: 41 additions & 0 deletions tests/python/frontend/pytorch/qnn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,58 @@
from tvm.relay.frontend.pytorch_utils import is_version_greater_than
from tvm.contrib.download import download_testdata

from tvm.relay.dataflow_pattern import wildcard, is_op
from tvm.relay.op.contrib.register import register_pattern_table
from tvm.relay.op.contrib.register import get_pattern_table


def torch_version_check():
from packaging import version

return version.parse(torch.__version__) > version.parse("1.4.0")


def make_qnn_add_pattern():
lhs = wildcard()
rhs = wildcard()
lhs_scale = wildcard()
lhs_zero_point = wildcard()
rhs_scale = wildcard()
rhs_zero_point = wildcard()
output_scale = wildcard()
output_zero_point = wildcard()
qadd = is_op("qnn.add")(
lhs,
rhs,
lhs_scale,
lhs_zero_point,
rhs_scale,
rhs_zero_point,
output_scale,
output_zero_point,
)
return qadd.optional(is_op("clip"))


@register_pattern_table("test_table")
def pattern_table():
return [
("qnn_add", make_qnn_add_pattern()),
]


def get_tvm_runtime(script_module, input_name, ishape):

input_shapes = [(input_name, ishape)]
mod, params = relay.frontend.from_pytorch(script_module, input_shapes)
pattern_table = get_pattern_table("test_table")
masahi marked this conversation as resolved.
Show resolved Hide resolved
with tvm.transform.PassContext(opt_level=3):
pass_list = [
tvm.relay.transform.SimplifyInference(),
tvm.relay.transform.MergeComposite(pattern_table),
]
composite_partition = tvm.transform.Sequential(pass_list)
partitioned = composite_partition(mod)

with tvm.transform.PassContext(opt_level=3):
# test on only cpu for now, torch cannot run quant models on cuda
Expand Down