diff --git a/tests/python/relay/test_op_qnn_conv2d.py b/tests/python/relay/test_op_qnn_conv2d.py index b4e8bfd71b62..3b5195c051b8 100644 --- a/tests/python/relay/test_op_qnn_conv2d.py +++ b/tests/python/relay/test_op_qnn_conv2d.py @@ -160,7 +160,7 @@ def get_output(func, golden_inputs): qnn_output = get_output(qnn_func, golden_inputs) np.testing.assert_equal(qnn_output, golden_output) -def no_zero_point_test(): +def test_no_zero_point(): # uint8 input data_shape = (2, 1, 2, 4) data_dtype = 'uint8' @@ -203,7 +203,7 @@ def no_zero_point_test(): verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype) -def kernel_zero_point_test(): +def test_kernel_zero_point(): # uint8 input data_shape = (2, 4, 2, 4) data_dtype = 'uint8' @@ -247,7 +247,7 @@ def kernel_zero_point_test(): kernel_shape, kernel_dtype) -def input_zero_point_test(): +def test_input_zero_point(): # uint8 input data_shape = (2, 4, 2, 4) data_dtype = 'uint8' @@ -290,7 +290,7 @@ def input_zero_point_test(): verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype) -def both_zero_point_test(): +def test_both_zero_point(): # uint8 input data_shape = (2, 4, 2, 4) data_dtype = 'uint8' @@ -333,7 +333,7 @@ def both_zero_point_test(): verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype) -def layout_test(): +def test_layout(): # uint8 input data_shape = (2, 2, 4, 4) # NHWC data_dtype = 'uint8' @@ -378,7 +378,7 @@ def layout_test(): -def padding_test(): +def test_padding(): # uint8 input data_shape = (1, 4, 2, 2) data_dtype = 'uint8' @@ -421,7 +421,7 @@ def padding_test(): verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype) -def dilation_test(): +def test_dilation(): # uint8 input data_shape = (2, 4, 4, 4) data_dtype = 'uint8' @@ -444,7 +444,7 @@ def dilation_test(): kernel_shape, kernel_dtype) -def const_folding_test(): +def test_const_folding(): data_shape = (2, 4, 2, 4) data_dtype = 'uint8' kernel_shape = (3, 4, 2, 2) @@ -470,7 +470,7 @@ def const_folding_test(): folded_func = folded_mod["main"] assert "reshape" not in folded_func.astext() -def kernel_size_1x1_test(): +def test_kernel_size_1x1(): # uint8 input data_shape = (2, 4, 2, 4) data_dtype = 'uint8' @@ -493,7 +493,7 @@ def kernel_size_1x1_test(): verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype) -def tflite_large_irregular_test(): +def test_tflite_large_irregular(): # uint8 input data_shape = (1, 1024, 1, 1) data_dtype = 'uint8' @@ -607,7 +607,7 @@ def tflite_anistropic_strides(): golden_output = np.array((124, -92, 164, -132)).reshape(1, 1, 2, 2) np.testing.assert_equal(qnn_output, golden_output) -def broadcast_layout_test(): +def test_broadcast_layout(): # Test broadcast support for NHWC layout. data_shape = (1, 229, 229, 3) # NHWC data_dtype = 'uint8' @@ -640,17 +640,52 @@ def broadcast_layout_test(): with relay.build_config(opt_level=3): graph, lib, params = relay.build(mod, "llvm -mcpu=skylake-avx512") + +def test_conv2d_int8(): + target = "llvm -mcpu=core-avx2" + if not tvm.module.enabled(target): + print("skip because %s is not enabled..." % target) + return + + data = relay.var("data", shape=(1, 28, 28, 128), dtype='uint8') + kernel = relay.var("w", shape=(3, 3, 128, 256), dtype='int8') + conv = relay.nn.conv2d( + data, + kernel, + kernel_size=(3, 3), + out_dtype='int32', + data_layout='NHWC', + kernel_layout='HWIO') + func = relay.Function([data, kernel], conv) + + with relay.build_config(opt_level=0): + params = {"w": np.zeros((3, 3, 128, 256)).astype("int8")} + # -mcpu should be specified to avoid the llvm jitting error here: + # https://discuss.tvm.ai/t/segfault-in-llvm/3567 + # To use VNNI, we need to specify the micro-architecture that supports + # it, e.g. cascadelake. + graph, lib, params = relay.build(func, target, params=params) + mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0)) + mod.set_input("data", np.zeros((1, 28, 28, 128)).astype("uint8")) + mod.set_input(**params) + mod.run() + qnn_output = mod.get_output(0).asnumpy() + golden_output = np.zeros((1, 26, 26, 256)).astype("int32") + np.testing.assert_equal(qnn_output, golden_output) + + if __name__ == "__main__": - no_zero_point_test() - input_zero_point_test() - kernel_zero_point_test() - both_zero_point_test() - layout_test() - padding_test() - dilation_test() - const_folding_test() - kernel_size_1x1_test() - tflite_large_irregular_test() - tflite_output_multiplier_greater_than_one() - tflite_anistropic_strides() - broadcast_layout_test() + test_no_zero_point() + test_input_zero_point() + test_kernel_zero_point() + test_both_zero_point() + test_layout() + test_padding() + test_dilation() + test_const_folding() + test_kernel_size_1x1g() + test_tflite_large_irregularg() + test_tflite_output_multiplier_greater_than_one() + test_tflite_anistropic_strides() + test_broadcast_layoutg() + test_conv2d_int8()