Skip to content

Commit

Permalink
[Test] Add Test Case to Cover Bug Fix by PR#7432 (#7601)
Browse files Browse the repository at this point in the history
  • Loading branch information
Johnson9009 authored Mar 11, 2021
1 parent df6fb69 commit 56feab9
Showing 1 changed file with 34 additions and 0 deletions.
34 changes: 34 additions & 0 deletions tests/python/relay/test_pass_auto_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,39 @@ def @main(
verify_partition_fails(mod, params)


def test_left_shift_negative():
data = relay.var("data", shape=(1, 16, 64, 64))
weight = relay.const(np.full((16, 16, 3, 3), 256.0))
conv2d = relay.nn.conv2d(data, weight, kernel_size=(3, 3), padding=(1, 1), channels=16)
relu = relay.nn.relu(conv2d)

mod = tvm.IRModule.from_expr(relu)

with tvm.transform.PassContext(opt_level=3):
with relay.quantize.qconfig(
calibrate_mode="global_scale", global_scale=8.0, skip_conv_layers=None
):
qnn_mod = relay.quantize.quantize(mod)

class OpFinder(relay.ExprVisitor):
def __init__(self, op_name):
super(OpFinder, self).__init__()
self._op_name = op_name
self.ops = list()

def visit_call(self, call):
super().visit_call(call)
if call.op.name == self._op_name:
self.ops.append(call)

opf = OpFinder("left_shift")
opf.visit(qnn_mod["main"])
assert len(opf.ops) > 0, 'Broken case, can\'t find any "left_shift" operators.'
for left_shift_op in opf.ops:
shift_amount = left_shift_op.args[1].data.asnumpy()
assert shift_amount >= 0, "Shift amount must be non-negative."


if __name__ == "__main__":
test_mul_rewrite()
test_batch_flatten_rewrite()
Expand All @@ -320,3 +353,4 @@ def @main(
test_unquantizable_prefix_partition()
test_unquantizable_core_partition()
test_unquantizable_suffix_partition()
test_left_shift_negative()

0 comments on commit 56feab9

Please sign in to comment.