Skip to content

Commit

Permalink
Do type checking for the input and kernel in the qnn conv2d (#3904)
Browse files Browse the repository at this point in the history
* [QNN] Convolution 2D Implementation.

Rebasing. Empty commit.

Clang-format styling.

* Reformatting code.

* Fixing lint issues.
  • Loading branch information
shoubhik authored and zhiics committed Sep 12, 2019
1 parent 88f9bfd commit 880c260
Showing 1 changed file with 21 additions and 1 deletion.
22 changes: 21 additions & 1 deletion src/relay/qnn/op/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,26 @@ namespace qnn {
// relay.op.qnn.conv2d
TVM_REGISTER_NODE_TYPE(QnnConv2DAttrs);

bool QnnConv2DRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
const auto* weight = types[1].as<TensorTypeNode>();
if (data == nullptr || weight == nullptr) return false;
const auto* param = attrs.as<QnnConv2DAttrs>();
CHECK(param != nullptr) << "QnnConv2DAttrs cannot be nullptr.";
CHECK(data->dtype == Int(8) || data->dtype == UInt(8))
<< "Expected qnn conv2d type(int8, uint8) for input but was " << data->dtype;
CHECK(weight->dtype == Int(8) || weight->dtype == UInt(8))
<< "Expected qnn conv2d type(int8, uint8) for weight but was " << weight->dtype;
CHECK(param->out_dtype == Int(16) || param->out_dtype == Int(32))
<< "Expected qnn conv2d type(int32, int16) for output but was " << param->out_dtype;
CHECK(param->out_dtype.bits() > 0) << "Output dtype bits should be greater than 0.";
return Conv2DRel<QnnConv2DAttrs>(types, num_inputs, attrs, reporter);
}

// Workload - batch_size, in_channels, out_channels, kernel_h, kernel_w
using WorkloadType = std::tuple<int, int, int, int, int>;

Expand Down Expand Up @@ -475,7 +495,7 @@ operator to understand how to scale back the int32 output to (u)int8.
.add_argument("data", "Tensor", "The quantized input data tensor.")
.add_argument("weight", "Tensor", "The quantized weight tensor.")
.set_support_level(11)
.add_type_rel("QnnConv2D", Conv2DRel<QnnConv2DAttrs>)
.add_type_rel("QnnConv2D", QnnConv2DRel)
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnConv2DCanonicalize);

TVM_REGISTER_API("relay.qnn.op._make.conv2d").set_body_typed(MakeQnnConv2D);
Expand Down

0 comments on commit 880c260

Please sign in to comment.