diff --git a/python/tvm/relay/transform/fake_quantization_to_integer.py b/python/tvm/relay/transform/fake_quantization_to_integer.py index 0ed75191c40d..1adde9a4a430 100644 --- a/python/tvm/relay/transform/fake_quantization_to_integer.py +++ b/python/tvm/relay/transform/fake_quantization_to_integer.py @@ -101,6 +101,17 @@ def avgpool2d(expr, type_map): return [out, t] +@register_fake_quantization_to_integer("nn.global_avg_pool2d") +def global_avgpool2d(expr, type_map): + """Rewrite a global_avgpool op""" + arg = expr.args[0] + t = type_map[arg] + arg = relay.op.cast(arg, "int32") + out = relay.op.nn.global_avg_pool2d(arg) + out = relay.op.cast(out, t.dtype) + return [out, t] + + @register_fake_quantization_to_integer("nn.bias_add") def bias_add(expr, type_map): """Rewrite a bias_add op""" diff --git a/tests/python/relay/test_pass_fake_quantization_to_integer.py b/tests/python/relay/test_pass_fake_quantization_to_integer.py index 3680310b4f92..c49d837ed920 100644 --- a/tests/python/relay/test_pass_fake_quantization_to_integer.py +++ b/tests/python/relay/test_pass_fake_quantization_to_integer.py @@ -268,6 +268,19 @@ def test_fake_quantize_avgpool(): compare_fq_to_int(op, [x_np], True) +def test_fake_quantize_global_avg_pool(): + x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") + + zero = relay.const(0) + x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) + op = relay.op.nn.global_avg_pool2d(x) + op = relay.qnn.op.quantize(op, relay.const(2.0), zero) + + x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") + + compare_fq_to_int(op, [x_np], True) + + def test_fake_quantize_reshape(): x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8")