From ad608d4a021626f6827e9b2015c50fc0ab05c62b Mon Sep 17 00:00:00 2001 From: Matthew Date: Tue, 1 Jun 2021 10:09:07 -0600 Subject: [PATCH] use an identify function for some ops --- .../transform/quantize_fake_quantization.py | 32 ++++++------------- 1 file changed, 10 insertions(+), 22 deletions(-) diff --git a/python/tvm/relay/transform/quantize_fake_quantization.py b/python/tvm/relay/transform/quantize_fake_quantization.py index 8ca5782643026..d5407d353ac38 100644 --- a/python/tvm/relay/transform/quantize_fake_quantization.py +++ b/python/tvm/relay/transform/quantize_fake_quantization.py @@ -57,31 +57,19 @@ def quantize_qfq(expr, type_map): return [out, expr.args[1], expr.args[2], expr.attrs.out_dtype] -@register_quantize_fake_quantization("reshape") -def reshape_qfq(expr, type_map): - """Rewrite a reshape op""" - arg = expr.args[0] - t = type_map[arg] - out = relay.op.reshape(arg, **expr.attrs) - return [out, t.scale, t.zero_point, t.dtype] +def register_qfq_identity(op_name, op): + def identity(expr, type_map): + arg = expr.args[0] + t = type_map[arg] + out = op(arg, **expr.attrs) + return [out, t.scale, t.zero_point, t.dtype] - -@register_quantize_fake_quantization("transpose") -def transpose_qfq(expr, type_map): - """Rewrite a transpose op""" - arg = expr.args[0] - t = type_map[arg] - out = relay.op.transpose(arg, **expr.attrs) - return [out, t.scale, t.zero_point, t.dtype] + return register_quantize_fake_quantization(op_name, identity) -@register_quantize_fake_quantization("nn.max_pool2d") -def maxpool_qfq(expr, type_map): - """Rewrite a maxpool op""" - arg = expr.args[0] - t = type_map[arg] - out = relay.op.nn.max_pool2d(arg, **expr.attrs) - return [out, t.scale, t.zero_point, t.dtype] +register_qfq_identity("reshape", relay.op.reshape) +register_qfq_identity("transpose", relay.op.transpose) +register_qfq_identity("nn.max_pool2d", relay.op.nn.max_pool2d) @register_quantize_fake_quantization("nn.avg_pool2d")