Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add smmla/ummla support in quantized Conv2d #6802

Merged
merged 4 commits into from
Nov 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions python/tvm/topi/arm_cpu/arm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,21 @@ def get_arch_version(target_mattr):


def is_dotprod_available():
""" Checks whether the hardware has support for fast Int8 arithmetic operations. """
""" Checks whether the hardware has support for udot/sdot instructions. """
target = tvm.target.Target.current(allow_none=False)
arch_version = get_arch_version(target.mattr)
return arch_version >= 8.4 or ((arch_version in (8.2, 8.3)) and "+dotprod" in target.mattr)


def is_mmla_available():
""" Checks whether the hardware has support for ummla/smmla instructions. """
target = tvm.target.Target.current(allow_none=False)
arch_version = get_arch_version(target.mattr)
return arch_version >= 8.6 or (
(arch_version in (8.2, 8.3, 8.4, 8.5)) and "+i8mm" in target.mattr
)


def is_aarch64_arm():
""" Checks whether we are compiling for an AArch64 target. """
target = tvm.target.Target.current(allow_none=False)
Expand All @@ -63,8 +72,10 @@ def get_tiling_B_interleaved_t(interleave_A):
tile computation.

Please refer to:
- https://discuss.tvm.apache.org/t/rfc-accelerate-quantized-convolution-through-dot-product
- Conv2DGemmWeightTransformRel in src/relay/op/nn/convolution.h
- https://discuss.tvm.apache.org/t/rfc-improve-quantized-convolution-performance-for-armv8-architectures # pylint: disable=line-too-long
- https://discuss.tvm.apache.org/t/rfc-accelerate-quantized-convolution-through-dot-product
- https://discuss.tvm.apache.org/t/rfc-improve-quantized-convolution-through-mmla-instruction
- Conv2DGemmWeightTransformRel in src/relay/op/nn/convolution.h
In order to have more information

