Skip to content

Commit

Permalink
Add smmla/ummla support in quantized Conv2d
Browse files Browse the repository at this point in the history
This introduces support for `smmla`/`ummla` instructions in TVM:
- Added `is_mmla_available` function in `arm_utils.py`
- Added the tiling node + tensorization schedule in `conv2d_gemm.py`
- Added the intrinsic support in `tensor_intrin.py`
- Added the test-case in `test_topi_conv2d_int8.py`

Change-Id: Iff48c77f16fe1e64ecb733da965a879651ce635f
  • Loading branch information
Giuseppe Rossini committed Oct 30, 2020
1 parent 0ce55cb commit 9c8a6fe
Show file tree
Hide file tree
Showing 4 changed files with 200 additions and 23 deletions.
14 changes: 13 additions & 1 deletion python/tvm/topi/arm_cpu/arm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,15 @@ def is_dotprod_available():
return arch_version >= 8.4 or ((arch_version in (8.2, 8.3)) and "+dotprod" in target.mattr)


def is_mmla_available():
""" Checks whether the hardware has support for fast Int8 arithmetic operations. """
target = tvm.target.Target.current(allow_none=False)
arch_version = get_arch_version(target.mattr)
return arch_version >= 8.6 or (
(arch_version in (8.2, 8.3, 8.4, 8.5)) and "+i8mm" in target.mattr
)


