Skip to content

Commit

Permalink
Add sse4/avx2 support for fast x86 int8 (vpmaddubsw/vpmaddwd/vpaddd) (#…
Browse files Browse the repository at this point in the history
…8897)

* Add sse4/avx2 support for vpmaddubsw/vpmaddwd/vpaddd

- Extend the list of different target for x86 topi
- Extend tests for conv2d x86 int8 for fast i8 x86 platforms

* fix code style

* Change x86-64-v2 to nahalem in test to support llvm11

* Change test target to get NCHW8c
  • Loading branch information
elvin-n authored Sep 9, 2021
1 parent c650f9a commit 1bebd0a
Show file tree
Hide file tree
Showing 13 changed files with 164 additions and 51 deletions.
3 changes: 2 additions & 1 deletion python/tvm/relay/qnn/op/legalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from tvm import relay
from tvm._ffi.base import TVMError
from .. import op as reg
from ....topi.x86.utils import target_has_sse42

#################################################
# Register the functions for different operators.
Expand Down Expand Up @@ -343,7 +344,7 @@ def _shift(data, zero_point, out_dtype):
def is_fast_int8_on_intel():
"""Checks whether the hardware has support for fast Int8 arithmetic operations."""
target = tvm.target.Target.current(allow_none=False)
return target.mcpu in {"skylake-avx512", "cascadelake"}
return target_has_sse42(target.mcpu)


def is_fast_int8_on_arm():
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/topi/x86/conv2d_avx_1x1.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@
from ..generic import conv2d as conv2d_generic
from ..utils import get_const_tuple, simplify
from .tensor_intrin import dot_16x1x16_uint8_int8_int32
from .utils import get_fp32_len
from .utils import get_simd_32bit_lanes


def _fallback_schedule(cfg, wkl):
simd_width = get_fp32_len()
simd_width = get_simd_32bit_lanes()
pt, pl, pb, pr = wkl.padt, wkl.padl, wkl.padb, wkl.padr
HSTR, WSTR = wkl.stride_h, wkl.stride_w
dilated_kernel_h = (wkl.kernel_h - 1) * wkl.dilation_h + 1
Expand Down Expand Up @@ -157,7 +157,7 @@ def _schedule_conv_NCHWc_int8(s, cfg, data_vec, kernel_vec, conv_out, last):
kernel_vec,
conv_out,
last,
int32_lanes=16,
int32_lanes=get_simd_32bit_lanes(),
intrin=dot_16x1x16_uint8_int8_int32(),
)

Expand Down
6 changes: 3 additions & 3 deletions python/tvm/topi/x86/conv2d_avx_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@
from ..generic import conv2d as conv2d_generic
from ..utils import get_const_tuple
from .tensor_intrin import dot_16x1x16_uint8_int8_int32
from .utils import get_fp32_len
from .utils import get_simd_32bit_lanes


def _fallback_schedule(cfg, wkl):
simd_width = get_fp32_len()
simd_width = get_simd_32bit_lanes()
pt, pl, pb, pr = wkl.padt, wkl.padl, wkl.padb, wkl.padr
HSTR, WSTR = wkl.stride_h, wkl.stride_w
dilated_kernel_w = (wkl.kernel_w - 1) * wkl.dilation_w + 1
Expand Down Expand Up @@ -174,6 +174,6 @@ def _schedule_conv_NCHWc_int8(s, cfg, data_vec, kernel_vec, conv_out, last):
kernel_vec,
conv_out,
last,
int32_lanes=16,
int32_lanes=get_simd_32bit_lanes(),
intrin=dot_16x1x16_uint8_int8_int32(),
)
5 changes: 2 additions & 3 deletions python/tvm/topi/x86/conv2d_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from ..utils import get_const_tuple, traverse_inline
from .. import nn
from . import conv2d_avx_1x1, conv2d_avx_common
from .utils import target_has_sse42


