diff --git a/python/tvm/relay/quantize/_annotate.py b/python/tvm/relay/quantize/_annotate.py index c75746812ecc..08d56a3ca23f 100644 --- a/python/tvm/relay/quantize/_annotate.py +++ b/python/tvm/relay/quantize/_annotate.py @@ -166,9 +166,10 @@ def multiply_rewrite(ref_call, new_args, ctx): if lhs_kind is None and rhs_kind is None: return None - if lhs_kind == QAnnotateKind.ACTIVATION and rhs_kind is None: + if lhs_kind in [QAnnotateKind.ACTIVATION, QAnnotateKind.INPUT] and rhs_kind is None: # quantize lhs to INPUT field - lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT) + if lhs_kind == QAnnotateKind.ACTIVATION: + lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT) # quantize rhs to WEIGHT field rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT) expr = _forward_op(ref_call, [lhs_expr, rhs_expr])