Parameters
Expand All @@ -77,7 +88,13 @@ def get_tiling_B_interleaved_t(interleave_A):
tile_rows_B: the output tile rows of B'
tile_cols_B: the output tile columns of B'
"""
if is_dotprod_available():
if is_mmla_available():
# If smmla/ummla is available, A must be interleaved.
# Each load from B' will contain 8 elements
# and we are loading 12 rows of B' (i.e., 12 columns of B)
tile_rows_B = 12
tile_cols_B = 8
elif is_dotprod_available():
# The number of tile rows of B' vary depending on the
# strategy:
# * If we are interleaving A, then we select 12 columns from B'(i.e.,
Expand All @@ -92,7 +109,7 @@ def get_tiling_B_interleaved_t(interleave_A):
# rows of the original matrix B) need to be 4.
tile_cols_B = 4
else:
# If dot product is not available, A must be interleaved. In this case
# If no acceleration is available, A must be interleaved. In this case
# we load 4 rows of B' (i.e., 4 columns of B). Each of them will contain 16 elements
tile_rows_B = 4
tile_cols_B = 16
Expand Down
179 changes: 122 additions & 57 deletions python/tvm/topi/arm_cpu/conv2d_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@
gemm_quantized_impl,
gemm_acc_4x4_int8_int8_int32,
gemm_acc_nx16_int8_int8_int32,
gemm_acc_2x2_int8_int8_int32,
)
from .arm_utils import is_aarch64_arm, is_dotprod_available
from .arm_utils import is_aarch64_arm, is_dotprod_available, is_mmla_available


def configure_knobs(cfg, M, K):
Expand Down Expand Up @@ -130,11 +131,18 @@ def compute_conv2d_gemm_without_weight_transform(
# the tile computation.
#
# Please refer to:
# - https://discuss.tvm.apache.org/t/rfc-accelerate-quantized-convolution-through-dot-product
# - Conv2DGemmWeightTransformRel in src/relay/op/nn/convolution.h
# - https://discuss.tvm.apache.org/t/rfc-improve-quantized-convolution-performance-for-armv8-architectures # pylint: disable=line-too-long
# - https://discuss.tvm.apache.org/t/rfc-accelerate-quantized-convolution-through-dot-product
# - https://discuss.tvm.apache.org/t/rfc-improve-quantized-convolution-through-mmla-instruction
# - Conv2DGemmWeightTransformRel in src/relay/op/nn/convolution.h
# In order to have more information
#
if is_dotprod_available() and interleave_A:
if is_mmla_available():
# If smmla/ummla is enabled, we are loading 8 rows from A. Each row
# will contain 8 elements
tile_rows_A = 8
tile_cols_A = 8
elif is_dotprod_available() and interleave_A:
# If dot product has been enabled, and we are interleaving A
# tile size should be 8x4
tile_rows_A = 8
Expand Down Expand Up @@ -177,24 +185,71 @@ def compute_conv2d_gemm_without_weight_transform(
lambda b, x, y, z, w: A[b, z + tile_rows_A * x, w + tile_cols_A * y],
name="A_interleaved",
)
# Execute GEMM
C_interleaved = te.compute(
(batches, M_padded // tile_rows_A, N_transformed, tile_rows_A, tile_rows_B),
lambda b, x, y, w, z: te.sum(
A_interleaved[b, x, k // tile_cols_A, w, idxm(k, tile_cols_A)].astype("int32")
* B_interleaved_t[y, k // tile_cols_B, z, idxm(k, tile_cols_B)].astype("int32"),
axis=k,
),
name="C_interleaved",
)
# Unpack the result
C = te.compute(
(batches, M, N),
lambda b, x, y: C_interleaved[
b, x // tile_rows_A, y // tile_rows_B, idxm(x, tile_rows_A), idxm(y, tile_rows_B)
].astype(out_dtype),
name="C",
)
if is_mmla_available():
# Execute GEMM. In the case of mmla, we need to enforce the tiling
# from the compute. This is because mmla is doing a tiled computation
# as well. So we have a big 8x12 tile, with small 2x2 sub-tiles
# generated by mmla. In theory we could make the tile 2x2 and
# fuse and split during scheduling, but this would not work
# because of possible padding
C_interleaved = te.compute(
(
batches,
M_padded // tile_rows_A,
N_transformed,
tile_rows_A // 2,
tile_rows_B // 2,
2,
2,
),
lambda b, x, y, w, z, s, t: te.sum(
A_interleaved[b, x, k // tile_cols_A, 2 * w + s, idxm(k, tile_cols_A)].astype(
"int32"
)
* B_interleaved_t[y, k // tile_cols_B, 2 * z + t, idxm(k, tile_cols_B)].astype(
"int32"
),
axis=k,
),
name="C_interleaved",
)
# Unpack the result
C = te.compute(
(batches, M, N),
lambda b, x, y: C_interleaved[
b,
x // tile_rows_A,
y // tile_rows_B,
idxm(x, tile_rows_A) // 2,
idxm(y, tile_rows_B) // 2,
idxm(idxm(x, tile_rows_A), 2),
idxm(idxm(y, tile_rows_B), 2),
].astype(out_dtype),
name="C",
)
else:
# Execute GEMM
C_interleaved = te.compute(
(batches, M_padded // tile_rows_A, N_transformed, tile_rows_A, tile_rows_B),
lambda b, x, y, w, z: te.sum(
A_interleaved[b, x, k // tile_cols_A, w, idxm(k, tile_cols_A)].astype("int32")
* B_interleaved_t[y, k // tile_cols_B, z, idxm(k, tile_cols_B)].astype("int32"),
axis=k,
),
name="C_interleaved",
)
# Unpack the result
C = te.compute(
(batches, M, N),
lambda b, x, y: C_interleaved[
b,
x // tile_rows_A,
y // tile_rows_B,
idxm(x, tile_rows_A),
idxm(y, tile_rows_B),
].astype(out_dtype),
name="C",
)
zero = tvm.tir.const(0)
else:
# No need to pack/unpack, execute GEMM directly
Expand Down Expand Up @@ -255,7 +310,7 @@ def schedule_conv2d_gemm_interleaved(cfg, s, out, final_out):
s[data_im2col].compute_inline()

# Computation(through tensorize)
b, xo, yo, xi, yi = C_interleaved.op.axis
b, xo, yo, xi, yi = C_interleaved.op.axis[0:5]
outer_gemm, inner_gemm = cfg["reorder_gemm"].apply(s, C_interleaved, [xo, yo])

b_outer_gemm_fused = s[C_interleaved].fuse(b, outer_gemm)
Expand All @@ -271,40 +326,50 @@ def schedule_conv2d_gemm_interleaved(cfg, s, out, final_out):

k = C_interleaved.op.reduce_axis[0]
_, M, N = C.shape
if is_dotprod_available():
gemm_acc = gemm_acc_4x4_int8_int8_int32(in_type)
xi_outer, yi_outer, xi_inner, yi_inner = s[C_interleaved].tile(
xi, yi, x_factor=8, y_factor=4
)
k_outer, k_inner = s[C_interleaved].split(k, 4)
xi_inner_outer, xi_inner_inner = s[C_interleaved].split(xi_inner, 4)
s[C_interleaved].reorder(
b_outer_gemm_fused,
inner_gemm,
xi_outer,
yi_outer,
k_outer,
xi_inner_outer,
xi_inner_inner,
yi_inner,
k_inner,
)
s[C_interleaved].tensorize(xi_inner_inner, gemm_acc)
s[C_interleaved].unroll(xi_inner_outer)

elif is_aarch64_arm():
s[C_interleaved].reorder(yi, xi)
K = A_interleaved_input.shape[2]
assert in_type in ["int8", "uint8"], "Only int8 and uint8 gemm are supported"
unroll = cfg["gemm_quantized_unroll"].val
interleave = cfg["gemm_quantized_interleave"].val
gemm = gemm_quantized(M, N, K, unroll, interleave, in_type, out_type)
s[C_interleaved].pragma(
b_outer_gemm_fused,
"import_llvm",
gemm_quantized_impl(M, N, K, unroll, interleave, in_type),
)
s[C_interleaved].tensorize(yi, gemm)
if in_type in ["int8", "uint8"]:
if is_mmla_available():
gemm_acc = gemm_acc_2x2_int8_int8_int32(in_type)
xi_inner, yi_inner = C_interleaved.op.axis[-2:]
k_outer, k_inner = s[C_interleaved].split(k, 8)
s[C_interleaved].reorder(
b_outer_gemm_fused, inner_gemm, k_outer, xi, yi, xi_inner, yi_inner, k_inner
)
s[C_interleaved].tensorize(xi_inner, gemm_acc)
s[C_interleaved].unroll(xi)
s[C_interleaved].unroll(yi)
elif is_dotprod_available():
gemm_acc = gemm_acc_4x4_int8_int8_int32(in_type)
xi_outer, yi_outer, xi_inner, yi_inner = s[C_interleaved].tile(
xi, yi, x_factor=8, y_factor=4
)
k_outer, k_inner = s[C_interleaved].split(k, 4)
xi_inner_outer, xi_inner_inner = s[C_interleaved].split(xi_inner, 4)
s[C_interleaved].reorder(
b_outer_gemm_fused,
inner_gemm,
xi_outer,
yi_outer,
k_outer,
xi_inner_outer,
xi_inner_inner,
yi_inner,
k_inner,
)
s[C_interleaved].tensorize(xi_inner_inner, gemm_acc)
s[C_interleaved].unroll(xi_inner_outer)

elif is_aarch64_arm():
s[C_interleaved].reorder(yi, xi)
K = A_interleaved_input.shape[2]
unroll = cfg["gemm_quantized_unroll"].val
interleave = cfg["gemm_quantized_interleave"].val
gemm = gemm_quantized(M, N, K, unroll, interleave, in_type, out_type)
s[C_interleaved].pragma(
b_outer_gemm_fused,
"import_llvm",
gemm_quantized_impl(M, N, K, unroll, interleave, in_type),
)
s[C_interleaved].tensorize(yi, gemm)

# Output transform
if out != final_out:
Expand Down
Loading