Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix conv2_gemm after target structure update #6037

Merged
merged 1 commit into from
Jul 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/tvm/relay/qnn/op/legalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def is_fast_int8_on_arm():
def is_aarch64_arm():
""" Checks whether we are compiling for an AArch64 target. """
target = tvm.target.Target.current(allow_none=False)
return 'aarch64' in target.attrs.get("target", "")
return 'aarch64' in target.attrs.get("mtriple", "")

########################
# ARM CPU legalizations.
Expand Down
2 changes: 1 addition & 1 deletion topi/python/topi/arm_cpu/conv2d_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
def is_aarch64_arm():
""" Checks whether we are compiling for an AArch64 target. """
target = tvm.target.Target.current(allow_none=False)
return 'aarch64' in target.attrs.get("target", "")
return 'aarch64' in target.attrs.get("mtriple", "")


# Compute function
Expand Down
2 changes: 1 addition & 1 deletion topi/python/topi/arm_cpu/tensor_intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def gemv_quantized_impl(M, N, data_type='uint8'):
ll_path = temp.relpath("temp.ll")
# Create LLVM ir from c source code
ll_code = clang.create_llvm(cc_code,
options=["-mtriple=aarch64-linux-gnu -mattr=+neon"],
options=["--target=aarch64-linux-gnu -mattr=+neon"],
output=ll_path)
return ll_code

Expand Down
64 changes: 64 additions & 0 deletions topi/tests/python/test_topi_conv2d_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,70 @@
from tvm.contrib.pickle_memoize import memoize
from topi.nn.util import get_pad_tuple
from topi.util import get_const_tuple
from topi.arm_cpu.conv2d_gemm import is_aarch64_arm

from common import get_all_backend, Int8Fallback

def compile_conv2d_NHWC_gemm_int8_arm(batch, in_channel, in_size, num_filter, kernel, stride, padding,
dilation=1, add_bias=False, add_relu=False):
pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
padding_sum = pad_top + pad_left + pad_bottom + pad_right
print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size, num_filter,
kernel, stride, padding_sum, dilation))

in_height = in_width = in_size
A = te.placeholder((batch, in_height, in_width, in_channel), name='A', dtype='int8')
W = te.placeholder((kernel, kernel, in_channel, num_filter), name='W', dtype='int8')
bias = te.placeholder((num_filter,), name='bias', dtype='int8')
dtype = 'int32'
device = "llvm --device arm_cpu --mtriple aarch64-linux-gnu"

ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Compiling on arm AArch64 target: %s" % device)
with tvm.target.create(device):
assert is_aarch64_arm(), "AArch64 target not recognized"

C = topi.arm_cpu.compute_conv2d_NHWC_quantized(A, W, (stride, stride), padding,
(dilation, dilation), dtype)
if add_bias:
C = topi.add(C, bias)
if add_relu:
C = topi.nn.relu(C)
s = topi.arm_cpu.schedule_conv2d_NHWC_quantized([C])

if add_bias:
tvm.build(s, [A, W, bias, C], device,
name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch,
in_channel,
in_size,
num_filter,
kernel,
stride,
padding_sum,
dilation))
func = tvm.build(s, [A, W, bias, C], device,
name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch,
in_channel,
in_size,
num_filter,
kernel,
stride,
padding_sum,
dilation))
else:
func = tvm.build(s, [A, W, C], device,
name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch,
in_channel,
in_size,
num_filter,
kernel,
stride,
padding_sum,
dilation))

def verify_conv2d_NHWC_gemm_int8(batch, in_channel, in_size, num_filter, kernel, stride, padding,
dilation=1, add_bias=False, add_relu=False):
pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
Expand Down Expand Up @@ -409,6 +470,9 @@ def test_conv2d_nhwc():
verify_conv2d_NHWC_gemm_int8(1, 2048, 8, 448, 1, 1, 'SAME', add_bias=True, add_relu=True)
verify_conv2d_NHWC_gemm_int8(1, 2048, 8, 192, 1, 1, 'SAME', add_bias=True)

# Let's also verify that it compiles fine on AArch64 targets
compile_conv2d_NHWC_gemm_int8_arm(1, 3, 299, 32, 3, 2, 'SAME')


if __name__ == "__main__":
test_conv2d_nchw()
Expand Down