Skip to content

Commit

Permalink
Address review comments and test failures
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe Rossini committed Nov 2, 2020
1 parent ca33ca4 commit b8c8c91
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 93 deletions.
11 changes: 8 additions & 3 deletions python/tvm/topi/arm_cpu/arm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,14 @@ def get_arch_version(target_mattr):


def is_dotprod_available():
""" Checks whether the hardware has support for fast Int8 arithmetic operations. """
""" Checks whether the hardware has support for udot/sdot instructions. """
target = tvm.target.Target.current(allow_none=False)
arch_version = get_arch_version(target.mattr)
return arch_version >= 8.4 or ((arch_version in (8.2, 8.3)) and "+dotprod" in target.mattr)


def is_mmla_available():
""" Checks whether the hardware has support for fast Int8 arithmetic operations. """
""" Checks whether the hardware has support for ummla/smmla instructions. """
target = tvm.target.Target.current(allow_none=False)
arch_version = get_arch_version(target.mattr)
return arch_version >= 8.6 or (
Expand All @@ -72,7 +72,9 @@ def get_tiling_B_interleaved_t(interleave_A):
tile computation.
Please refer to:
- https://discuss.tvm.apache.org/t/rfc-improve-quantized-convolution-performance-for-armv8-architectures/6920
- https://discuss.tvm.apache.org/t/rfc-accelerate-quantized-convolution-through-dot-product
- https://discuss.tvm.apache.org/t/rfc-improve-quantized-convolution-through-mmla-instruction
- Conv2DGemmWeightTransformRel in src/relay/op/nn/convolution.h
In order to have more information
Expand All @@ -87,6 +89,9 @@ def get_tiling_B_interleaved_t(interleave_A):
tile_cols_B: the output tile columns of B'
"""
if is_mmla_available():
# If smmla/ummla is available, A must be interleaved.
# Each load from B' will contain 8 elements
# and we are loading 12 rows of B' (i.e., 12 columns of B)
tile_rows_B = 12
tile_cols_B = 8
elif is_dotprod_available():
Expand All @@ -104,7 +109,7 @@ def get_tiling_B_interleaved_t(interleave_A):
# rows of the original matrix B) need to be 4.
tile_cols_B = 4
else:
# If dot product is not available, A must be interleaved. In this case
# If no acceleration is available, A must be interleaved. In this case
# we load 4 rows of B' (i.e., 4 columns of B). Each of them will contain 16 elements
tile_rows_B = 4
tile_cols_B = 16
Expand Down
92 changes: 48 additions & 44 deletions python/tvm/topi/arm_cpu/conv2d_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,15 @@ def compute_conv2d_gemm_without_weight_transform(
# the tile computation.
#
# Please refer to:
# - https://discuss.tvm.apache.org/t/rfc-improve-quantized-convolution-performance-for-armv8-architectures/6920
# - https://discuss.tvm.apache.org/t/rfc-accelerate-quantized-convolution-through-dot-product
# - https://discuss.tvm.apache.org/t/rfc-improve-quantized-convolution-through-mmla-instruction
# - Conv2DGemmWeightTransformRel in src/relay/op/nn/convolution.h
# In order to have more information
#
if is_mmla_available():
# If smmla/ummla is enabled, we are loading 8 rows from A. Each row
# will contain 8 elements
tile_rows_A = 8
tile_cols_A = 8
elif is_dotprod_available() and interleave_A:
Expand Down Expand Up @@ -322,50 +326,50 @@ def schedule_conv2d_gemm_interleaved(cfg, s, out, final_out):

k = C_interleaved.op.reduce_axis[0]
_, M, N = C.shape
if is_mmla_available():
gemm_acc = gemm_acc_2x2_int8_int8_int32(in_type)
xi_inner, yi_inner = C_interleaved.op.axis[5:7]
k_outer, k_inner = s[C_interleaved].split(k, 8)
s[C_interleaved].reorder(
b_outer_gemm_fused, inner_gemm, k_outer, xi, yi, xi_inner, yi_inner, k_inner
)
s[C_interleaved].tensorize(xi_inner, gemm_acc)
s[C_interleaved].unroll(xi)
s[C_interleaved].unroll(yi)
elif is_dotprod_available():
gemm_acc = gemm_acc_4x4_int8_int8_int32(in_type)
xi_outer, yi_outer, xi_inner, yi_inner = s[C_interleaved].tile(
xi, yi, x_factor=8, y_factor=4
)
k_outer, k_inner = s[C_interleaved].split(k, 4)
xi_inner_outer, xi_inner_inner = s[C_interleaved].split(xi_inner, 4)
s[C_interleaved].reorder(
b_outer_gemm_fused,
inner_gemm,
xi_outer,
yi_outer,
k_outer,
xi_inner_outer,
xi_inner_inner,
yi_inner,
k_inner,
)
s[C_interleaved].tensorize(xi_inner_inner, gemm_acc)
s[C_interleaved].unroll(xi_inner_outer)

elif is_aarch64_arm():
s[C_interleaved].reorder(yi, xi)
K = A_interleaved_input.shape[2]
assert in_type in ["int8", "uint8"], "Only int8 and uint8 gemm are supported"
unroll = cfg["gemm_quantized_unroll"].val
interleave = cfg["gemm_quantized_interleave"].val
gemm = gemm_quantized(M, N, K, unroll, interleave, in_type, out_type)
s[C_interleaved].pragma(
b_outer_gemm_fused,
"import_llvm",
gemm_quantized_impl(M, N, K, unroll, interleave, in_type),
)
s[C_interleaved].tensorize(yi, gemm)
if in_type in ["int8", "uint8"]:
if is_mmla_available():
gemm_acc = gemm_acc_2x2_int8_int8_int32(in_type)
xi_inner, yi_inner = C_interleaved.op.axis[-2:]
k_outer, k_inner = s[C_interleaved].split(k, 8)
s[C_interleaved].reorder(
b_outer_gemm_fused, inner_gemm, k_outer, xi, yi, xi_inner, yi_inner, k_inner
)
s[C_interleaved].tensorize(xi_inner, gemm_acc)
s[C_interleaved].unroll(xi)
s[C_interleaved].unroll(yi)
elif is_dotprod_available():
gemm_acc = gemm_acc_4x4_int8_int8_int32(in_type)
xi_outer, yi_outer, xi_inner, yi_inner = s[C_interleaved].tile(
xi, yi, x_factor=8, y_factor=4
)
k_outer, k_inner = s[C_interleaved].split(k, 4)
xi_inner_outer, xi_inner_inner = s[C_interleaved].split(xi_inner, 4)
s[C_interleaved].reorder(
b_outer_gemm_fused,
inner_gemm,
xi_outer,
yi_outer,
k_outer,
xi_inner_outer,
xi_inner_inner,
yi_inner,
k_inner,
)
s[C_interleaved].tensorize(xi_inner_inner, gemm_acc)
s[C_interleaved].unroll(xi_inner_outer)

elif is_aarch64_arm():
s[C_interleaved].reorder(yi, xi)
K = A_interleaved_input.shape[2]
unroll = cfg["gemm_quantized_unroll"].val
interleave = cfg["gemm_quantized_interleave"].val
gemm = gemm_quantized(M, N, K, unroll, interleave, in_type, out_type)
s[C_interleaved].pragma(
b_outer_gemm_fused,
"import_llvm",
gemm_quantized_impl(M, N, K, unroll, interleave, in_type),
)
s[C_interleaved].tensorize(yi, gemm)

# Output transform
if out != final_out:
Expand Down
15 changes: 7 additions & 8 deletions python/tvm/topi/arm_cpu/tensor_intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,7 @@ def gemm_quantized(M, N, K, unroll, interleave, in_type, out_type):
intrin : TensorIntrin
The ARM uint8/int8 TensorIntrin that can be used in tensorizing schedule
"""
assert in_type in ["uint8", "int8"]
A = te.placeholder((K // 16, te.var("m"), 16), dtype=in_type, name="A")
B = te.placeholder((K // 16, te.var("n"), 16), dtype=in_type, name="B")

Expand Down Expand Up @@ -627,7 +628,7 @@ def gemm_acc_4x4_int8_int8_int32(dtype):
Int8 4x4 matrix multiplication and accumulation using sdot/udot
instructions. This function takes two arrays of int8 datatype
-- A[4][4] and B[4][4] and produces a 4x4 matrix
which is equal to A*B.
which is equal to A*B'.
The pseudo code is as follows.
Expand All @@ -643,7 +644,6 @@ def gemm_acc_4x4_int8_int8_int32(dtype):
}
Notes:
* The rows of matrix B are transposed
* The tiling strategy is picked to maximize register usage.
Parameters
Expand All @@ -656,6 +656,7 @@ def gemm_acc_4x4_int8_int8_int32(dtype):
intrin : TensorIntrin
The Arm TensorIntrin that can be used in tensorizing schedule
"""
assert dtype in ["uint8", "int8"]
# This needs to be a variable number of "rows" since TVM
# "thinks" I only need to compute one row because of
# padding
Expand Down Expand Up @@ -755,7 +756,7 @@ def gemm_acc_nx16_int8_int8_int32(dtype, rows):
"""
Int8 nx16 matrix multiplication and accumulation using sdot/udot instructions
This function takes two arrays of int8 datatype -- A[n][4] and
B[4][16] and produces a rowsx16 matrix which is equal to A*B
B[4][16] and produces a rowsx16 matrix which is equal to A*B'
The pseudo code is as follows.
.. code-block:: c
Expand All @@ -771,7 +772,6 @@ def gemm_acc_nx16_int8_int8_int32(dtype, rows):
}
Notes:
* The rows of matrix B are transposed
* The tile size of B is 16x4. Since the reduction variable k moves between 0 and 16
we need 4 tiles of B to compute a single row of the output. The first 4 values of
k will be fetched from B[0][j][k], the second batch of 4 from B[1][j][k] and so on
Expand All @@ -789,6 +789,7 @@ def gemm_acc_nx16_int8_int8_int32(dtype, rows):
intrin : TensorIntrin
The Arm TensorIntrin that can be used in tensorizing schedule
"""
assert dtype in ["uint8", "int8"]
A = te.placeholder((rows, 16), dtype, name="A")
B = te.placeholder((4, 16, 4), dtype, name="B")
dtype_vec = dtype + "x16"
Expand Down Expand Up @@ -883,7 +884,7 @@ def gemm_acc_2x2_int8_int8_int32(dtype):
"""
Int8 2x2 matrix multiplication using smmla/ummla instructions
This function takes two arrays of int8 datatype -- A[2][8] and
B[2][8] and produces a 2x2 matrix which is equal to A*B
B[2][8] and produces a 2x2 matrix which is equal to A*B'
The pseudo code is as follows.
.. code-block:: c
Expand All @@ -897,9 +898,6 @@ def gemm_acc_2x2_int8_int8_int32(dtype):
}
}
Notes:
* The rows of matrix B are transposed
Parameters
----------
dtype: str, {"uint8", "int8"}
Expand All @@ -910,6 +908,7 @@ def gemm_acc_2x2_int8_int8_int32(dtype):
intrin : TensorIntrin
The Arm TensorIntrin that can be used in tensorizing schedule
"""
assert dtype in ["uint8", "int8"]
A = te.placeholder((2, 8), dtype, name="A")
B = te.placeholder((2, 8), dtype, name="B")
dtype_vec = dtype + "x16"
Expand Down
77 changes: 39 additions & 38 deletions tests/python/topi/python/test_topi_conv2d_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,12 @@ def compile_conv2d_NHWC_gemm_int8_arm(
topi.arm_cpu.compute_conv2d_NHWC_quantized_native,
topi.arm_cpu.schedule_conv2d_NHWC_quantized_native,
),
(
"llvm --device arm_cpu --mtriple aarch64-linux-gnu -mattr=+v8.2a,+i8mm",
topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved,
topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved,
),
# TODO(giuseros) Need LLVM-11 in order to compile with +i8mm extension
# (
# "llvm --device arm_cpu --mtriple aarch64-linux-gnu -mattr=+v8.2a,+i8mm",
# topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved,
# topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved,
# ),
]

for device_tuple in devices:
Expand Down Expand Up @@ -551,43 +552,43 @@ def test_conv2d_nchw():
def test_conv2d_nhwc():
with Int8Fallback():
# Subset of inception v3 expanded (dilation > 1, batch > 1, 'VALID' padding)
#verify_conv2d_NHWC_gemm_int8(1, 3, 299, 32, 3, 2, "SAME")
#verify_conv2d_NHWC_gemm_int8(1, 32, 149, 32, 3, 1, "SAME", dilation=2)
#verify_conv2d_NHWC_gemm_int8(4, 32, 147, 64, 3, 1, "SAME")
#verify_conv2d_NHWC_gemm_int8(1, 64, 73, 80, 1, 1, "SAME")
#verify_conv2d_NHWC_gemm_int8(1, 80, 73, 192, 3, 1, "SAME")
#verify_conv2d_NHWC_gemm_int8(1, 192, 35, 48, 1, 1, "SAME")
#verify_conv2d_NHWC_gemm_int8(1, 192, 35, 64, 1, 1, "VALID")
#verify_conv2d_NHWC_gemm_int8(1, 192, 35, 32, 1, 1, "SAME")
#verify_conv2d_NHWC_gemm_int8(1, 48, 35, 64, 5, 1, "SAME")
#verify_conv2d_NHWC_gemm_int8(1, 96, 35, 96, 3, 1, "SAME")
#verify_conv2d_NHWC_gemm_int8(1, 256, 35, 48, 1, 1, "SAME")
#verify_conv2d_NHWC_gemm_int8(1, 256, 35, 64, 1, 1, "SAME")
#verify_conv2d_NHWC_gemm_int8(1, 288, 35, 64, 1, 1, "SAME")
#verify_conv2d_NHWC_gemm_int8(1, 288, 35, 48, 1, 1, "SAME")
#verify_conv2d_NHWC_gemm_int8(1, 96, 35, 96, 3, 2, "SAME")
#verify_conv2d_NHWC_gemm_int8(1, 128, 17, 192, 7, 1, "SAME", dilation=2)
#verify_conv2d_NHWC_gemm_int8(1, 160, 17, 160, 7, 1, "SAME")
#verify_conv2d_NHWC_gemm_int8(1, 160, 17, 192, 1, 1, "VALID")
#verify_conv2d_NHWC_gemm_int8(1, 192, 17, 192, 1, 1, "SAME")
#verify_conv2d_NHWC_gemm_int8(1, 768, 5, 128, 1, 1, "SAME")
#verify_conv2d_NHWC_gemm_int8(1, 192, 17, 320, 3, 2, "SAME")
#verify_conv2d_NHWC_gemm_int8(1, 192, 17, 192, 3, 2, "SAME")
#verify_conv2d_NHWC_gemm_int8(1, 1280, 8, 192, 1, 1, "SAME")
#verify_conv2d_NHWC_gemm_int8(1, 1280, 8, 384, 1, 1, "SAME")
#verify_conv2d_NHWC_gemm_int8(1, 1280, 8, 320, 1, 1, "SAME")
#verify_conv2d_NHWC_gemm_int8(1, 1280, 8, 448, 1, 1, "SAME")
#verify_conv2d_NHWC_gemm_int8(1, 384, 8, 384, 1, 1, "SAME")
#verify_conv2d_NHWC_gemm_int8(1, 384, 8, 384, 3, 1, "SAME")
#verify_conv2d_NHWC_gemm_int8(1, 448, 8, 384, 3, 1, "VALID")
#verify_conv2d_NHWC_gemm_int8(1, 2048, 8, 320, 1, 1, "SAME")
#verify_conv2d_NHWC_gemm_int8(1, 2048, 8, 448, 1, 1, "SAME", add_bias=True, add_relu=True)
#verify_conv2d_NHWC_gemm_int8(1, 2048, 8, 192, 1, 1, "SAME", add_bias=True)
verify_conv2d_NHWC_gemm_int8(1, 3, 299, 32, 3, 2, "SAME")
verify_conv2d_NHWC_gemm_int8(1, 32, 149, 32, 3, 1, "SAME", dilation=2)
verify_conv2d_NHWC_gemm_int8(4, 32, 147, 64, 3, 1, "SAME")
verify_conv2d_NHWC_gemm_int8(1, 64, 73, 80, 1, 1, "SAME")
verify_conv2d_NHWC_gemm_int8(1, 80, 73, 192, 3, 1, "SAME")
verify_conv2d_NHWC_gemm_int8(1, 192, 35, 48, 1, 1, "SAME")
verify_conv2d_NHWC_gemm_int8(1, 192, 35, 64, 1, 1, "VALID")
verify_conv2d_NHWC_gemm_int8(1, 192, 35, 32, 1, 1, "SAME")
verify_conv2d_NHWC_gemm_int8(1, 48, 35, 64, 5, 1, "SAME")
verify_conv2d_NHWC_gemm_int8(1, 96, 35, 96, 3, 1, "SAME")
verify_conv2d_NHWC_gemm_int8(1, 256, 35, 48, 1, 1, "SAME")
verify_conv2d_NHWC_gemm_int8(1, 256, 35, 64, 1, 1, "SAME")
verify_conv2d_NHWC_gemm_int8(1, 288, 35, 64, 1, 1, "SAME")
verify_conv2d_NHWC_gemm_int8(1, 288, 35, 48, 1, 1, "SAME")
verify_conv2d_NHWC_gemm_int8(1, 96, 35, 96, 3, 2, "SAME")
verify_conv2d_NHWC_gemm_int8(1, 128, 17, 192, 7, 1, "SAME", dilation=2)
verify_conv2d_NHWC_gemm_int8(1, 160, 17, 160, 7, 1, "SAME")
verify_conv2d_NHWC_gemm_int8(1, 160, 17, 192, 1, 1, "VALID")
verify_conv2d_NHWC_gemm_int8(1, 192, 17, 192, 1, 1, "SAME")
verify_conv2d_NHWC_gemm_int8(1, 768, 5, 128, 1, 1, "SAME")
verify_conv2d_NHWC_gemm_int8(1, 192, 17, 320, 3, 2, "SAME")
verify_conv2d_NHWC_gemm_int8(1, 192, 17, 192, 3, 2, "SAME")
verify_conv2d_NHWC_gemm_int8(1, 1280, 8, 192, 1, 1, "SAME")
verify_conv2d_NHWC_gemm_int8(1, 1280, 8, 384, 1, 1, "SAME")
verify_conv2d_NHWC_gemm_int8(1, 1280, 8, 320, 1, 1, "SAME")
verify_conv2d_NHWC_gemm_int8(1, 1280, 8, 448, 1, 1, "SAME")
verify_conv2d_NHWC_gemm_int8(1, 384, 8, 384, 1, 1, "SAME")
verify_conv2d_NHWC_gemm_int8(1, 384, 8, 384, 3, 1, "SAME")
verify_conv2d_NHWC_gemm_int8(1, 448, 8, 384, 3, 1, "VALID")
verify_conv2d_NHWC_gemm_int8(1, 2048, 8, 320, 1, 1, "SAME")
verify_conv2d_NHWC_gemm_int8(1, 2048, 8, 448, 1, 1, "SAME", add_bias=True, add_relu=True)
verify_conv2d_NHWC_gemm_int8(1, 2048, 8, 192, 1, 1, "SAME", add_bias=True)

# Let's also verify that it compiles fine on AArch64 targets
compile_conv2d_NHWC_gemm_int8_arm(1, 3, 299, 32, 3, 2, "SAME")


if __name__ == "__main__":
#test_conv2d_nchw()
test_conv2d_nchw()
test_conv2d_nhwc()

0 comments on commit b8c8c91

Please sign in to comment.