diff --git a/python/tvm/relay/transform/fake_quantization_to_integer.py b/python/tvm/relay/transform/fake_quantization_to_integer.py index 867d647c9e418..783204fb700f0 100644 --- a/python/tvm/relay/transform/fake_quantization_to_integer.py +++ b/python/tvm/relay/transform/fake_quantization_to_integer.py @@ -205,7 +205,11 @@ def pad(expr, type_map): arg = expr.args[0] t = type_map[arg] pad_value = expr.args[1] + ## TF2ONNX will sometimes implement the pad_value as a constant without a quantize + ## To support that, the pass lets branches that terminate in a constant through if pad_value in type_map: + ## if the pad value is calcuated from a dequantize op, it should be in the type map + ## and we need to make sure it's affine type matches the arg pad_t = type_map[pad_value] if not tvm.ir.structural_equal(t, pad_t): pad_value = relay.qnn.op.requantize( @@ -217,6 +221,7 @@ def pad(expr, type_map): out_dtype=t.dtype, ) else: + ## If the pad-value is a constant, we need to quantize it assert isinstance(pad_value, relay.expr.Constant) pad_value = relay.qnn.op.quantize(pad_value, t.scale, t.zero_point) diff --git a/src/relay/transforms/fake_quantization_to_integer.cc b/src/relay/transforms/fake_quantization_to_integer.cc index 0b96b086eb5bb..b5f434e74c43b 100644 --- a/src/relay/transforms/fake_quantization_to_integer.cc +++ b/src/relay/transforms/fake_quantization_to_integer.cc @@ -92,9 +92,14 @@ class SubgraphExtractor : public ExprVisitor { } const AffineTypeMap GetAffineTypes() { return affine_types_; } void VisitExpr(const Expr& expr) override { + // When looking for fake quantized subgraphs, we only support data-flow regions of the graph, + // i.e. call nodes/tuples/constants/etc. If we see anything else (like control flow) we + // abort the rewrite. if (expr.as() == nullptr && expr.as() == nullptr && expr.as() == nullptr && expr.as() == nullptr && expr.as() == nullptr) { + LOG(INFO) << "FakeQuantizationToInteger found a non-dataflow op inside" + << " a fake quantize region, aborting this rewrite"; is_fake_quantized_ = false; } else { ExprVisitor::VisitExpr(expr); @@ -172,7 +177,7 @@ class SubgraphMutator : public ExprMutator { } // Call the rewrite Array vals = fqfq[op](expr, affine_types_); - // Save teh outputs of the rewrite + // Save the outputs of the rewrite ICHECK(vals.size() == 2) << "got the wrong number of returned arguments from FTVMFakeQuantizationToInteger for " << AsText(op, false);