From 58a6b1c76260ec562a4320d003f6ec1eeca9b95b Mon Sep 17 00:00:00 2001 From: Perry Gibson Date: Mon, 13 Mar 2023 18:19:57 +0000 Subject: [PATCH 1/2] [fix][relay][qnn] Bug fix for 8-bit quantized mul --- python/tvm/relay/quantize/_annotate.py | 43 +++++++++++++++++++------ python/tvm/relay/quantize/_partition.py | 5 +++ 2 files changed, 39 insertions(+), 9 deletions(-) diff --git a/python/tvm/relay/quantize/_annotate.py b/python/tvm/relay/quantize/_annotate.py index ff673d23144a..0222e99e730f 100644 --- a/python/tvm/relay/quantize/_annotate.py +++ b/python/tvm/relay/quantize/_annotate.py @@ -107,7 +107,9 @@ def frewrite_with_guard(ref_call, new_args, ctx): return default_rewrite(ref_call, new_args, ctx) return func(ref_call, new_args, ctx) - return tvm.ir.register_op_attr(op_name, "FQAnnotateRewrite", frewrite_with_guard, level) + return tvm.ir.register_op_attr( + op_name, "FQAnnotateRewrite", frewrite_with_guard, level + ) return _register(frewrite) if frewrite is not None else _register @@ -125,7 +127,11 @@ def attach_simulated_quantize(data, kind, sign=True, rounding="round"): """ quantize_op = _op.get("relay.op.annotation.simulated_quantize") if isinstance(data, _expr.Call) and data.op == quantize_op: - if data.attrs.kind == kind and data.attrs.sign == sign and data.attrs.rounding == rounding: + if ( + data.attrs.kind == kind + and data.attrs.sign == sign + and data.attrs.rounding == rounding + ): return data qctx = quantize_context() @@ -136,12 +142,16 @@ def attach_simulated_quantize(data, kind, sign=True, rounding="round"): dom_scale = _expr.var("dom_scale") clip_min = _expr.var("clip_min") clip_max = _expr.var("clip_max") - qnode = _quantize.simulated_quantize(data, dom_scale, clip_min, clip_max, kind, sign, rounding) + qnode = _quantize.simulated_quantize( + data, dom_scale, clip_min, clip_max, kind, sign, rounding + ) qctx.qnode_map[key] = qnode return qnode -tvm._ffi.register_func("relay.quantize.attach_simulated_quantize", attach_simulated_quantize) +tvm._ffi.register_func( + "relay.quantize.attach_simulated_quantize", attach_simulated_quantize +) @register_annotate_function("nn.contrib_conv2d_NCHWc") @@ -244,6 +254,16 @@ def multiply_rewrite(ref_call, new_args, ctx): rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT) expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) + if rhs_kind in [QAnnotateKind.ACTIVATION, QAnnotateKind.INPUT] and lhs_kind is None: + # quantize rhs to INPUT field + if rhs_kind == QAnnotateKind.ACTIVATION: + rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT) + if _analysis.check_constant(lhs_expr): + lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.WEIGHT) + else: + lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT) + expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) + return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) raise ValueError @@ -281,13 +301,16 @@ def add_rewrite(ref_call, new_args, ctx): if lhs_kind == QAnnotateKind.INPUT and rhs_kind == QAnnotateKind.INPUT: expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) return QAnnotateExpr(expr, QAnnotateKind.INPUT) - if lhs_kind == QAnnotateKind.ACTIVATION and rhs_kind == QAnnotateKind.ACTIVATION: + if ( + lhs_kind == QAnnotateKind.ACTIVATION + and rhs_kind == QAnnotateKind.ACTIVATION + ): rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT) expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) - if (lhs_kind == QAnnotateKind.ACTIVATION and rhs_kind == QAnnotateKind.INPUT) or ( - lhs_kind == QAnnotateKind.INPUT and rhs_kind == QAnnotateKind.ACTIVATION - ): + if ( + lhs_kind == QAnnotateKind.ACTIVATION and rhs_kind == QAnnotateKind.INPUT + ) or (lhs_kind == QAnnotateKind.INPUT and rhs_kind == QAnnotateKind.ACTIVATION): expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) raise ValueError() @@ -387,7 +410,9 @@ def concatenate_rewrite(ref_call, new_args, ctx): return None for i, k in enumerate(kind_list): if k is None: - expr_list[i] = attach_simulated_quantize(expr_list[i], QAnnotateKind.ACTIVATION) + expr_list[i] = attach_simulated_quantize( + expr_list[i], QAnnotateKind.ACTIVATION + ) expr = _forward_op(ref_call, [_expr.Tuple(expr_list)]) return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) diff --git a/python/tvm/relay/quantize/_partition.py b/python/tvm/relay/quantize/_partition.py index 563d28366874..ca980dd9f8fc 100644 --- a/python/tvm/relay/quantize/_partition.py +++ b/python/tvm/relay/quantize/_partition.py @@ -133,6 +133,11 @@ def mul_partition_generic(ref_call, new_args, ctx): lhs = new_args[0].realize() return QPartitionExpr(_forward_op(ref_call, [lhs, rhs])) + if rhs_cond: + # introduced by efficientnet + rhs = new_args[1].realize() + return QPartitionExpr(_forward_op(ref_call, [lhs, rhs])) + if not lhs_cond and not rhs_cond: # trivial case return None From 250dc3fc1389553e345d4b36b60c47a0287f484f Mon Sep 17 00:00:00 2001 From: Perry Gibson Date: Mon, 13 Mar 2023 18:47:02 +0000 Subject: [PATCH 2/2] Update _annotate.py Revert black formatting from my text editor, which I had assumed matched TVM's linter --- python/tvm/relay/quantize/_annotate.py | 33 +++++++------------------- 1 file changed, 9 insertions(+), 24 deletions(-) diff --git a/python/tvm/relay/quantize/_annotate.py b/python/tvm/relay/quantize/_annotate.py index 0222e99e730f..c1a7b50d3f45 100644 --- a/python/tvm/relay/quantize/_annotate.py +++ b/python/tvm/relay/quantize/_annotate.py @@ -107,9 +107,7 @@ def frewrite_with_guard(ref_call, new_args, ctx): return default_rewrite(ref_call, new_args, ctx) return func(ref_call, new_args, ctx) - return tvm.ir.register_op_attr( - op_name, "FQAnnotateRewrite", frewrite_with_guard, level - ) + return tvm.ir.register_op_attr(op_name, "FQAnnotateRewrite", frewrite_with_guard, level) return _register(frewrite) if frewrite is not None else _register @@ -127,11 +125,7 @@ def attach_simulated_quantize(data, kind, sign=True, rounding="round"): """ quantize_op = _op.get("relay.op.annotation.simulated_quantize") if isinstance(data, _expr.Call) and data.op == quantize_op: - if ( - data.attrs.kind == kind - and data.attrs.sign == sign - and data.attrs.rounding == rounding - ): + if data.attrs.kind == kind and data.attrs.sign == sign and data.attrs.rounding == rounding: return data qctx = quantize_context() @@ -142,16 +136,12 @@ def attach_simulated_quantize(data, kind, sign=True, rounding="round"): dom_scale = _expr.var("dom_scale") clip_min = _expr.var("clip_min") clip_max = _expr.var("clip_max") - qnode = _quantize.simulated_quantize( - data, dom_scale, clip_min, clip_max, kind, sign, rounding - ) + qnode = _quantize.simulated_quantize(data, dom_scale, clip_min, clip_max, kind, sign, rounding) qctx.qnode_map[key] = qnode return qnode -tvm._ffi.register_func( - "relay.quantize.attach_simulated_quantize", attach_simulated_quantize -) +tvm._ffi.register_func("relay.quantize.attach_simulated_quantize", attach_simulated_quantize) @register_annotate_function("nn.contrib_conv2d_NCHWc") @@ -301,16 +291,13 @@ def add_rewrite(ref_call, new_args, ctx): if lhs_kind == QAnnotateKind.INPUT and rhs_kind == QAnnotateKind.INPUT: expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) return QAnnotateExpr(expr, QAnnotateKind.INPUT) - if ( - lhs_kind == QAnnotateKind.ACTIVATION - and rhs_kind == QAnnotateKind.ACTIVATION - ): + if lhs_kind == QAnnotateKind.ACTIVATION and rhs_kind == QAnnotateKind.ACTIVATION: rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT) expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) - if ( - lhs_kind == QAnnotateKind.ACTIVATION and rhs_kind == QAnnotateKind.INPUT - ) or (lhs_kind == QAnnotateKind.INPUT and rhs_kind == QAnnotateKind.ACTIVATION): + if (lhs_kind == QAnnotateKind.ACTIVATION and rhs_kind == QAnnotateKind.INPUT) or ( + lhs_kind == QAnnotateKind.INPUT and rhs_kind == QAnnotateKind.ACTIVATION + ): expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) raise ValueError() @@ -410,9 +397,7 @@ def concatenate_rewrite(ref_call, new_args, ctx): return None for i, k in enumerate(kind_list): if k is None: - expr_list[i] = attach_simulated_quantize( - expr_list[i], QAnnotateKind.ACTIVATION - ) + expr_list[i] = attach_simulated_quantize(expr_list[i], QAnnotateKind.ACTIVATION) expr = _forward_op(ref_call, [_expr.Tuple(expr_list)]) return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)