def is_aarch64_arm():
""" Checks whether we are compiling for an AArch64 target. """
target = tvm.target.Target.current(allow_none=False)
Expand Down Expand Up @@ -77,7 +86,10 @@ def get_tiling_B_interleaved_t(interleave_A):
tile_rows_B: the output tile rows of B'
tile_cols_B: the output tile columns of B'
"""
if is_dotprod_available():
if is_mmla_available():
tile_rows_B = 12
tile_cols_B = 8
elif is_dotprod_available():
# The number of tile rows of B' vary depending on the
# strategy:
# * If we are interleaving A, then we select 12 columns from B'(i.e.,
Expand Down
105 changes: 83 additions & 22 deletions python/tvm/topi/arm_cpu/conv2d_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@
gemm_quantized_impl,
gemm_acc_4x4_int8_int8_int32,
gemm_acc_nx16_int8_int8_int32,
gemm_acc_2x2_int8_int8_int32,
)
from .arm_utils import is_aarch64_arm, is_dotprod_available
from .arm_utils import is_aarch64_arm, is_dotprod_available, is_mmla_available


def configure_knobs(cfg, M, K):
Expand Down Expand Up @@ -134,7 +135,10 @@ def compute_conv2d_gemm_without_weight_transform(
# - Conv2DGemmWeightTransformRel in src/relay/op/nn/convolution.h
# In order to have more information
#
if is_dotprod_available() and interleave_A:
if is_mmla_available():
tile_rows_A = 8
tile_cols_A = 8
elif is_dotprod_available() and interleave_A:
# If dot product has been enabled, and we are interleaving A
# tile size should be 8x4
tile_rows_A = 8
Expand Down Expand Up @@ -177,24 +181,71 @@ def compute_conv2d_gemm_without_weight_transform(
lambda b, x, y, z, w: A[b, z + tile_rows_A * x, w + tile_cols_A * y],
name="A_interleaved",
)
# Execute GEMM
C_interleaved = te.compute(
(batches, M_padded // tile_rows_A, N_transformed, tile_rows_A, tile_rows_B),
lambda b, x, y, w, z: te.sum(
A_interleaved[b, x, k // tile_cols_A, w, idxm(k, tile_cols_A)].astype("int32")
* B_interleaved_t[y, k // tile_cols_B, z, idxm(k, tile_cols_B)].astype("int32"),
axis=k,
),
name="C_interleaved",
)
# Unpack the result
C = te.compute(
(batches, M, N),
lambda b, x, y: C_interleaved[
b, x // tile_rows_A, y // tile_rows_B, idxm(x, tile_rows_A), idxm(y, tile_rows_B)
].astype(out_dtype),
name="C",
)
if is_mmla_available():
# Execute GEMM. In the case of mmla, we need to enforce the tiling
# from the compute. This is because mmla is doing a tiled computation
# as well. So we have a big 8x12 tile, with small 2x2 sub-tiles
# generated by mmla. In theory we could make the tile 2x2 and
# fuse and split during scheduling, but this would not work
# because of possible padding
C_interleaved = te.compute(
(
batches,
M_padded // tile_rows_A,
N_transformed,
tile_rows_A // 2,
tile_rows_B // 2,
2,
2,
),
lambda b, x, y, w, z, s, t: te.sum(
A_interleaved[b, x, k // tile_cols_A, 2 * w + s, idxm(k, tile_cols_A)].astype(
"int32"
)
* B_interleaved_t[y, k // tile_cols_B, 2 * z + t, idxm(k, tile_cols_B)].astype(
"int32"
),
axis=k,
),
name="C_interleaved",
)
# Unpack the result
C = te.compute(
(batches, M, N),
lambda b, x, y: C_interleaved[
b,
x // tile_rows_A,
y // tile_rows_B,
idxm(x, tile_rows_A) // 2,
idxm(y, tile_rows_B) // 2,
idxm(idxm(x, tile_rows_A), 2),
idxm(idxm(y, tile_rows_B), 2),
].astype(out_dtype),
name="C",
)
else:
# Execute GEMM
C_interleaved = te.compute(
(batches, M_padded // tile_rows_A, N_transformed, tile_rows_A, tile_rows_B),
lambda b, x, y, w, z: te.sum(
A_interleaved[b, x, k // tile_cols_A, w, idxm(k, tile_cols_A)].astype("int32")
* B_interleaved_t[y, k // tile_cols_B, z, idxm(k, tile_cols_B)].astype("int32"),
axis=k,
),
name="C_interleaved",
)
# Unpack the result
C = te.compute(
(batches, M, N),
lambda b, x, y: C_interleaved[
b,
x // tile_rows_A,
y // tile_rows_B,
idxm(x, tile_rows_A),
idxm(y, tile_rows_B),
].astype(out_dtype),
name="C",
)
zero = tvm.tir.const(0)
else:
# No need to pack/unpack, execute GEMM directly
Expand Down Expand Up @@ -255,7 +306,7 @@ def schedule_conv2d_gemm_interleaved(cfg, s, out, final_out):
s[data_im2col].compute_inline()

# Computation(through tensorize)
b, xo, yo, xi, yi = C_interleaved.op.axis
b, xo, yo, xi, yi = C_interleaved.op.axis[0:5]
outer_gemm, inner_gemm = cfg["reorder_gemm"].apply(s, C_interleaved, [xo, yo])

b_outer_gemm_fused = s[C_interleaved].fuse(b, outer_gemm)
Expand All @@ -271,7 +322,17 @@ def schedule_conv2d_gemm_interleaved(cfg, s, out, final_out):

k = C_interleaved.op.reduce_axis[0]
_, M, N = C.shape
if is_dotprod_available():
if is_mmla_available():
gemm_acc = gemm_acc_2x2_int8_int8_int32(in_type)
xi_inner, yi_inner = C_interleaved.op.axis[5:7]
k_outer, k_inner = s[C_interleaved].split(k, 8)
s[C_interleaved].reorder(
b_outer_gemm_fused, inner_gemm, k_outer, xi, yi, xi_inner, yi_inner, k_inner
)
s[C_interleaved].tensorize(xi_inner, gemm_acc)
s[C_interleaved].unroll(xi)
s[C_interleaved].unroll(yi)
elif is_dotprod_available():
gemm_acc = gemm_acc_4x4_int8_int8_int32(in_type)
xi_outer, yi_outer, xi_inner, yi_inner = s[C_interleaved].tile(
xi, yi, x_factor=8, y_factor=4
Expand Down
99 changes: 99 additions & 0 deletions python/tvm/topi/arm_cpu/tensor_intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,6 +879,105 @@ def _instr(index):
)


def gemm_acc_2x2_int8_int8_int32(dtype):
"""
Int8 2x2 matrix multiplication using smmla/ummla instructions
This function takes two arrays of int8 datatype -- A[2][8] and
B[2][8] and produces a 2x2 matrix which is equal to A*B
The pseudo code is as follows.
.. code-block:: c
void mmla_2x2_int8_int8_int32(int8 A[2][8], int8 B[2][8], int32 C[2][2]){
for (int i = 0; i < 2; i++){
for (int j = 0; i < 2; i++){
for (int k = 0; k < 8; k++){
C[i][j] += A[i][k] * B[j][k]
}
}
}
Notes:
* The rows of matrix B are transposed
Parameters
----------
dtype: str, {"uint8", "int8"}
Whether it works on unsigned int or signed int
Returns
-------
intrin : TensorIntrin
The Arm TensorIntrin that can be used in tensorizing schedule
"""
A = te.placeholder((2, 8), dtype, name="A")
B = te.placeholder((2, 8), dtype, name="B")
dtype_vec = dtype + "x16"

k = te.reduce_axis((0, 8), name="k")
C = te.compute(
(2, 2),
lambda i, j: te.sum(A[i, k].astype("int32") * B[j, k].astype("int32"), axis=k),
name="C",
)

aa_buffer = tvm.tir.decl_buffer(
A.shape, dtype, name="aa_buffer", offset_factor=1, strides=[te.var("sa"), 1]
)
bb_buffer = tvm.tir.decl_buffer(
B.shape, dtype, name="bb_buffer", offset_factor=1, strides=[te.var("sb"), 1]
)
cc_buffer = tvm.tir.decl_buffer(
C.shape, dtype="int32", name="cc_buffer", offset_factor=1, strides=[te.var("sc"), 1]
)

llvm_intrin = "llvm.aarch64.neon.smmla" if dtype == "int8" else "llvm.aarch64.neon.ummla"

def _intrin_func(ins, outs):
def _instr(index):
ib = tvm.tir.ir_builder.create()
if index == 1:
ib.emit(outs[0].vstore([0, 0], tvm.tir.const(0, "int32x4")))
return ib.get()
# Load in vec_a the two rows of A
# vec_a = [a, b, c, d, e, f, g, h;
# i, j, k, l, m, n, o, p,]
vec_a = ins[0].vload([0, 0], dtype_vec)
# Load in vec_b the two rows of B
# vec_b = [0, 2, 4, 6, 8, 10, 12, 14;
# 1, 3, 5, 7, 9, 11, 13, 14,]
vec_b = ins[1].vload([0, 0], dtype_vec)

# Execute the matrix multiplication via (s/u)mmla:
# vec_c = [a*0 + b*2 + c*4 + d*6 +e*8 + f*10 + g*12 + h*14;
# a*1 + b*3 + c*5 + d*7 +e*9 + f*11 + g*13 + h*15;
# i*0 + j*2 + k*4 + l*6 +m*8 + n*10 + o*12 + p*14;
# i*1 + j*3 + k*5 + l*7 +m*9 + n*11 + o*13 + p*15]
vec_c = outs[0].vload([0, 0], "int32x4")
vmmla = tvm.tir.call_llvm_intrin(
"int32x4",
llvm_intrin,
tvm.tir.const(3, "uint32"),
vec_c,
vec_a,
vec_b,
)
# Store the result
ib.emit(outs[0].vstore([0, 0], vmmla))
return ib.get()

# body, reset, update
return _instr(0), _instr(1), _instr(2)

buffer_params = {"offset_factor": 1}
return te.decl_tensor_intrin(
C.op,
_intrin_func,
binds={A: aa_buffer, B: bb_buffer, C: cc_buffer},
default_buffer_params=buffer_params,
)


def _q_multiply_shift_arm(op):
"""
Implementation of q_multiply_shift_arm through arm intrinsics
Expand Down
5 changes: 5 additions & 0 deletions tests/python/topi/python/test_topi_conv2d_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ def compile_conv2d_NHWC_gemm_int8_arm(
topi.arm_cpu.compute_conv2d_NHWC_quantized_native,
topi.arm_cpu.schedule_conv2d_NHWC_quantized_native,
),
(
"llvm --device arm_cpu --mtriple aarch64-linux-gnu -mattr=+v8.2a,+i8mm",
topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved,
topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved,
),
]

for device_tuple in devices:
Expand Down

0 comments on commit 9c8a6fe

Please sign in to comment.