Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RELAY] Refactor FoldConstant to skip TNonComputationalOps #6720

Merged
merged 12 commits into from
Oct 24, 2020
1 change: 1 addition & 0 deletions src/relay/qnn/op/concatenate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ RELAY_REGISTER_OP("qnn.concatenate")
"The quantization zero_point of the output tensor.")
.set_support_level(11)
.add_type_rel("QnnConcatenate", QnnConcatenateRel)
.set_attr<TNonComputational>("TNonComputational", true)
electriclilies marked this conversation as resolved.
Show resolved Hide resolved
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", ConcatenateQnnCanonicalize)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", QnnConcatenateLayout);

Expand Down
1 change: 1 addition & 0 deletions src/relay/qnn/op/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -733,6 +733,7 @@ operator to understand how to scale back the int32 output to (u)int8.
"The quantization zero_point of the weight tensor.")
.set_support_level(11)
.add_type_rel("QnnConv2D", QnnConv2DRel)
.set_attr<TNonComputational>("TNonComputational", true)
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnConv2DCanonicalize)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", QnnConvInferCorrectLayout);

Expand Down
1 change: 1 addition & 0 deletions src/relay/qnn/op/dense.cc
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ RELAY_REGISTER_OP("qnn.dense")
"The quantization zero_point of the weight tensor.")
.set_support_level(11)
.add_type_rel("QDense", QnnDenseRel)
.set_attr<TNonComputational>("TNonComputational", true)
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnDenseCanonicalize);

TVM_REGISTER_GLOBAL("relay.qnn.op._make.dense").set_body_typed(MakeQuantizedDense);
Expand Down
1 change: 1 addition & 0 deletions src/relay/qnn/op/dequantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ The input is always quantized (int8, uint8) and will be converted to float32 giv
.add_argument("input_zero_point", "Tensor", "The quantization zero_point of the input tensor.")
.set_support_level(11)
.add_type_rel("Dequantize", DequantizeRel)
.set_attr<TNonComputational>("TNonComputational", true)
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", DequantizeQnnCanonicalize);

TVM_REGISTER_GLOBAL("relay.qnn.op._make.dequantize").set_body_typed(MakeDequantize);
Expand Down
1 change: 1 addition & 0 deletions src/relay/qnn/op/op_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ static inline bool QnnBroadcastRel(const Array<Type>& types, int num_inputs, con
.add_argument("output_scale", "Tensor", "The scale of the output tensor.") \
.add_argument("output_zero_point", "Tensor", "The zero_point of the output tensor.") \
.add_type_rel("QnnBroadcast", QnnBroadcastRel) \
.set_attr<TNonComputational>("TNonComputational", true) \
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", QnnBinaryBroadcastLayout)

} // namespace qnn
Expand Down
1 change: 1 addition & 0 deletions src/relay/qnn/op/quantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ scale and zero point.
"The quantization zero_point of the output tensor.")
.set_support_level(11)
.add_type_rel("Quantize", QuantizeRel)
.set_attr<TNonComputational>("TNonComputational", true)
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QuantizeQnnCanonicalize);

TVM_REGISTER_GLOBAL("relay.qnn.op._make.quantize").set_body_typed(MakeQuantize);
Expand Down
1 change: 1 addition & 0 deletions src/relay/qnn/op/requantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ Q_output = zp_output + (scale_input)/(scale_output) * (Q_input - zp_input)
"The quantization zero_point of the output tensor.")
.set_support_level(11)
.add_type_rel("Requantize", RequantizeRel)
.set_attr<TNonComputational>("TNonComputational", true)
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", RequantizeQnnCanonicalize)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", RequantizeInferCorrectLayout);

Expand Down
11 changes: 8 additions & 3 deletions src/relay/transforms/fold_constant.cc
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,14 @@ class ConstantFolder : public MixedModeMutator {
}

// We should think about potentially constant evaluation over these ops too.
if (call->op == invoke_tvm_op_ || call->op == shape_func_op_ || call->op == alloc_tensor_op_ ||
call->op == alloc_storage_op_ || call->op == device_copy_op_) {
electriclilies marked this conversation as resolved.
Show resolved Hide resolved
return GetRef<Call>(call);
static auto nonComputational = Op::GetAttrMap<TNonComputational>("TNonComputational");
if (auto call_node = call->op.as<OpNode>()) {
electriclilies marked this conversation as resolved.
Show resolved Hide resolved
Op op = GetRef<Op>(call_node);
if (nonComputational.count(op)) {
if (nonComputational[op]) {
return GetRef<Call>(call);
}
}
}

bool all_const_args = true;
Expand Down