Skip to content

Commit

Permalink
add more binary ops
Browse files Browse the repository at this point in the history
fix pylint

fix black

black broke pylint

oops on black
  • Loading branch information
Matthew committed Jul 15, 2021
1 parent 83c299a commit 3213e48
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 45 deletions.
104 changes: 67 additions & 37 deletions python/tvm/relay/transform/fake_quantization_to_integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def quantize(expr, type_map):
return [out, TensorAffineType(expr.args[1], expr.args[2], expr.attrs.out_dtype)]


def register_unary_identity(op_name, op):
def register_unary_identity(op_name):
def identity(expr, type_map):
assert len(expr.args) == 1
arg = expr.args[0]
Expand All @@ -66,13 +66,13 @@ def identity(expr, type_map):
return register_fake_quantization_to_integer(op_name, identity)


register_unary_identity("reshape", relay.op.reshape)
register_unary_identity("squeeze", relay.op.squeeze)
register_unary_identity("strided_slice", relay.op.strided_slice)
register_unary_identity("transpose", relay.op.transpose)
register_unary_identity("expand_dims", relay.op.expand_dims)
register_unary_identity("nn.max_pool2d", relay.op.nn.max_pool2d)
register_unary_identity("nn.batch_flatten", relay.op.nn.batch_flatten)
register_unary_identity("reshape")
register_unary_identity("squeeze")
register_unary_identity("strided_slice")
register_unary_identity("transpose")
register_unary_identity("expand_dims")
register_unary_identity("nn.max_pool2d")
register_unary_identity("nn.batch_flatten")


@register_fake_quantization_to_integer("nn.avg_pool2d")
Expand Down Expand Up @@ -201,6 +201,7 @@ def clip(expr, type_map):

@register_fake_quantization_to_integer("nn.pad")
def pad(expr, type_map):
"""Rewite an nn.pad op"""
arg = expr.args[0]
t = type_map[arg]
pad_value = expr.args[1]
Expand All @@ -219,12 +220,12 @@ def pad(expr, type_map):
assert isinstance(pad_value, relay.expr.Constant)
pad_value = relay.qnn.op.quantize(pad_value, t.scale, t.zero_point)

z_p = fold_constant(t.zero_point)
out = relay.op.nn.pad(arg, pad_value=pad_value, **expr.attrs)
return [out, t]


def get_binary_types(expr, type_map):
"""Get Affine types of a binary op's inputs and unify them"""
##Support the case where one input is quantized and the other is a constant float
left = expr.args[0]
right = expr.args[1]
Expand Down Expand Up @@ -261,33 +262,62 @@ def get_binary_types(expr, type_map):
return left, right, left_t, right_t, out_t


@register_fake_quantization_to_integer("add")
def add(expr, type_map):
left, right, left_t, right_t, out_type = get_binary_types(expr, type_map)
out = relay.qnn.op.add(
left,
right,
left_t.scale,
left_t.zero_point,
right_t.scale,
right_t.zero_point,
out_type.scale,
out_type.zero_point,
)
return [out, out_type]
def register_binary_qnn(op_name, op):
"""Register a Binary Op that converts to QNN"""

def binary(expr, type_map):
left, right, left_t, right_t, out_t = get_binary_types(expr, type_map)
out = op(
left,
right,
left_t.scale,
left_t.zero_point,
right_t.scale,
right_t.zero_point,
out_t.scale,
out_t.zero_point,
)
return [out, out_t]

return register_fake_quantization_to_integer(op_name, binary)

@register_fake_quantization_to_integer("multiply")
def multiply(expr, type_map):
left, right, left_t, right_t, out_type = get_binary_types(expr, type_map)
out = relay.qnn.op.mul(
left,
right,
left_t.scale,
left_t.zero_point,
right_t.scale,
right_t.zero_point,
out_type.scale,
out_type.zero_point,
)
return [out, out_type]

