diff --git a/python/tvm/relay/qnn/op/legalizations.py b/python/tvm/relay/qnn/op/legalizations.py index fd3a1686f5a8..52fe6c8ebe2f 100644 --- a/python/tvm/relay/qnn/op/legalizations.py +++ b/python/tvm/relay/qnn/op/legalizations.py @@ -22,6 +22,7 @@ from tvm import relay from tvm._ffi.base import TVMError from .. import op as reg +from ....topi.x86.utils import target_has_sse42 ################################################# # Register the functions for different operators. @@ -343,7 +344,7 @@ def _shift(data, zero_point, out_dtype): def is_fast_int8_on_intel(): """Checks whether the hardware has support for fast Int8 arithmetic operations.""" target = tvm.target.Target.current(allow_none=False) - return target.mcpu in {"skylake-avx512", "cascadelake"} + return target_has_sse42(target.mcpu) def is_fast_int8_on_arm(): diff --git a/python/tvm/topi/x86/conv2d_avx_1x1.py b/python/tvm/topi/x86/conv2d_avx_1x1.py index 32b06725cdc2..bda1f8c725f5 100644 --- a/python/tvm/topi/x86/conv2d_avx_1x1.py +++ b/python/tvm/topi/x86/conv2d_avx_1x1.py @@ -26,11 +26,11 @@ from ..generic import conv2d as conv2d_generic from ..utils import get_const_tuple, simplify from .tensor_intrin import dot_16x1x16_uint8_int8_int32 -from .utils import get_fp32_len +from .utils import get_simd_32bit_lanes def _fallback_schedule(cfg, wkl): - simd_width = get_fp32_len() + simd_width = get_simd_32bit_lanes() pt, pl, pb, pr = wkl.padt, wkl.padl, wkl.padb, wkl.padr HSTR, WSTR = wkl.stride_h, wkl.stride_w dilated_kernel_h = (wkl.kernel_h - 1) * wkl.dilation_h + 1 @@ -157,7 +157,7 @@ def _schedule_conv_NCHWc_int8(s, cfg, data_vec, kernel_vec, conv_out, last): kernel_vec, conv_out, last, - int32_lanes=16, + int32_lanes=get_simd_32bit_lanes(), intrin=dot_16x1x16_uint8_int8_int32(), ) diff --git a/python/tvm/topi/x86/conv2d_avx_common.py b/python/tvm/topi/x86/conv2d_avx_common.py index 5e63de329bba..4f129fc6912f 100644 --- a/python/tvm/topi/x86/conv2d_avx_common.py +++ b/python/tvm/topi/x86/conv2d_avx_common.py @@ -22,11 +22,11 @@ from ..generic import conv2d as conv2d_generic from ..utils import get_const_tuple from .tensor_intrin import dot_16x1x16_uint8_int8_int32 -from .utils import get_fp32_len +from .utils import get_simd_32bit_lanes def _fallback_schedule(cfg, wkl): - simd_width = get_fp32_len() + simd_width = get_simd_32bit_lanes() pt, pl, pb, pr = wkl.padt, wkl.padl, wkl.padb, wkl.padr HSTR, WSTR = wkl.stride_h, wkl.stride_w dilated_kernel_w = (wkl.kernel_w - 1) * wkl.dilation_w + 1 @@ -174,6 +174,6 @@ def _schedule_conv_NCHWc_int8(s, cfg, data_vec, kernel_vec, conv_out, last): kernel_vec, conv_out, last, - int32_lanes=16, + int32_lanes=get_simd_32bit_lanes(), intrin=dot_16x1x16_uint8_int8_int32(), ) diff --git a/python/tvm/topi/x86/conv2d_int8.py b/python/tvm/topi/x86/conv2d_int8.py index ca0d0b8b223c..075723303841 100644 --- a/python/tvm/topi/x86/conv2d_int8.py +++ b/python/tvm/topi/x86/conv2d_int8.py @@ -30,6 +30,7 @@ from ..utils import get_const_tuple, traverse_inline from .. import nn from . import conv2d_avx_1x1, conv2d_avx_common +from .utils import target_has_sse42 def _get_default_config_int8( @@ -73,9 +74,7 @@ def is_int8_hw_support(data_dtype, kernel_dtype): # 3) Check target mcpu = tvm.target.Target.current().mcpu - is_target_support = False - if mcpu in ("skylake-avx512", "cascadelake"): - is_target_support = True + is_target_support = target_has_sse42(mcpu) return is_dtype_support and is_llvm_support and is_target_support diff --git a/python/tvm/topi/x86/conv3d.py b/python/tvm/topi/x86/conv3d.py index d5b09e640e16..c4194167ce47 100644 --- a/python/tvm/topi/x86/conv3d.py +++ b/python/tvm/topi/x86/conv3d.py @@ -26,7 +26,7 @@ from ..nn.utils import get_pad_tuple3d, infer_pad3d from ..nn.pad import pad from ..utils import get_const_tuple, simplify, get_const_int -from .utils import get_fp32_len +from .utils import get_simd_32bit_lanes Workload3D = namedtuple( "Workload", @@ -520,7 +520,7 @@ def _get_conv3d_workload(data, kernel, stride, padding, out_dtype, data_layout=" def _fallback_schedule(cfg, wkl): - simd_width = get_fp32_len() + simd_width = get_simd_32bit_lanes() DPAD, HPAD, WPAD = wkl.dpad, wkl.hpad, wkl.wpad DSTR, HSTR, WSTR = wkl.dstride, wkl.hstride, wkl.wstride out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1 diff --git a/python/tvm/topi/x86/dense.py b/python/tvm/topi/x86/dense.py index 29c378dda30f..9799ec02d644 100644 --- a/python/tvm/topi/x86/dense.py +++ b/python/tvm/topi/x86/dense.py @@ -26,7 +26,7 @@ from tvm.contrib import mkl from tvm.contrib import mkldnn -from .utils import get_fp32_len +from .utils import get_simd_32bit_lanes from .. import generic, tag from ..utils import traverse_inline, get_const_tuple @@ -107,7 +107,7 @@ def _default_dense_pack_config(cfg, M, N, K): if isinstance(K, (tvm.tir.Var, tvm.tir.Any)): K = 16 - vec_width = get_fp32_len() + vec_width = get_simd_32bit_lanes() tilex_ii = 1 for bn in range(vec_width * 2, 0, -1): if N % bn == 0: @@ -145,7 +145,7 @@ def _default_dense_nopack_config(cfg, M, N, K): if isinstance(K, (tvm.tir.Var, tvm.tir.Any)): K = 16 - vec_width = get_fp32_len() + vec_width = get_simd_32bit_lanes() tilek_bn = 1 for bn in range(vec_width * 2, 0, -1): if K % bn == 0: diff --git a/python/tvm/topi/x86/depthwise_conv2d.py b/python/tvm/topi/x86/depthwise_conv2d.py index a0225ef9e147..5e49c2cb3b78 100644 --- a/python/tvm/topi/x86/depthwise_conv2d.py +++ b/python/tvm/topi/x86/depthwise_conv2d.py @@ -27,7 +27,7 @@ from ..nn.depthwise_conv2d import _get_workload, depthwise_conv2d_infer_layout from ..nn.conv2d import unpack_NCHWc_to_nchw from ..utils import traverse_inline -from .utils import get_fp32_len +from .utils import get_simd_32bit_lanes def _fallback_schedule(cfg, wkl): @@ -40,7 +40,7 @@ def _fallback_schedule(cfg, wkl): wkl : topi.nn.depthwise_conv2d.Workload Convolution workload """ - simd_width = get_fp32_len() + simd_width = get_simd_32bit_lanes() pt, pl, pb, pr = wkl.padt, wkl.padl, wkl.padb, wkl.padr HSTR, WSTR = wkl.stride_h, wkl.stride_w diff --git a/python/tvm/topi/x86/group_conv2d.py b/python/tvm/topi/x86/group_conv2d.py index 0e10052e2428..890a15898a1a 100644 --- a/python/tvm/topi/x86/group_conv2d.py +++ b/python/tvm/topi/x86/group_conv2d.py @@ -23,7 +23,7 @@ from tvm import te from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity -from .utils import get_fp32_len +from .utils import get_simd_32bit_lanes from ..utils import get_const_tuple from ..nn.pad import pad from .. import tag @@ -62,7 +62,7 @@ def _get_default_config( def _fallback_schedule(cfg, wkl): - simd_width = get_fp32_len() + simd_width = get_simd_32bit_lanes() pad_left, pad_right = wkl.padl, wkl.padr stride_w = wkl.stride_w out_width = (wkl.width + pad_left + pad_right - wkl.kernel_w) // stride_w + 1 diff --git a/python/tvm/topi/x86/sparse.py b/python/tvm/topi/x86/sparse.py index 48ec233fa4bb..8a2cb0b69475 100644 --- a/python/tvm/topi/x86/sparse.py +++ b/python/tvm/topi/x86/sparse.py @@ -21,7 +21,7 @@ from ..transform import reshape from ..utils import traverse_inline, get_const_int -from .utils import get_fp32_len +from .utils import get_simd_32bit_lanes def schedule_sparse_dense(outs): @@ -29,7 +29,7 @@ def schedule_sparse_dense(outs): s = te.create_schedule([x.op for x in outs]) def _callback(op): - simd_width = get_fp32_len() + simd_width = get_simd_32bit_lanes() if op.tag == "sparse_dense_sp_lhs_csrmm" or op.tag == "sparse_dense_sp_lhs_csrmm": (y_o, y_i) = s[op].split(s[op].op.axis[1], 2) fused = s[op].fuse(s[op].op.axis[0], y_o) diff --git a/python/tvm/topi/x86/tensor_intrin.py b/python/tvm/topi/x86/tensor_intrin.py index 818765dc0b27..727319c95c5c 100644 --- a/python/tvm/topi/x86/tensor_intrin.py +++ b/python/tvm/topi/x86/tensor_intrin.py @@ -19,20 +19,19 @@ import tvm from tvm import te import tvm.target.codegen +from .utils import target_has_sse42, target_has_vnni, get_simd_32bit_lanes def dot_16x1x16_uint8_int8_int32(): """Dispatch the most optimized intrin depending on the target""" mcpu = tvm.target.Target.current().mcpu - assert mcpu in ( - "skylake-avx512", - "cascadelake", - ), "An old Intel machine that does not have fast Int8 support." - if mcpu == "skylake-avx512": - return dot_16x1x16_uint8_int8_int32_skylake() - # cascadelake - return dot_16x1x16_uint8_int8_int32_cascadelake() + assert target_has_sse42(mcpu), "An old Intel machine that does not have fast Int8 support." + if target_has_vnni(mcpu): + # VNNI capable platform + return dot_16x1x16_uint8_int8_int32_cascadelake() + # vpmaddubsw/vpmaddwd fallback + return dot_16x1x16_uint8_int8_int32_skylake() def dot_16x1x16_uint8_int8_int32_skylake(): @@ -64,7 +63,7 @@ def dot_16x1x16_uint8_int8_int32_skylake(): The Skylake int8 TensorIntrin that can be used in tensorizing schedule """ - int32_lanes = 16 # 16 int32 lanes in AVX512 + int32_lanes = get_simd_32bit_lanes() num_int8_elements = 4 # 4 int8 elements in int32 data = te.placeholder((num_int8_elements,), dtype="uint8", name="data") kernel = te.placeholder((int32_lanes, num_int8_elements), dtype="int8", name="kernel") @@ -84,27 +83,50 @@ def dot_16x1x16_uint8_int8_int32_skylake(): def _intrin_func(ins, outs): def _instr(index): + # int_lx32 - output datatype after pmaddubs - 16 bits to number of lanes + # int_8xl - input datatype to pmaddubs - 8 bits to number of lanes + # int_32xl - output datatype after pmaddw - 32 bits per number of lanes + + if int32_lanes == 4: + int_lx32 = "int16x8" + int_8xl = "int8x16" + int_32xl = "int32x4" + pmaddubs = "llvm.x86.ssse3.pmadd.ub.sw.128" + pmaddw = "llvm.x86.sse2.pmadd.wd" + elif int32_lanes == 8: + int_lx32 = "int16x16" + int_8xl = "int8x32" + int_32xl = "int32x8" + pmaddubs = "llvm.x86.avx2.pmadd.ub.sw" + pmaddw = "llvm.x86.avx2.pmadd.wd" + elif int32_lanes == 16: + int_lx32 = "int16x32" + int_8xl = "int8x64" + int_32xl = "int32x16" + pmaddubs = "llvm.x86.avx512.pmaddubs.w.512" + pmaddw = "llvm.x86.avx512.pmaddw.d.512" + ib = tvm.tir.ir_builder.create() if index == 1: - ib.emit(outs[0].vstore(0, tvm.tir.const(0, "int32x16"))) + ib.emit(outs[0].vstore(0, tvm.tir.const(0, int_32xl))) return ib.get() a_int8 = ins[0].vload([0], "uint8x4") re_int32 = tvm.tir.call_intrin("int32", "tir.reinterpret", a_int8) - vec_ai32 = re_int32.astype("int32x16") - vec_a = tvm.tir.call_intrin("int8x64", "tir.reinterpret", vec_ai32) - vec_b = ins[1].vload([0, 0], "int8x64") - vec_one = tvm.tir.const(1, "int16x32") + vec_ai32 = re_int32.astype(int_32xl) + vec_a = tvm.tir.call_intrin(int_8xl, "tir.reinterpret", vec_ai32) + vec_b = ins[1].vload([0, 0], int_8xl) + vec_one = tvm.tir.const(1, int_lx32) pair_reduction = tvm.tir.call_llvm_pure_intrin( - "int16x32", - "llvm.x86.avx512.pmaddubs.w.512", + int_lx32, + pmaddubs, tvm.tir.const(0, "uint32"), vec_a, vec_b, ) quad_reduction = tvm.tir.call_llvm_pure_intrin( - "int32x16", - "llvm.x86.avx512.pmaddw.d.512", + int_32xl, + pmaddw, tvm.tir.const(0, "uint32"), pair_reduction, vec_one, @@ -112,7 +134,7 @@ def _instr(index): if index == 0: ib.emit(outs[0].vstore(0, quad_reduction)) else: - ib.emit(outs[0].vstore(0, quad_reduction + outs[0].vload([0], "int32x16"))) + ib.emit(outs[0].vstore(0, quad_reduction + outs[0].vload([0], int_32xl))) return ib.get() # body, reset, update diff --git a/python/tvm/topi/x86/utils.py b/python/tvm/topi/x86/utils.py index 92c11a7f1ef1..658a92966257 100644 --- a/python/tvm/topi/x86/utils.py +++ b/python/tvm/topi/x86/utils.py @@ -18,9 +18,95 @@ import tvm -def get_fp32_len(): +def target_has_sse42(target): + return ( + target_has_avx(target) + or target_has_avx2(target) + or target_has_avx512(target) + or target_has_vnni(target) + or target + in { + "silvermont", + "slm", + "goldmont", + "goldmont-plus", + "tremont", + "nehalem", + "corei7", + "westmere", + "bdver1", + "bdver2", + "bdver3", + "x86-64-v2", + } + ) + + +def target_has_avx(target): + return ( + target_has_avx2(target) + or target_has_avx512(target) + or target_has_vnni(target) + or target in {"sandybridge", "corei7-avx", "ivybridge", "core-avx-i"} + ) + + +def target_has_avx2(target): + return ( + target_has_avx512(target) + or target_has_vnni(target) + or target + in { + "haswell", + "core-avx2", + "broadwell", + "skylake", + "bdver4", + "znver1", + "znver2", + "znver3", + "x86-64-v3", + } + ) + + +def target_has_avx512(target): + return target in { + "skylake-avx512", + "skx", + "knl", + "knm", + "x86-64-v4", + "cannonlake", + # explicit enumeration of VNNI capable due to collision with alderlake + "cascadelake", + "icelake-client", + "rocketlake", + "icelake", + "tigerlake", + "cooperlake", + "sapphirerapids", + } + + +def target_has_vnni(target): + return target in { + "cascadelake", + "icelake-client", + "rocketlake", + "icelake", + "tigerlake", + "cooperlake", + "sapphirerapids", + "alderlake", + } + + +def get_simd_32bit_lanes(): mcpu = tvm.target.Target.current().mcpu - fp32_vec_len = 8 - if mcpu in ("skylake-avx512", "cascadelake"): + fp32_vec_len = 4 + if target_has_avx512(mcpu): fp32_vec_len = 16 + elif target_has_avx2(mcpu): + fp32_vec_len = 8 return fp32_vec_len diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index 0ae88fce5b8c..19c12c612ee5 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -1751,7 +1751,7 @@ def _compile(ic, oc, target, data_layout, kernel_layout, dtypes): return assembly def _has_fast_int8_instructions(asm, target): - if "skylake-avx512" in target: + if "nehalem" in target or "core-avx2" in target or "skylake-avx512" in target: return "pmaddubs" in asm elif "cascadelake" in target: return "vpdpbusd" in asm @@ -1761,8 +1761,13 @@ def _has_fast_int8_instructions(asm, target): # TODO(@anijain2305, @icemelon9): disable conv2d_int8 for NHWC data layout. # Re-enable this after adding conv2d_NCHWc_int8 support for NHWC. - # compile conv2d for x86 (skylake, cascadelake) and test assembly contains *pmadd* instructions - targets = ["llvm -mcpu=skylake-avx512", "llvm -mcpu=cascadelake"] + # compile conv2d for x86 (SSE3/AVX2/AVX512/VNNI capable) and test assembly contains *pmadd* instructions + targets = [ + "llvm -mcpu=nehalem", + "llvm -mcpu=core-avx2", + "llvm -mcpu=skylake-avx512", + "llvm -mcpu=cascadelake", + ] llvm_version = tvm.target.codegen.llvm_version_major() for target in targets: if tvm.testing.device_enabled(target) and llvm_version >= 8: @@ -1838,7 +1843,7 @@ def _has_fast_int8_instructions(asm, target): # Check that a vectorized instruction is generated for older Intel # generations, because we default to NCHWc layout. - target = "llvm -mcpu=core-avx2" + target = "llvm -mcpu=x86-64" if tvm.testing.device_enabled(target): fast_int8_dtypes = ("uint8", "int8", "int32") asm = _compile( @@ -1850,7 +1855,7 @@ def _has_fast_int8_instructions(asm, target): dtypes=fast_int8_dtypes, ) # Check that vector int mult and add instructions are generated. - assert "vpmulld" in asm and "vpadd" in asm + assert "pmulhw" in asm and "paddd" in asm @tvm.testing.uses_gpu diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index cbb9285c31d1..ef5824c957e8 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -838,7 +838,7 @@ def before(): from tvm import topi def alter_conv2d(attrs, inputs, tinfos, out_type): - with tvm.target.Target("llvm"): + with tvm.target.Target("llvm -mcpu=core-avx2"): return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, out_type) def expected(): @@ -1322,7 +1322,7 @@ def expected(): y = relay.Function(analysis.free_vars(y), y) return y - target = "llvm" + target = "llvm -mcpu=core-avx2" with tvm.target.Target(target): with TempOpAttr( "nn.dense", "FTVMAlterOpLayout", topi.x86.dense_alter_op._alter_dense_layout @@ -1390,7 +1390,7 @@ def expected(): ) return relay.Function(analysis.free_vars(dense), dense) - with tvm.target.Target("llvm"): + with tvm.target.Target("llvm -mcpu=core-avx2"): with TempOpAttr( "nn.dense", "FTVMAlterOpLayout", topi.x86.dense_alter_op._alter_dense_layout ):