def _get_default_config_int8(
Expand Down Expand Up @@ -73,9 +74,7 @@ def is_int8_hw_support(data_dtype, kernel_dtype):

# 3) Check target
mcpu = tvm.target.Target.current().mcpu
is_target_support = False
if mcpu in ("skylake-avx512", "cascadelake"):
is_target_support = True
is_target_support = target_has_sse42(mcpu)

return is_dtype_support and is_llvm_support and is_target_support

Expand Down
4 changes: 2 additions & 2 deletions python/tvm/topi/x86/conv3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from ..nn.utils import get_pad_tuple3d, infer_pad3d
from ..nn.pad import pad
from ..utils import get_const_tuple, simplify, get_const_int
from .utils import get_fp32_len
from .utils import get_simd_32bit_lanes

Workload3D = namedtuple(
"Workload",
Expand Down Expand Up @@ -520,7 +520,7 @@ def _get_conv3d_workload(data, kernel, stride, padding, out_dtype, data_layout="


def _fallback_schedule(cfg, wkl):
simd_width = get_fp32_len()
simd_width = get_simd_32bit_lanes()
DPAD, HPAD, WPAD = wkl.dpad, wkl.hpad, wkl.wpad
DSTR, HSTR, WSTR = wkl.dstride, wkl.hstride, wkl.wstride
out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/topi/x86/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from tvm.contrib import mkl
from tvm.contrib import mkldnn

from .utils import get_fp32_len
from .utils import get_simd_32bit_lanes
from .. import generic, tag
from ..utils import traverse_inline, get_const_tuple

Expand Down Expand Up @@ -107,7 +107,7 @@ def _default_dense_pack_config(cfg, M, N, K):
if isinstance(K, (tvm.tir.Var, tvm.tir.Any)):
K = 16

vec_width = get_fp32_len()
vec_width = get_simd_32bit_lanes()
tilex_ii = 1
for bn in range(vec_width * 2, 0, -1):
if N % bn == 0:
Expand Down Expand Up @@ -145,7 +145,7 @@ def _default_dense_nopack_config(cfg, M, N, K):
if isinstance(K, (tvm.tir.Var, tvm.tir.Any)):
K = 16

vec_width = get_fp32_len()
vec_width = get_simd_32bit_lanes()
tilek_bn = 1
for bn in range(vec_width * 2, 0, -1):
if K % bn == 0:
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/topi/x86/depthwise_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from ..nn.depthwise_conv2d import _get_workload, depthwise_conv2d_infer_layout
from ..nn.conv2d import unpack_NCHWc_to_nchw
from ..utils import traverse_inline
from .utils import get_fp32_len
from .utils import get_simd_32bit_lanes


def _fallback_schedule(cfg, wkl):
Expand All @@ -40,7 +40,7 @@ def _fallback_schedule(cfg, wkl):
wkl : topi.nn.depthwise_conv2d.Workload
Convolution workload
"""
simd_width = get_fp32_len()
simd_width = get_simd_32bit_lanes()

pt, pl, pb, pr = wkl.padt, wkl.padl, wkl.padb, wkl.padr
HSTR, WSTR = wkl.stride_h, wkl.stride_w
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/topi/x86/group_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from tvm import te
from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity

from .utils import get_fp32_len
from .utils import get_simd_32bit_lanes
from ..utils import get_const_tuple
from ..nn.pad import pad
from .. import tag
Expand Down Expand Up @@ -62,7 +62,7 @@ def _get_default_config(


def _fallback_schedule(cfg, wkl):
simd_width = get_fp32_len()
simd_width = get_simd_32bit_lanes()
pad_left, pad_right = wkl.padl, wkl.padr
stride_w = wkl.stride_w
out_width = (wkl.width + pad_left + pad_right - wkl.kernel_w) // stride_w + 1
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/topi/x86/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@

from ..transform import reshape
from ..utils import traverse_inline, get_const_int
from .utils import get_fp32_len
from .utils import get_simd_32bit_lanes


def schedule_sparse_dense(outs):
"""Create schedule for sparse dense"""
s = te.create_schedule([x.op for x in outs])

def _callback(op):
simd_width = get_fp32_len()
simd_width = get_simd_32bit_lanes()
if op.tag == "sparse_dense_sp_lhs_csrmm" or op.tag == "sparse_dense_sp_lhs_csrmm":
(y_o, y_i) = s[op].split(s[op].op.axis[1], 2)
fused = s[op].fuse(s[op].op.axis[0], y_o)
Expand Down
60 changes: 41 additions & 19 deletions python/tvm/topi/x86/tensor_intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,19 @@
import tvm
from tvm import te
import tvm.target.codegen
from .utils import target_has_sse42, target_has_vnni, get_simd_32bit_lanes


def dot_16x1x16_uint8_int8_int32():
"""Dispatch the most optimized intrin depending on the target"""
mcpu = tvm.target.Target.current().mcpu

assert mcpu in (
"skylake-avx512",
"cascadelake",
), "An old Intel machine that does not have fast Int8 support."
if mcpu == "skylake-avx512":
return dot_16x1x16_uint8_int8_int32_skylake()
# cascadelake
return dot_16x1x16_uint8_int8_int32_cascadelake()
assert target_has_sse42(mcpu), "An old Intel machine that does not have fast Int8 support."
if target_has_vnni(mcpu):
# VNNI capable platform
return dot_16x1x16_uint8_int8_int32_cascadelake()
# vpmaddubsw/vpmaddwd fallback
return dot_16x1x16_uint8_int8_int32_skylake()


def dot_16x1x16_uint8_int8_int32_skylake():
Expand Down Expand Up @@ -64,7 +63,7 @@ def dot_16x1x16_uint8_int8_int32_skylake():
The Skylake int8 TensorIntrin that can be used in tensorizing schedule
"""

int32_lanes = 16 # 16 int32 lanes in AVX512
int32_lanes = get_simd_32bit_lanes()
num_int8_elements = 4 # 4 int8 elements in int32
data = te.placeholder((num_int8_elements,), dtype="uint8", name="data")
kernel = te.placeholder((int32_lanes, num_int8_elements), dtype="int8", name="kernel")
Expand All @@ -84,35 +83,58 @@ def dot_16x1x16_uint8_int8_int32_skylake():

def _intrin_func(ins, outs):
def _instr(index):
# int_lx32 - output datatype after pmaddubs - 16 bits to number of lanes
# int_8xl - input datatype to pmaddubs - 8 bits to number of lanes
# int_32xl - output datatype after pmaddw - 32 bits per number of lanes

if int32_lanes == 4:
int_lx32 = "int16x8"
int_8xl = "int8x16"
int_32xl = "int32x4"
pmaddubs = "llvm.x86.ssse3.pmadd.ub.sw.128"
pmaddw = "llvm.x86.sse2.pmadd.wd"
elif int32_lanes == 8:
int_lx32 = "int16x16"
int_8xl = "int8x32"
int_32xl = "int32x8"
pmaddubs = "llvm.x86.avx2.pmadd.ub.sw"
pmaddw = "llvm.x86.avx2.pmadd.wd"
elif int32_lanes == 16:
int_lx32 = "int16x32"
int_8xl = "int8x64"
int_32xl = "int32x16"
pmaddubs = "llvm.x86.avx512.pmaddubs.w.512"
pmaddw = "llvm.x86.avx512.pmaddw.d.512"

ib = tvm.tir.ir_builder.create()
if index == 1:
ib.emit(outs[0].vstore(0, tvm.tir.const(0, "int32x16")))
ib.emit(outs[0].vstore(0, tvm.tir.const(0, int_32xl)))
return ib.get()

a_int8 = ins[0].vload([0], "uint8x4")
re_int32 = tvm.tir.call_intrin("int32", "tir.reinterpret", a_int8)
vec_ai32 = re_int32.astype("int32x16")
vec_a = tvm.tir.call_intrin("int8x64", "tir.reinterpret", vec_ai32)
vec_b = ins[1].vload([0, 0], "int8x64")
vec_one = tvm.tir.const(1, "int16x32")
vec_ai32 = re_int32.astype(int_32xl)
vec_a = tvm.tir.call_intrin(int_8xl, "tir.reinterpret", vec_ai32)
vec_b = ins[1].vload([0, 0], int_8xl)
vec_one = tvm.tir.const(1, int_lx32)
pair_reduction = tvm.tir.call_llvm_pure_intrin(
"int16x32",
"llvm.x86.avx512.pmaddubs.w.512",
int_lx32,
pmaddubs,
tvm.tir.const(0, "uint32"),
vec_a,
vec_b,
)
quad_reduction = tvm.tir.call_llvm_pure_intrin(
"int32x16",
"llvm.x86.avx512.pmaddw.d.512",
int_32xl,
pmaddw,
tvm.tir.const(0, "uint32"),
pair_reduction,
vec_one,
)
if index == 0:
ib.emit(outs[0].vstore(0, quad_reduction))
else:
ib.emit(outs[0].vstore(0, quad_reduction + outs[0].vload([0], "int32x16")))
ib.emit(outs[0].vstore(0, quad_reduction + outs[0].vload([0], int_32xl)))
return ib.get()

# body, reset, update
Expand Down
92 changes: 89 additions & 3 deletions python/tvm/topi/x86/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,95 @@
import tvm


def get_fp32_len():
def target_has_sse42(target):
return (
target_has_avx(target)
or target_has_avx2(target)
or target_has_avx512(target)
or target_has_vnni(target)
or target
in {
"silvermont",
"slm",
"goldmont",
"goldmont-plus",
"tremont",
"nehalem",
"corei7",
"westmere",
"bdver1",
"bdver2",
"bdver3",
"x86-64-v2",
}
)


def target_has_avx(target):
return (
target_has_avx2(target)
or target_has_avx512(target)
or target_has_vnni(target)
or target in {"sandybridge", "corei7-avx", "ivybridge", "core-avx-i"}
)


def target_has_avx2(target):
return (
target_has_avx512(target)
or target_has_vnni(target)
or target
in {
"haswell",
"core-avx2",
"broadwell",
"skylake",
"bdver4",
"znver1",
"znver2",
"znver3",
"x86-64-v3",
}
)


def target_has_avx512(target):
return target in {
"skylake-avx512",
"skx",
"knl",
"knm",
"x86-64-v4",
"cannonlake",
# explicit enumeration of VNNI capable due to collision with alderlake
"cascadelake",
"icelake-client",
"rocketlake",
"icelake",
"tigerlake",
"cooperlake",
"sapphirerapids",
}


def target_has_vnni(target):
return target in {
"cascadelake",
"icelake-client",
"rocketlake",
"icelake",
"tigerlake",
"cooperlake",
"sapphirerapids",
"alderlake",
}


def get_simd_32bit_lanes():
mcpu = tvm.target.Target.current().mcpu
fp32_vec_len = 8
if mcpu in ("skylake-avx512", "cascadelake"):
fp32_vec_len = 4
if target_has_avx512(mcpu):
fp32_vec_len = 16
elif target_has_avx2(mcpu):
fp32_vec_len = 8
return fp32_vec_len
Loading

0 comments on commit 1bebd0a

Please sign in to comment.