Skip to content

Commit

Permalink
[RELAY] Refactor FoldConstant to skip TNonComputationalOps (apache#6720)
Browse files Browse the repository at this point in the history
* add TNonComputational to qnn ops and change FoldConstant

* remove comments

* check if op in nonComputational map

* forgot to mark device_copy op as TNonComputational

* hacky fix to fuseops pass

* fix typo

* manually skip device_copy in fold_constant

* Update src/relay/transforms/fold_constant.cc

Co-authored-by: Junru Shao <junrushao1994@gmail.com>

Co-authored-by: Junru Shao <junrushao1994@gmail.com>
  • Loading branch information
2 people authored and Trevor Morris committed Dec 4, 2020
1 parent 0e3a2f2 commit 4e5ce77
Show file tree
Hide file tree
Showing 8 changed files with 13 additions and 3 deletions.
1 change: 1 addition & 0 deletions src/relay/qnn/op/concatenate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ RELAY_REGISTER_OP("qnn.concatenate")
"The quantization zero_point of the output tensor.")
.set_support_level(11)
.add_type_rel("QnnConcatenate", QnnConcatenateRel)
.set_attr<TNonComputational>("TNonComputational", true)
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", ConcatenateQnnCanonicalize)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", QnnConcatenateLayout);

Expand Down
1 change: 1 addition & 0 deletions src/relay/qnn/op/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -733,6 +733,7 @@ operator to understand how to scale back the int32 output to (u)int8.
"The quantization zero_point of the weight tensor.")
.set_support_level(11)
.add_type_rel("QnnConv2D", QnnConv2DRel)
.set_attr<TNonComputational>("TNonComputational", true)
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnConv2DCanonicalize)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", QnnConvInferCorrectLayout);

Expand Down
1 change: 1 addition & 0 deletions src/relay/qnn/op/dense.cc
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ RELAY_REGISTER_OP("qnn.dense")
"The quantization zero_point of the weight tensor.")
.set_support_level(11)
.add_type_rel("QDense", QnnDenseRel)
.set_attr<TNonComputational>("TNonComputational", true)
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnDenseCanonicalize);

TVM_REGISTER_GLOBAL("relay.qnn.op._make.dense").set_body_typed(MakeQuantizedDense);
Expand Down
1 change: 1 addition & 0 deletions src/relay/qnn/op/dequantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ The input is always quantized (int8, uint8) and will be converted to float32 giv
.add_argument("input_zero_point", "Tensor", "The quantization zero_point of the input tensor.")
.set_support_level(11)
.add_type_rel("Dequantize", DequantizeRel)
.set_attr<TNonComputational>("TNonComputational", true)
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", DequantizeQnnCanonicalize);

TVM_REGISTER_GLOBAL("relay.qnn.op._make.dequantize").set_body_typed(MakeDequantize);
Expand Down
1 change: 1 addition & 0 deletions src/relay/qnn/op/op_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ static inline bool QnnBroadcastRel(const Array<Type>& types, int num_inputs, con
.add_argument("output_scale", "Tensor", "The scale of the output tensor.") \
.add_argument("output_zero_point", "Tensor", "The zero_point of the output tensor.") \
.add_type_rel("QnnBroadcast", QnnBroadcastRel) \
.set_attr<TNonComputational>("TNonComputational", true) \
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", QnnBinaryBroadcastLayout)

} // namespace qnn
Expand Down
1 change: 1 addition & 0 deletions src/relay/qnn/op/quantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ scale and zero point.
"The quantization zero_point of the output tensor.")
.set_support_level(11)
.add_type_rel("Quantize", QuantizeRel)
.set_attr<TNonComputational>("TNonComputational", true)
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QuantizeQnnCanonicalize);

TVM_REGISTER_GLOBAL("relay.qnn.op._make.quantize").set_body_typed(MakeQuantize);
Expand Down
1 change: 1 addition & 0 deletions src/relay/qnn/op/requantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ Q_output = zp_output + (scale_input)/(scale_output) * (Q_input - zp_input)
"The quantization zero_point of the output tensor.")
.set_support_level(11)
.add_type_rel("Requantize", RequantizeRel)
.set_attr<TNonComputational>("TNonComputational", true)
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", RequantizeQnnCanonicalize)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", RequantizeInferCorrectLayout);

Expand Down
9 changes: 6 additions & 3 deletions src/relay/transforms/fold_constant.cc
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,12 @@ class ConstantFolder : public MixedModeMutator {
}

// We should think about potentially constant evaluation over these ops too.
if (call->op == invoke_tvm_op_ || call->op == shape_func_op_ || call->op == alloc_tensor_op_ ||
call->op == alloc_storage_op_ || call->op == device_copy_op_) {
return GetRef<Call>(call);
static auto fnoncomputational = Op::GetAttrMap<TNonComputational>("TNonComputational");
if (const auto* call_node = call->op.as<OpNode>()) {
Op op = GetRef<Op>(call_node);
if ((fnoncomputational.count(op) && fnoncomputational[op]) || (call->op == device_copy_op_)) {
return GetRef<Call>(call);
}
}

bool all_const_args = true;
Expand Down

0 comments on commit 4e5ce77

Please sign in to comment.