diff --git a/python/tvm/relay/qnn/op/legalizations.py b/python/tvm/relay/qnn/op/legalizations.py index 4d515dea329f..af5072ef74cd 100644 --- a/python/tvm/relay/qnn/op/legalizations.py +++ b/python/tvm/relay/qnn/op/legalizations.py @@ -239,7 +239,7 @@ def is_fast_int8_on_arm(): def is_aarch64_arm(): """ Checks whether we are compiling for an AArch64 target. """ target = tvm.target.Target.current(allow_none=False) - return 'aarch64' in target.attrs.get("target", "") + return 'aarch64' in target.attrs.get("mtriple", "") ######################## # ARM CPU legalizations. diff --git a/topi/python/topi/arm_cpu/conv2d_gemm.py b/topi/python/topi/arm_cpu/conv2d_gemm.py index 68161c32a0fa..63d96bb44d92 100644 --- a/topi/python/topi/arm_cpu/conv2d_gemm.py +++ b/topi/python/topi/arm_cpu/conv2d_gemm.py @@ -27,7 +27,7 @@ def is_aarch64_arm(): """ Checks whether we are compiling for an AArch64 target. """ target = tvm.target.Target.current(allow_none=False) - return 'aarch64' in target.attrs.get("target", "") + return 'aarch64' in target.attrs.get("mtriple", "") # Compute function diff --git a/topi/python/topi/arm_cpu/tensor_intrin.py b/topi/python/topi/arm_cpu/tensor_intrin.py index d8d9481c2a32..dfa2f05e7960 100644 --- a/topi/python/topi/arm_cpu/tensor_intrin.py +++ b/topi/python/topi/arm_cpu/tensor_intrin.py @@ -267,7 +267,7 @@ def gemv_quantized_impl(M, N, data_type='uint8'): ll_path = temp.relpath("temp.ll") # Create LLVM ir from c source code ll_code = clang.create_llvm(cc_code, - options=["-mtriple=aarch64-linux-gnu -mattr=+neon"], + options=["--target=aarch64-linux-gnu -mattr=+neon"], output=ll_path) return ll_code diff --git a/topi/tests/python/test_topi_conv2d_int8.py b/topi/tests/python/test_topi_conv2d_int8.py index edf4267ddaee..5659147f8c41 100644 --- a/topi/tests/python/test_topi_conv2d_int8.py +++ b/topi/tests/python/test_topi_conv2d_int8.py @@ -26,9 +26,70 @@ from tvm.contrib.pickle_memoize import memoize from topi.nn.util import get_pad_tuple from topi.util import get_const_tuple +from topi.arm_cpu.conv2d_gemm import is_aarch64_arm from common import get_all_backend, Int8Fallback +def compile_conv2d_NHWC_gemm_int8_arm(batch, in_channel, in_size, num_filter, kernel, stride, padding, + dilation=1, add_bias=False, add_relu=False): + pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel)) + padding_sum = pad_top + pad_left + pad_bottom + pad_right + print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size, num_filter, + kernel, stride, padding_sum, dilation)) + + in_height = in_width = in_size + A = te.placeholder((batch, in_height, in_width, in_channel), name='A', dtype='int8') + W = te.placeholder((kernel, kernel, in_channel, num_filter), name='W', dtype='int8') + bias = te.placeholder((num_filter,), name='bias', dtype='int8') + dtype = 'int32' + device = "llvm --device arm_cpu --mtriple aarch64-linux-gnu" + + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Compiling on arm AArch64 target: %s" % device) + with tvm.target.create(device): + assert is_aarch64_arm(), "AArch64 target not recognized" + + C = topi.arm_cpu.compute_conv2d_NHWC_quantized(A, W, (stride, stride), padding, + (dilation, dilation), dtype) + if add_bias: + C = topi.add(C, bias) + if add_relu: + C = topi.nn.relu(C) + s = topi.arm_cpu.schedule_conv2d_NHWC_quantized([C]) + + if add_bias: + tvm.build(s, [A, W, bias, C], device, + name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, + in_channel, + in_size, + num_filter, + kernel, + stride, + padding_sum, + dilation)) + func = tvm.build(s, [A, W, bias, C], device, + name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, + in_channel, + in_size, + num_filter, + kernel, + stride, + padding_sum, + dilation)) + else: + func = tvm.build(s, [A, W, C], device, + name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, + in_channel, + in_size, + num_filter, + kernel, + stride, + padding_sum, + dilation)) + def verify_conv2d_NHWC_gemm_int8(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1, add_bias=False, add_relu=False): pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel)) @@ -409,6 +470,9 @@ def test_conv2d_nhwc(): verify_conv2d_NHWC_gemm_int8(1, 2048, 8, 448, 1, 1, 'SAME', add_bias=True, add_relu=True) verify_conv2d_NHWC_gemm_int8(1, 2048, 8, 192, 1, 1, 'SAME', add_bias=True) + # Let's also verify that it compiles fine on AArch64 targets + compile_conv2d_NHWC_gemm_int8_arm(1, 3, 299, 32, 3, 2, 'SAME') + if __name__ == "__main__": test_conv2d_nchw()