From 3b7cda07004f63f7174dde095d3fc37e2990a478 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Tue, 20 Oct 2020 11:57:01 -0700 Subject: [PATCH 1/8] add TNonComputational to qnn ops and change FoldConstant --- src/relay/qnn/op/concatenate.cc | 1 + src/relay/qnn/op/convolution.cc | 1 + src/relay/qnn/op/dense.cc | 1 + src/relay/qnn/op/dequantize.cc | 1 + src/relay/qnn/op/op_common.h | 1 + src/relay/qnn/op/quantize.cc | 1 + src/relay/qnn/op/requantize.cc | 1 + src/relay/transforms/fold_constant.cc | 11 ++++++++--- 8 files changed, 15 insertions(+), 3 deletions(-) diff --git a/src/relay/qnn/op/concatenate.cc b/src/relay/qnn/op/concatenate.cc index bda8cf878793..11ec3bfca57e 100644 --- a/src/relay/qnn/op/concatenate.cc +++ b/src/relay/qnn/op/concatenate.cc @@ -207,6 +207,7 @@ RELAY_REGISTER_OP("qnn.concatenate") "The quantization zero_point of the output tensor.") .set_support_level(11) .add_type_rel("QnnConcatenate", QnnConcatenateRel) + .set_attr("TNonComputational", true) .set_attr("FTVMQnnCanonicalize", ConcatenateQnnCanonicalize) .set_attr("FInferCorrectLayout", QnnConcatenateLayout); diff --git a/src/relay/qnn/op/convolution.cc b/src/relay/qnn/op/convolution.cc index 5d2e360e0951..f4e017183fa9 100644 --- a/src/relay/qnn/op/convolution.cc +++ b/src/relay/qnn/op/convolution.cc @@ -714,6 +714,7 @@ operator to understand how to scale back the int32 output to (u)int8. "The quantization zero_point of the weight tensor.") .set_support_level(11) .add_type_rel("QnnConv2D", QnnConv2DRel) + .set_attr("TNonComputational", true) .set_attr("FTVMQnnCanonicalize", QnnConv2DCanonicalize) .set_attr("FInferCorrectLayout", QnnConvInferCorrectLayout); diff --git a/src/relay/qnn/op/dense.cc b/src/relay/qnn/op/dense.cc index 464b3f9aeff3..4e719aace27a 100644 --- a/src/relay/qnn/op/dense.cc +++ b/src/relay/qnn/op/dense.cc @@ -189,6 +189,7 @@ RELAY_REGISTER_OP("qnn.dense") "The quantization zero_point of the weight tensor.") .set_support_level(11) .add_type_rel("QDense", QnnDenseRel) + .set_attr("TNonComputational", true) .set_attr("FTVMQnnCanonicalize", QnnDenseCanonicalize); TVM_REGISTER_GLOBAL("relay.qnn.op._make.dense").set_body_typed(MakeQuantizedDense); diff --git a/src/relay/qnn/op/dequantize.cc b/src/relay/qnn/op/dequantize.cc index da804dace60d..1a75e4771fa5 100644 --- a/src/relay/qnn/op/dequantize.cc +++ b/src/relay/qnn/op/dequantize.cc @@ -132,6 +132,7 @@ The input is always quantized (int8, uint8) and will be converted to float32 giv .add_argument("input_zero_point", "Tensor", "The quantization zero_point of the input tensor.") .set_support_level(11) .add_type_rel("Dequantize", DequantizeRel) + .set_attr("TNonComputational", true) .set_attr("FTVMQnnCanonicalize", DequantizeQnnCanonicalize); TVM_REGISTER_GLOBAL("relay.qnn.op._make.dequantize").set_body_typed(MakeDequantize); diff --git a/src/relay/qnn/op/op_common.h b/src/relay/qnn/op/op_common.h index 50fc0cda30cf..6c3bf9ca5ce7 100644 --- a/src/relay/qnn/op/op_common.h +++ b/src/relay/qnn/op/op_common.h @@ -215,6 +215,7 @@ static inline bool QnnBroadcastRel(const Array& types, int num_inputs, con .add_argument("output_scale", "Tensor", "The scale of the output tensor.") \ .add_argument("output_zero_point", "Tensor", "The zero_point of the output tensor.") \ .add_type_rel("QnnBroadcast", QnnBroadcastRel) \ + .set_attr("TNonComputational", true) \ .set_attr("FInferCorrectLayout", QnnBinaryBroadcastLayout) } // namespace qnn diff --git a/src/relay/qnn/op/quantize.cc b/src/relay/qnn/op/quantize.cc index 28f0b8994a01..1e2b5048fb6e 100644 --- a/src/relay/qnn/op/quantize.cc +++ b/src/relay/qnn/op/quantize.cc @@ -146,6 +146,7 @@ scale and zero point. "The quantization zero_point of the output tensor.") .set_support_level(11) .add_type_rel("Quantize", QuantizeRel) + .set_attr("TNonComputational", true) .set_attr("FTVMQnnCanonicalize", QuantizeQnnCanonicalize); TVM_REGISTER_GLOBAL("relay.qnn.op._make.quantize").set_body_typed(MakeQuantize); diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc index 222d91021b19..29be3fcaccfc 100644 --- a/src/relay/qnn/op/requantize.cc +++ b/src/relay/qnn/op/requantize.cc @@ -320,6 +320,7 @@ Q_output = zp_output + (scale_input)/(scale_output) * (Q_input - zp_input) "The quantization zero_point of the output tensor.") .set_support_level(11) .add_type_rel("Requantize", RequantizeRel) + .set_attr("TNonComputational", true) .set_attr("FTVMQnnCanonicalize", RequantizeQnnCanonicalize) .set_attr("FInferCorrectLayout", RequantizeInferCorrectLayout); diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index 0ecbfea8c905..1cb8d9216f68 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -133,9 +133,14 @@ class ConstantFolder : public ExprMutator { } // We should think about potentially constant evaluation over these ops too. - if (call->op == invoke_tvm_op_ || call->op == shape_func_op_ || call->op == alloc_tensor_op_ || - call->op == alloc_storage_op_) { - return GetRef(call); + //if (call->op == invoke_tvm_op_ || call->op == shape_func_op_ || call->op == alloc_tensor_op_ || + // call->op == alloc_storage_op_) { + static auto nonComputational = Op::GetAttrMap("TNonComputational"); + if (auto call_node = call->op.as()) { + Op op = GetRef(call_node); + if (nonComputational[op]) { + return GetRef(call); + } } bool all_const_args = true; From 064b270d58effb54f21d3dca1aab94ed76a84039 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Tue, 20 Oct 2020 12:04:54 -0700 Subject: [PATCH 2/8] remove comments --- src/relay/transforms/fold_constant.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index 1cb8d9216f68..ad0912ef600d 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -133,8 +133,6 @@ class ConstantFolder : public ExprMutator { } // We should think about potentially constant evaluation over these ops too. - //if (call->op == invoke_tvm_op_ || call->op == shape_func_op_ || call->op == alloc_tensor_op_ || - // call->op == alloc_storage_op_) { static auto nonComputational = Op::GetAttrMap("TNonComputational"); if (auto call_node = call->op.as()) { Op op = GetRef(call_node); From 4d983a9aa01b49fc49e27454955ddb781de0c405 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Tue, 20 Oct 2020 13:23:48 -0700 Subject: [PATCH 3/8] check if op in nonComputational map --- src/relay/transforms/fold_constant.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index da9053612914..f8f31ea42d56 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -154,8 +154,10 @@ class ConstantFolder : public MixedModeMutator { static auto nonComputational = Op::GetAttrMap("TNonComputational"); if (auto call_node = call->op.as()) { Op op = GetRef(call_node); - if (nonComputational[op]) { - return GetRef(call); + if (nonComputational.count(op)) { + if (nonComputational[op]) { + return GetRef(call); + } } } From de609284368cb131d73cd9eb056c993b9da2bf4b Mon Sep 17 00:00:00 2001 From: electriclilies Date: Tue, 20 Oct 2020 17:16:14 -0700 Subject: [PATCH 4/8] forgot to mark device_copy op as TNonComputational --- src/relay/op/device_copy.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/relay/op/device_copy.cc b/src/relay/op/device_copy.cc index b26dc879be0a..c5a9e9320423 100644 --- a/src/relay/op/device_copy.cc +++ b/src/relay/op/device_copy.cc @@ -61,6 +61,7 @@ on different devices. .add_type_rel("Identity", IdentityRel) .set_attr("TOpPattern", kOpaque) .set_attr("TOpIsStateful", false) + .set_attr("TNonComputational", true) .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) .set_attr("FTVMCompute", [](const Attrs& attrs, const Array& inputs, From b1ff6ff2404f99545ac69fcb35cd9d8c981ab516 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Tue, 20 Oct 2020 20:31:22 -0700 Subject: [PATCH 5/8] hacky fix to fuseops pass --- src/relay/transforms/fuse_ops.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/relay/transforms/fuse_ops.cc b/src/relay/transforms/fuse_ops.cc index bc6335a539af..5769b38cdf69 100644 --- a/src/relay/transforms/fuse_ops.cc +++ b/src/relay/transforms/fuse_ops.cc @@ -857,6 +857,10 @@ class FuseMutator : private ExprMutator { if (call->op.as()) { static auto fnoncomputational = Op::GetAttrMap("TNonComputational"); + if(call->op == Op::Get("device_copy")) { + return ExprMutator::VisitExpr_(call); + } + if (fnoncomputational.get(Downcast(call->op), false)) { return ExprMutator::VisitExpr_(call); } From c4f79426b24c0518f39ab6d74d5bed42ce26a121 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Tue, 20 Oct 2020 20:34:52 -0700 Subject: [PATCH 6/8] fix typo --- src/relay/transforms/fuse_ops.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/transforms/fuse_ops.cc b/src/relay/transforms/fuse_ops.cc index 5769b38cdf69..3a5d064871ab 100644 --- a/src/relay/transforms/fuse_ops.cc +++ b/src/relay/transforms/fuse_ops.cc @@ -857,7 +857,7 @@ class FuseMutator : private ExprMutator { if (call->op.as()) { static auto fnoncomputational = Op::GetAttrMap("TNonComputational"); - if(call->op == Op::Get("device_copy")) { + if (call->op == Op::Get("device_copy")) { return ExprMutator::VisitExpr_(call); } From 68150e217e35d0d2a9e3c5589066b4eaab6176c5 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Wed, 21 Oct 2020 09:41:08 -0700 Subject: [PATCH 7/8] manually skip device_copy in fold_constant --- src/relay/op/device_copy.cc | 1 - src/relay/transforms/fold_constant.cc | 8 +++----- src/relay/transforms/fuse_ops.cc | 4 ---- 3 files changed, 3 insertions(+), 10 deletions(-) diff --git a/src/relay/op/device_copy.cc b/src/relay/op/device_copy.cc index c5a9e9320423..b26dc879be0a 100644 --- a/src/relay/op/device_copy.cc +++ b/src/relay/op/device_copy.cc @@ -61,7 +61,6 @@ on different devices. .add_type_rel("Identity", IdentityRel) .set_attr("TOpPattern", kOpaque) .set_attr("TOpIsStateful", false) - .set_attr("TNonComputational", true) .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) .set_attr("FTVMCompute", [](const Attrs& attrs, const Array& inputs, diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index f8f31ea42d56..417a39a7e604 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -151,13 +151,11 @@ class ConstantFolder : public MixedModeMutator { } // We should think about potentially constant evaluation over these ops too. - static auto nonComputational = Op::GetAttrMap("TNonComputational"); + static auto fnoncomputational = Op::GetAttrMap("TNonComputational"); if (auto call_node = call->op.as()) { Op op = GetRef(call_node); - if (nonComputational.count(op)) { - if (nonComputational[op]) { - return GetRef(call); - } + if ((fnoncomputational.count(op) && fnoncomputational[op]) || (call->op == device_copy_op_)) { + return GetRef(call); } } diff --git a/src/relay/transforms/fuse_ops.cc b/src/relay/transforms/fuse_ops.cc index 3a5d064871ab..bc6335a539af 100644 --- a/src/relay/transforms/fuse_ops.cc +++ b/src/relay/transforms/fuse_ops.cc @@ -857,10 +857,6 @@ class FuseMutator : private ExprMutator { if (call->op.as()) { static auto fnoncomputational = Op::GetAttrMap("TNonComputational"); - if (call->op == Op::Get("device_copy")) { - return ExprMutator::VisitExpr_(call); - } - if (fnoncomputational.get(Downcast(call->op), false)) { return ExprMutator::VisitExpr_(call); } From 2653ecd8c70aa80cb26849def2a5416a26deb98e Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Thu, 22 Oct 2020 13:34:27 -0700 Subject: [PATCH 8/8] Update src/relay/transforms/fold_constant.cc Co-authored-by: Junru Shao --- src/relay/transforms/fold_constant.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index 417a39a7e604..4a739ddba40f 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -152,7 +152,7 @@ class ConstantFolder : public MixedModeMutator { // We should think about potentially constant evaluation over these ops too. static auto fnoncomputational = Op::GetAttrMap("TNonComputational"); - if (auto call_node = call->op.as()) { + if (const auto* call_node = call->op.as()) { Op op = GetRef(call_node); if ((fnoncomputational.count(op) && fnoncomputational[op]) || (call->op == device_copy_op_)) { return GetRef(call);