# Use lambdas here to avoid a circular import problem
# pylint: disable=unnecessary-lambda
register_binary_qnn("add", lambda *args: relay.qnn.op.add(*args))
register_binary_qnn("multiply", lambda *args: relay.qnn.op.mul(*args))
register_binary_qnn("subtract", lambda *args: relay.qnn.op.subtract(*args))


def register_binary_identity(op_name, op):
"""Register a binary op that works directly on int8"""

def binary(expr, type_map):
left, right, left_t, right_t, out_t = get_binary_types(expr, type_map)
if left_t != out_t:
left = relay.qnn.op.requantize(
left,
left_t.scale,
left_t.zero_point,
out_t.scale,
out_t.zero_point,
out_dtype=out_t.dtype,
)

if right_t != out_t:
right = relay.qnn.op.requantize(
right,
right_t.scale,
right_t.zero_point,
out_t.scale,
out_t.zero_point,
out_dtype=out_t.dtype,
)
out = op(left, right)
return [out, out_t]

return register_fake_quantization_to_integer(op_name, binary)


register_binary_identity("minimum", relay.op.minimum)
register_binary_identity("maximum", relay.op.maximum)
32 changes: 24 additions & 8 deletions tests/python/relay/test_pass_fake_quantization_to_integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ def compare_fq_to_int(expr, args, allow_rounding_error=False):
mod_int = tvm.relay.transform.FakeQuantizationToInteger()(mod)
assert not tvm.ir.structural_equal(mod, mod_int)

mod_int = tvm.relay.transform.FoldConstant()(mod_int)

ex = relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm")
result = ex.evaluate()(*args).numpy()

Expand Down Expand Up @@ -274,32 +272,50 @@ def test_fake_quantize_clip():
compare_fq_to_int(op, [x_np])


@pytest.mark.parametrize("operator", [relay.op.add, relay.op.multiply])
@pytest.mark.parametrize(
"operator",
[relay.op.add, relay.op.multiply, relay.op.subtract, relay.op.minimum, relay.op.maximum],
)
def test_fake_quantize_binary(operator):
x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8")
x = relay.qnn.op.dequantize(x, relay.const(0.1), relay.const(10))
x = relay.qnn.op.dequantize(x, relay.const(0.1), relay.const(0))

y = relay.var("y", shape=[1, 3, 224, 224], dtype="int8")
y = relay.qnn.op.dequantize(y, relay.const(0.2), relay.const(-10))
y = relay.qnn.op.dequantize(y, relay.const(0.2), relay.const(0))

op = operator(x, y)
op = relay.qnn.op.quantize(op, relay.const(20.0), relay.const(0), out_dtype="int8")
if operator == relay.op.multiply:
out_scale = relay.const(20.0)
else:
out_scale = relay.const(0.1)

op = relay.qnn.op.quantize(op, out_scale, relay.const(0), out_dtype="int8")

x_np = np.random.randint(-25, 25, size=[1, 3, 224, 224], dtype="int8")
y_np = np.random.randint(-25, 25, size=[1, 3, 224, 224], dtype="int8")

compare_fq_to_int(op, [x_np, y_np])


@pytest.mark.parametrize("operator", [relay.op.add, relay.op.multiply])
@pytest.mark.parametrize(
"operator",
[
relay.op.add,
relay.op.multiply,
relay.op.subtract,
relay.op.subtract,
relay.op.minimum,
relay.op.maximum,
],
)
def test_fake_quantize_binary_const(operator):
x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8")
x = relay.qnn.op.dequantize(x, relay.const(0.1), relay.const(10))

y = relay.const(1.0)

op = operator(x, y)
op = relay.qnn.op.quantize(op, relay.const(20.0), relay.const(0), out_dtype="int8")
op = relay.qnn.op.quantize(op, relay.const(0.1), relay.const(10), out_dtype="int8")

x_np = np.random.randint(-25, 25, size=[1, 3, 224, 224], dtype="int8")

Expand Down

0 comments on commit 3213e48

Please sign in to comment.