Skip to content

Commit

Permalink
[Relay][QNN] Add unit test for int8 (apache#4159)
Browse files Browse the repository at this point in the history
* [bugfix][codegen] fix casting bug in llvm codegen

* update example

* retrigger ci

* check llvm version
  • Loading branch information
zhiics authored and icemelon committed Oct 21, 2019
1 parent e0d286a commit 6f9d028
Showing 1 changed file with 59 additions and 24 deletions.
83 changes: 59 additions & 24 deletions tests/python/relay/test_op_qnn_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def get_output(func, golden_inputs):
qnn_output = get_output(qnn_func, golden_inputs)
np.testing.assert_equal(qnn_output, golden_output)

def no_zero_point_test():
def test_no_zero_point():
# uint8 input
data_shape = (2, 1, 2, 4)
data_dtype = 'uint8'
Expand Down Expand Up @@ -203,7 +203,7 @@ def no_zero_point_test():
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)

def kernel_zero_point_test():
def test_kernel_zero_point():
# uint8 input
data_shape = (2, 4, 2, 4)
data_dtype = 'uint8'
Expand Down Expand Up @@ -247,7 +247,7 @@ def kernel_zero_point_test():
kernel_shape, kernel_dtype)


def input_zero_point_test():
def test_input_zero_point():
# uint8 input
data_shape = (2, 4, 2, 4)
data_dtype = 'uint8'
Expand Down Expand Up @@ -290,7 +290,7 @@ def input_zero_point_test():
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)

def both_zero_point_test():
def test_both_zero_point():
# uint8 input
data_shape = (2, 4, 2, 4)
data_dtype = 'uint8'
Expand Down Expand Up @@ -333,7 +333,7 @@ def both_zero_point_test():
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)

def layout_test():
def test_layout():
# uint8 input
data_shape = (2, 2, 4, 4) # NHWC
data_dtype = 'uint8'
Expand Down Expand Up @@ -378,7 +378,7 @@ def layout_test():



def padding_test():
def test_padding():
# uint8 input
data_shape = (1, 4, 2, 2)
data_dtype = 'uint8'
Expand Down Expand Up @@ -421,7 +421,7 @@ def padding_test():
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)

def dilation_test():
def test_dilation():
# uint8 input
data_shape = (2, 4, 4, 4)
data_dtype = 'uint8'
Expand All @@ -444,7 +444,7 @@ def dilation_test():
kernel_shape, kernel_dtype)


def const_folding_test():
def test_const_folding():
data_shape = (2, 4, 2, 4)
data_dtype = 'uint8'
kernel_shape = (3, 4, 2, 2)
Expand All @@ -470,7 +470,7 @@ def const_folding_test():
folded_func = folded_mod["main"]
assert "reshape" not in folded_func.astext()

def kernel_size_1x1_test():
def test_kernel_size_1x1():
# uint8 input
data_shape = (2, 4, 2, 4)
data_dtype = 'uint8'
Expand All @@ -493,7 +493,7 @@ def kernel_size_1x1_test():
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)

def tflite_large_irregular_test():
def test_tflite_large_irregular():
# uint8 input
data_shape = (1, 1024, 1, 1)
data_dtype = 'uint8'
Expand Down Expand Up @@ -607,7 +607,7 @@ def tflite_anistropic_strides():
golden_output = np.array((124, -92, 164, -132)).reshape(1, 1, 2, 2)
np.testing.assert_equal(qnn_output, golden_output)

def broadcast_layout_test():
def test_broadcast_layout():
# Test broadcast support for NHWC layout.
data_shape = (1, 229, 229, 3) # NHWC
data_dtype = 'uint8'
Expand Down Expand Up @@ -640,17 +640,52 @@ def broadcast_layout_test():
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(mod, "llvm -mcpu=skylake-avx512")


def test_conv2d_int8():
target = "llvm -mcpu=core-avx2"
if not tvm.module.enabled(target):
print("skip because %s is not enabled..." % target)
return

data = relay.var("data", shape=(1, 28, 28, 128), dtype='uint8')
kernel = relay.var("w", shape=(3, 3, 128, 256), dtype='int8')
conv = relay.nn.conv2d(
data,
kernel,
kernel_size=(3, 3),
out_dtype='int32',
data_layout='NHWC',
kernel_layout='HWIO')
func = relay.Function([data, kernel], conv)

with relay.build_config(opt_level=0):
params = {"w": np.zeros((3, 3, 128, 256)).astype("int8")}
# -mcpu should be specified to avoid the llvm jitting error here:
# https://discuss.tvm.ai/t/segfault-in-llvm/3567
# To use VNNI, we need to specify the micro-architecture that supports
# it, e.g. cascadelake.
graph, lib, params = relay.build(func, target, params=params)
mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
mod.set_input("data", np.zeros((1, 28, 28, 128)).astype("uint8"))
mod.set_input(**params)
mod.run()
qnn_output = mod.get_output(0).asnumpy()
golden_output = np.zeros((1, 26, 26, 256)).astype("int32")
np.testing.assert_equal(qnn_output, golden_output)


if __name__ == "__main__":
no_zero_point_test()
input_zero_point_test()
kernel_zero_point_test()
both_zero_point_test()
layout_test()
padding_test()
dilation_test()
const_folding_test()
kernel_size_1x1_test()
tflite_large_irregular_test()
tflite_output_multiplier_greater_than_one()
tflite_anistropic_strides()
broadcast_layout_test()
test_no_zero_point()
test_input_zero_point()
test_kernel_zero_point()
test_both_zero_point()
test_layout()
test_padding()
test_dilation()
test_const_folding()
test_kernel_size_1x1g()
test_tflite_large_irregularg()
test_tflite_output_multiplier_greater_than_one()
test_tflite_anistropic_strides()
test_broadcast_layoutg()
test_conv2d_int8()

0 comments on commit 6f9d028

Please sign in to comment.