Skip to content

Commit

Permalink
use an identify function for some ops
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthew committed Jun 1, 2021
1 parent 424eaa1 commit ad608d4
Showing 1 changed file with 10 additions and 22 deletions.
32 changes: 10 additions & 22 deletions python/tvm/relay/transform/quantize_fake_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,31 +57,19 @@ def quantize_qfq(expr, type_map):
return [out, expr.args[1], expr.args[2], expr.attrs.out_dtype]


@register_quantize_fake_quantization("reshape")
def reshape_qfq(expr, type_map):
"""Rewrite a reshape op"""
arg = expr.args[0]
t = type_map[arg]
out = relay.op.reshape(arg, **expr.attrs)
return [out, t.scale, t.zero_point, t.dtype]
def register_qfq_identity(op_name, op):
def identity(expr, type_map):
arg = expr.args[0]
t = type_map[arg]
out = op(arg, **expr.attrs)
return [out, t.scale, t.zero_point, t.dtype]


@register_quantize_fake_quantization("transpose")
def transpose_qfq(expr, type_map):
"""Rewrite a transpose op"""
arg = expr.args[0]
t = type_map[arg]
out = relay.op.transpose(arg, **expr.attrs)
return [out, t.scale, t.zero_point, t.dtype]
return register_quantize_fake_quantization(op_name, identity)


@register_quantize_fake_quantization("nn.max_pool2d")
def maxpool_qfq(expr, type_map):
"""Rewrite a maxpool op"""
arg = expr.args[0]
t = type_map[arg]
out = relay.op.nn.max_pool2d(arg, **expr.attrs)
return [out, t.scale, t.zero_point, t.dtype]
register_qfq_identity("reshape", relay.op.reshape)
register_qfq_identity("transpose", relay.op.transpose)
register_qfq_identity("nn.max_pool2d", relay.op.nn.max_pool2d)


@register_quantize_fake_quantization("nn.avg_pool2d")
Expand Down

0 comments on commit ad608d4

Please sign in to comment.