diff --git a/cmake/modules/LLVM.cmake b/cmake/modules/LLVM.cmake index 6c21356ae880..6fb74fc1ef6c 100644 --- a/cmake/modules/LLVM.cmake +++ b/cmake/modules/LLVM.cmake @@ -29,6 +29,9 @@ add_definitions(-DDMLC_USE_FOPEN64=0 -DNDEBUG=1) # It may be a boolean or a string if(NOT ${USE_LLVM} MATCHES ${IS_FALSE_PATTERN}) find_llvm(${USE_LLVM}) + if (${TVM_LLVM_VERSION} LESS 60) + message(FATAL_ERROR "LLVM version 6.0 or greater is required.") + endif() include_directories(SYSTEM ${LLVM_INCLUDE_DIRS}) add_definitions(${LLVM_DEFINITIONS}) message(STATUS "Build with LLVM " ${LLVM_PACKAGE_VERSION}) diff --git a/python/tvm/target/codegen.py b/python/tvm/target/codegen.py index 5d43f4ae24ab..1a2efd4efaff 100644 --- a/python/tvm/target/codegen.py +++ b/python/tvm/target/codegen.py @@ -17,6 +17,7 @@ """Code generation related functions.""" from . import _ffi_api from .target import Target +from ..ir.container import Array def build_module(mod, target): @@ -39,6 +40,30 @@ def build_module(mod, target): return _ffi_api.Build(mod, target) +def target_has_features(cpu_features, target=None): + """Check CPU features for the target's `-mtriple` and `-mcpu` and `-mattr`. + + Parameters + ---------- + target : Target + The TVM target. + cpu_features : str or Array + CPU Feature(s) to check. + + Returns + ------- + has_features : bool + True if target has the feature(s). + """ + assert isinstance(target, Target) or target is None + assert isinstance(cpu_features, (Array, list, tuple, str)) + has_feats = True + cpu_features = [cpu_features] if isinstance(cpu_features, str) else cpu_features + for feat in cpu_features: + has_feats &= _ffi_api.target_has_feature(feat, target) + return has_feats + + def llvm_lookup_intrinsic_id(name): """Lookup LLVM intrinsic id by name. @@ -71,36 +96,76 @@ def llvm_get_intrinsic_name(intrin_id: int) -> str: return _ffi_api.llvm_get_intrinsic_name(intrin_id) -def llvm_x86_get_archlist(only64bit=False): - """Get X86 CPU name list. +def llvm_get_targets(): + """Get LLVM target list. + + Parameters + ---------- + + Returns + ------- + llvm_targets : list[str] + List of available LLVM targets. + """ + return _ffi_api.llvm_get_targets() + + +def llvm_get_cpu_archlist(target=None): + """Get CPU architectures for the target's `-mtriple`. + + Parameters + ---------- + target : Target + The TVM target. + + Returns + ------- + cpu_archlist : list[str] + List of available CPU architectures. + """ + assert isinstance(target, Target) or target is None + return _ffi_api.llvm_get_cpu_archlist(target) + + +def llvm_get_cpu_features(target=None): + """Get CPU features for the target's `-mtriple` and `-mcpu` and considering `-mattr`. Parameters ---------- - only64bit : bool - Filter 64bit architectures. + target : Target + The TVM target. Returns ------- - features : list[str] - String list of X86 architectures. + cpu_features : list[str] + List of available CPU features. """ - return _ffi_api.llvm_x86_get_archlist(only64bit) + assert isinstance(target, Target) or target is None + return _ffi_api.llvm_get_cpu_features(target) -def llvm_x86_get_features(cpu_name): - """Get X86 CPU features. +def llvm_cpu_has_features(cpu_features, target=None): + """Check CPU features for the target's `-mtriple` and `-mcpu` and considering `-mattr`. Parameters ---------- - cpu_name : string - X86 CPU name (e.g. "skylake"). + target : Target + The TVM target. + cpu_features : str or Array + CPU Feature(s) to check. Returns ------- - features : list[str] - String list of X86 CPU features. + has_features : bool + True if target CPU has the feature(s). """ - return _ffi_api.llvm_x86_get_features(cpu_name) + assert isinstance(target, Target) or target is None + assert isinstance(cpu_features, (Array, list, tuple, str)) + has_feats = True + cpu_features = [cpu_features] if isinstance(cpu_features, str) else cpu_features + for feat in cpu_features: + has_feats &= _ffi_api.llvm_cpu_has_feature(feat, target) + return has_feats def llvm_version_major(allow_none=False): diff --git a/python/tvm/target/x86.py b/python/tvm/target/x86.py index a3dcb62e8aa7..c040eface808 100644 --- a/python/tvm/target/x86.py +++ b/python/tvm/target/x86.py @@ -16,30 +16,7 @@ # under the License. """Common x86 related utilities""" from .._ffi import register_func -from . import _ffi_api -from ..ir.container import Array - - -@register_func("tvm.target.x86.target_has_features") -def target_has_features(features, target=None): - """Check X86 CPU features. - Parameters - ---------- - features : str or Array - Feature(s) to check. - target : Target - Optional TVM target, default `None` use the global context target. - Returns - ------- - has_feats : bool - True if feature(s) are in the target arch. - """ - has_feats = True - assert isinstance(features, (Array, str)) - features = [features] if isinstance(features, str) else features - for feat in features: - has_feats &= _ffi_api.llvm_x86_has_feature(feat, target) - return has_feats +from .codegen import target_has_features @register_func("tvm.topi.x86.utils.get_simd_32bit_lanes") @@ -53,9 +30,6 @@ def get_simd_32bit_lanes(): The optimal vector length of CPU from the global context target. """ vec_len = 4 - # avx512f: llvm.x86.avx512.addpd.w.512 (LLVM auto, added) - # avx512bw: llvm.x86.avx512.pmaddubs.w.512" (TVM required) - # + llvm.x86.avx512.pmaddw.d.512" if target_has_features(["avx512bw", "avx512f"]): vec_len = 16 elif target_has_features("avx2"): diff --git a/python/tvm/topi/x86/batch_matmul.py b/python/tvm/topi/x86/batch_matmul.py index 85accab87b2a..e10313323089 100644 --- a/python/tvm/topi/x86/batch_matmul.py +++ b/python/tvm/topi/x86/batch_matmul.py @@ -21,7 +21,7 @@ from tvm import autotvm, te from tvm.autotvm.task.space import SplitEntity from tvm.contrib import cblas, mkl -from tvm.target.x86 import target_has_features +from tvm.target.codegen import target_has_features from .. import generic, nn from ..transform import layout_transform @@ -38,9 +38,6 @@ def batch_matmul_int8_compute(cfg, x, y, *_): packed_y = layout_transform(y, "BNK", packed_y_layout) _, n_o, _, n_i, _ = packed_y.shape ak = te.reduce_axis((0, k), name="k") - # avx512f: llvm.x86.avx512.addpd.w.512 (LLVM auto, added) - # avx512bw: llvm.x86.avx512.pmaddubs.w.512" (TVM required) - # + llvm.x86.avx512.pmaddw.d.512" if target_has_features(["avx512bw", "avx512f"]): attrs_info = {"schedule_rule": "batch_matmul_int8"} else: @@ -241,9 +238,6 @@ def _callback(op): layout_trans = op.input_tensors[1] if target_has_features("amx-int8"): batch_matmul_amx_schedule(cfg, s, op.output(0), outs[0], layout_trans) - # avx512f: llvm.x86.avx512.addpd.w.512 (LLVM auto, added) - # avx512bw: llvm.x86.avx512.pmaddubs.w.512" (TVM required) - # + llvm.x86.avx512.pmaddw.d.512" elif target_has_features(["avx512bw", "avx512f"]): batch_matmul_int8_schedule(cfg, s, op.output(0), outs[0], layout_trans) diff --git a/python/tvm/topi/x86/dense.py b/python/tvm/topi/x86/dense.py index 2437b1a69564..4151ea0b7006 100644 --- a/python/tvm/topi/x86/dense.py +++ b/python/tvm/topi/x86/dense.py @@ -23,7 +23,8 @@ from tvm import autotvm, te from tvm.autotvm.task.space import SplitEntity from tvm.contrib import cblas, dnnl, mkl -from tvm.target.x86 import get_simd_32bit_lanes, target_has_features +from tvm.target.x86 import get_simd_32bit_lanes +from tvm.target.codegen import target_has_features from .. import generic, tag from ..utils import get_const_tuple, traverse_inline @@ -303,9 +304,6 @@ def _callback(op): if "dense_int8" in op.tag: if target_has_features("amx-int8"): dense_amx_int8_schedule(cfg, s, op.output(0), outs[0]) - # avx512f: llvm.x86.avx512.addpd.w.512 (LLVM auto, added) - # avx512bw: llvm.x86.avx512.pmaddubs.w.512" (TVM required) - # + llvm.x86.avx512.pmaddw.d.512" elif target_has_features(["avx512bw", "avx512f"]): dense_int8_schedule(cfg, s, op.output(0), outs[0]) @@ -318,9 +316,6 @@ def dense_int8_compute(cfg, X, packed_w, bias=None): m, k = X.shape n_o, _, n_i, _ = packed_w.shape ak = te.reduce_axis((0, k), name="k") - # avx512f: llvm.x86.avx512.addpd.w.512 (LLVM auto, added) - # avx512bw: llvm.x86.avx512.pmaddubs.w.512" (TVM required) - # + llvm.x86.avx512.pmaddw.d.512" if target_has_features(["avx512bw", "avx512f"]): target_attr = {"schedule_rule": "meta_schedule.x86.dense_int8"} else: diff --git a/python/tvm/topi/x86/dense_alter_op.py b/python/tvm/topi/x86/dense_alter_op.py index ef6df7dd2c9b..0e9b1f7b65f0 100644 --- a/python/tvm/topi/x86/dense_alter_op.py +++ b/python/tvm/topi/x86/dense_alter_op.py @@ -19,7 +19,7 @@ import tvm from tvm import autotvm, relay, te -from tvm.target.x86 import target_has_features +from tvm.target.codegen import target_has_features from .. import nn from ..nn import dense_alter_layout @@ -28,9 +28,6 @@ def check_int8_applicable(x, y, allow_padding=False): - # avx512f: llvm.x86.avx512.addpd.w.512 (LLVM auto, added) - # avx512bw: llvm.x86.avx512.pmaddubs.w.512" (TVM required) - # + llvm.x86.avx512.pmaddw.d.512" simd_avai = target_has_features(["avx512bw", "avx512f"]) simd_avai |= target_has_features("amx-int8") # TODO(vvchernov): may be also target_has_features("avx2") or lower? diff --git a/src/meta_schedule/space_generator/space_generator.cc b/src/meta_schedule/space_generator/space_generator.cc index 4657f962f32c..73df303f725c 100644 --- a/src/meta_schedule/space_generator/space_generator.cc +++ b/src/meta_schedule/space_generator/space_generator.cc @@ -24,20 +24,17 @@ namespace meta_schedule { String GetRuleKindFromTarget(const Target& target) { if (target->kind->name == "llvm") { - static const PackedFunc* llvm_x86_has_feature_fn_ptr = - runtime::Registry::Get("target.llvm_x86_has_feature"); - ICHECK(llvm_x86_has_feature_fn_ptr != nullptr) - << "The `target.llvm_x86_has_feature` func is not in tvm registry."; - bool have_avx512vnni = (*llvm_x86_has_feature_fn_ptr)("avx512vnni", target); - bool have_avxvnni = (*llvm_x86_has_feature_fn_ptr)("avxvnni", target); + static const PackedFunc* target_has_feature_fn_ptr = + runtime::Registry::Get("target.target_has_feature"); + ICHECK(target_has_feature_fn_ptr != nullptr) + << "The `target.target_has_feature` func is not in tvm registry."; + bool have_avx512vnni = (*target_has_feature_fn_ptr)("avx512vnni", target); + bool have_avxvnni = (*target_has_feature_fn_ptr)("avxvnni", target); if (have_avx512vnni || have_avxvnni) { return "vnni"; } else { - // avx512f: llvm.x86.avx512.addpd.w.512 (LLVM auto, added) - // avx512bw: llvm.x86.avx512.pmaddubs.w.512" (TVM required) - // + llvm.x86.avx512.pmaddw.d.512" - bool have_avx512f = (*llvm_x86_has_feature_fn_ptr)("avx512f", target); - bool have_avx512bw = (*llvm_x86_has_feature_fn_ptr)("avx512bw", target); + bool have_avx512f = (*target_has_feature_fn_ptr)("avx512f", target); + bool have_avx512bw = (*target_has_feature_fn_ptr)("avx512bw", target); if (have_avx512bw && have_avx512f) { return "avx512"; } diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc index b57710b26686..2dd74e1321bf 100644 --- a/src/relay/qnn/op/requantize.cc +++ b/src/relay/qnn/op/requantize.cc @@ -121,9 +121,9 @@ InferCorrectLayoutOutput RequantizeInferCorrectLayout(const Attrs& attrs, } bool has_current_target_sse41_support() { - auto llvm_x86_has_feature_fn_ptr = tvm::runtime::Registry::Get("target.llvm_x86_has_feature"); - ICHECK(llvm_x86_has_feature_fn_ptr) << "Function target.llvm_x86_has_feature not found"; - return (*llvm_x86_has_feature_fn_ptr)("sse4.1", Target::Current(true)); + auto target_has_feature_fn_ptr = tvm::runtime::Registry::Get("target.target_has_feature"); + ICHECK(target_has_feature_fn_ptr) << "Function target.target_has_feature not found"; + return (*target_has_feature_fn_ptr)("sse4.1", Target::Current(true)); } /* diff --git a/src/relay/qnn/op/requantize_config.h b/src/relay/qnn/op/requantize_config.h index 956bc3533b81..a4238fa498c6 100644 --- a/src/relay/qnn/op/requantize_config.h +++ b/src/relay/qnn/op/requantize_config.h @@ -61,10 +61,10 @@ class RequantizeConfigNode : public Object { // For the x86 architecture, the float32 computation is expected to give significant speedup, // with little loss in the accuracy of the requantize operation. auto target = Target::Current(true); - auto llvm_x86_has_feature_fn_ptr = tvm::runtime::Registry::Get("target.llvm_x86_has_feature"); - ICHECK(llvm_x86_has_feature_fn_ptr) << "Function target.llvm_x86_has_feature not found"; + auto target_has_feature_fn_ptr = tvm::runtime::Registry::Get("target.target_has_feature"); + ICHECK(target_has_feature_fn_ptr) << "Function target.target_has_feature not found"; if (target.defined() && target->kind->name == "llvm") { - if ((*llvm_x86_has_feature_fn_ptr)("sse4.1", target)) { + if ((*target_has_feature_fn_ptr)("sse4.1", target)) { return "float32"; } } diff --git a/src/target/llvm/codegen_x86_64.cc b/src/target/llvm/codegen_x86_64.cc index efe15c5c4aac..1872d64d71c5 100644 --- a/src/target/llvm/codegen_x86_64.cc +++ b/src/target/llvm/codegen_x86_64.cc @@ -29,9 +29,7 @@ #if TVM_LLVM_VERSION >= 100 #include #endif -#include #include -#include #include #include @@ -43,38 +41,6 @@ namespace tvm { namespace codegen { -namespace { -bool TargetHasFeature(const llvm::TargetMachine& tm, const std::string& feature) { - // MCSubTargetInfo::checkFeatures was added in LLVM 6.0 -#if TVM_LLVM_VERSION >= 60 - const auto* MCInfo = tm.getMCSubtargetInfo(); - return MCInfo->checkFeatures(std::string("+") + feature); -#else - return false; - // TODO(tulloch) - enable this block, need to figure out how to reimplement - // this given visibility constraints, similar to - // https://github.com/rust-lang/rust/pull/31709 - - // Copied from - // https://github.com/llvm-mirror/llvm/blob/5136df4/lib/MC/MCSubtargetInfo.cpp#L78-L88. - - // auto checkFeatures = [&](const std::string FS) { - // llvm::SubtargetFeatures T(FS); - // llvm::FeatureBitset Set, All; - // for (std::string F : T.getFeatures()) { - // llvm::SubtargetFeatures::ApplyFeatureFlag(Set, F, MCInfo->ProcFeatures); - // if (F[0] == '-') { - // F[0] = '+'; - // } - // llvm::SubtargetFeatures::ApplyFeatureFlag(All, F, MCInfo->ProcFeatures); - // } - // return (MCInfo->getFeatureBits() & All) == Set; - // }; - // return checkFeatures(MCInfo, std::string("+") + feature); -#endif -} -} // namespace - class CodeGenX86_64 final : public CodeGenCPU { public: llvm::Value* VisitExpr_(const CastNode* op) override; @@ -92,9 +58,8 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) { const auto to = op->dtype; if (from.is_float() && to.is_float() && from.bits() == 16 && to.bits() == 32) { ICHECK_EQ(from.lanes(), to.lanes()); - llvm::TargetMachine* tm = llvm_target_->GetOrCreateTargetMachine(); - const auto has_avx512 = TargetHasFeature(*tm, "avx512f"); + const auto has_avx512 = llvm_target_->TargetHasCPUFeature("avx512f"); if (from.lanes() >= 16 && has_avx512) { return CallVectorIntrin( @@ -111,7 +76,7 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) { #if TVM_LLVM_VERSION <= 100 // The intrinsic x86_vcvtph2ps_256 was removed in LLVM 11. - const auto has_f16c = TargetHasFeature(*tm, "f16c"); + const auto has_f16c = llvm_target_->TargetHasCPUFeature("f16c"); if (from.lanes() >= 8 && has_f16c) { return CallVectorIntrin(llvm::Intrinsic::x86_vcvtph2ps_256, 8, diff --git a/src/target/llvm/llvm_instance.cc b/src/target/llvm/llvm_instance.cc index 2aa190ad708e..e270a9b66cb6 100644 --- a/src/target/llvm/llvm_instance.cc +++ b/src/target/llvm/llvm_instance.cc @@ -33,6 +33,7 @@ #include #include #include +#include #if TVM_LLVM_VERSION >= 140 #include #else @@ -66,6 +67,27 @@ #include #include +#if TVM_LLVM_VERSION < 180 +namespace llvm { +#if TVM_LLVM_VERSION < 170 +// SubtargetSubTypeKV view +template MCSubtargetInfo::*Member> +struct ArchViewer { + friend ArrayRef& archViewer(MCSubtargetInfo Obj) { return Obj.*Member; } +}; +template struct ArchViewer<&MCSubtargetInfo::ProcDesc>; +ArrayRef& archViewer(MCSubtargetInfo); +#endif +// SubtargetFeatureKV view +template MCSubtargetInfo::*Member> +struct FeatViewer { + friend ArrayRef& featViewer(MCSubtargetInfo Obj) { return Obj.*Member; } +}; +template struct FeatViewer<&MCSubtargetInfo::ProcFeatures>; +ArrayRef& featViewer(MCSubtargetInfo); +} // namespace llvm +#endif + namespace tvm { namespace codegen { @@ -175,6 +197,17 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const Target& target) { attrs_.push_back(s); } } + // llvm module target + if (target->kind->name == "llvm") { + // legalize -mcpu with the target -mtriple + auto arches = GetAllLLVMTargetArches(); + bool has_arch = + std::any_of(arches.begin(), arches.end(), [&](const auto& var) { return var == cpu_; }); + if (!has_arch) { + LOG(FATAL) << "LLVM cpu architecture `-mcpu=" << cpu_ + << "` is not valid in `-mtriple=" << triple_ << "`"; + } + } if (const Optional>& v = target->GetAttr>("cl-opt")) { llvm::StringMap& options = llvm::cl::getRegisteredOptions(); @@ -288,19 +321,59 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& scope, const std::string& target_st LLVMTargetInfo::~LLVMTargetInfo() = default; +static const llvm::Target* CreateLLVMTargetInstance(const std::string triple, + const bool allow_missing = true) { + std::string error; + // create LLVM instance + // required mimimum: llvm::InitializeAllTargets() + const llvm::Target* llvm_instance = llvm::TargetRegistry::lookupTarget(triple, error); + if (!allow_missing && !llvm_instance) { + ICHECK(llvm_instance) << "LLVM instance error: `" << error << "`"; + } + + return llvm_instance; +} + +static llvm::TargetMachine* CreateLLVMTargetMachine( + const llvm::Target* llvm_instance, const std::string& triple, const std::string& cpu, + const std::string& features, const llvm::TargetOptions& target_options, + const llvm::Reloc::Model& reloc_model, const llvm::CodeModel::Model& code_model, + const llvm::CodeGenOpt::Level& opt_level) { + llvm::TargetMachine* tm = llvm_instance->createTargetMachine( + triple, cpu, features, target_options, reloc_model, code_model, opt_level); + ICHECK(tm != nullptr); + + return tm; +} + +static const llvm::MCSubtargetInfo* GetLLVMSubtargetInfo(const std::string& triple, + const std::string& cpu_name, + const std::string& feats) { + // create a LLVM instance + auto llvm_instance = CreateLLVMTargetInstance(triple, true); + // create a target machine + // required minimum: llvm::InitializeAllTargetMCs() + llvm::TargetOptions target_options; + auto tm = CreateLLVMTargetMachine(llvm_instance, triple, cpu_name, feats, target_options, + llvm::Reloc::Static, llvm::CodeModel::Small, + llvm::CodeGenOpt::Level(0)); + // create subtarget info module + const llvm::MCSubtargetInfo* MCInfo = tm->getMCSubtargetInfo(); + + return MCInfo; +} + llvm::TargetMachine* LLVMTargetInfo::GetOrCreateTargetMachine(bool allow_missing) { if (target_machine_) return target_machine_.get(); std::string error; - if (const llvm::Target* llvm_instance = llvm::TargetRegistry::lookupTarget(triple_, error)) { + if (const llvm::Target* llvm_instance = CreateLLVMTargetInstance(triple_, allow_missing)) { llvm::TargetMachine* tm = - llvm_instance->createTargetMachine(triple_, cpu_, GetTargetFeatureString(), target_options_, - reloc_model_, code_model_, opt_level_); + CreateLLVMTargetMachine(llvm_instance, triple_, cpu_, GetTargetFeatureString(), + target_options_, reloc_model_, code_model_, opt_level_); target_machine_ = std::unique_ptr(tm); } - if (!allow_missing) { - ICHECK(target_machine_ != nullptr) << error; - } + ICHECK(target_machine_ != nullptr); return target_machine_.get(); } @@ -662,6 +735,75 @@ void LLVMTargetInfo::GetOptionValue(LLVMTargetInfo::Option* opt) const { } } +const Array LLVMTargetInfo::GetAllLLVMTargets() const { + Array llvm_targets; + // iterate all archtypes + for (auto a = llvm::Triple::ArchType(llvm::Triple::ArchType::UnknownArch + 1); + a < llvm::Triple::ArchType::LastArchType; a = llvm::Triple::ArchType(a + 1)) { + std::string target_name = llvm::Triple::getArchTypeName(a).str(); + // get valid target + if (CreateLLVMTargetInstance(target_name + "--", true)) { + llvm_targets.push_back(target_name); + } + } + + return llvm_targets; +} + +const Array LLVMTargetInfo::GetAllLLVMTargetArches() const { + Array cpu_arches; + // get the subtarget info module + const auto MCInfo = GetLLVMSubtargetInfo(triple_, "", ""); + if (!MCInfo) { + return cpu_arches; + } + // get all arches + llvm::ArrayRef llvm_arches = +#if TVM_LLVM_VERSION < 170 + llvm::archViewer(*(llvm::MCSubtargetInfo*)MCInfo); +#else + MCInfo->getAllProcessorDescriptions(); +#endif + for (const auto& arch : llvm_arches) { + cpu_arches.push_back(arch.Key); + } + + return cpu_arches; +} + +const Array LLVMTargetInfo::GetAllLLVMCpuFeatures() const { + std::string feats = ""; + for (const auto& attr : attrs_) { + feats += feats.empty() ? attr : ("," + attr); + } + // get the subtarget info module + const auto MCInfo = GetLLVMSubtargetInfo(triple_, cpu_.c_str(), feats); + // get all features for CPU + llvm::ArrayRef llvm_features = +#if TVM_LLVM_VERSION < 180 + llvm::featViewer(*(llvm::MCSubtargetInfo*)MCInfo); +#else + MCInfo->getAllProcessorFeatures(); +#endif + Array cpu_features; + for (const auto& feat : llvm_features) { + if (MCInfo->checkFeatures("+" + std::string(feat.Key))) { + cpu_features.push_back(feat.Key); + } + } + + return cpu_features; +} + +const bool LLVMTargetInfo::TargetHasCPUFeature(const std::string& feature) const { + // lookup features for `-mcpu` + auto feats = GetAllLLVMCpuFeatures(); + bool has_feature = + std::any_of(feats.begin(), feats.end(), [&](const auto& var) { return var == feature; }); + + return has_feature; +} + // LLVMTarget bool LLVMTarget::modified_llvm_state_ = false; diff --git a/src/target/llvm/llvm_instance.h b/src/target/llvm/llvm_instance.h index 217db63aad7a..ac08008b8021 100644 --- a/src/target/llvm/llvm_instance.h +++ b/src/target/llvm/llvm_instance.h @@ -266,6 +266,36 @@ class LLVMTargetInfo { */ bool MatchesGlobalState() const; + /*! + * \brief Get all supported targets from the LLVM backend + * \return list with all valid targets + */ + const Array GetAllLLVMTargets() const; + + /*! + * \brief Get all CPU arches from target + * \return list with all valid cpu architectures + * \note The arches are fetched from the LLVM backend using the target `-mtriple`. + */ + const Array GetAllLLVMTargetArches() const; + + /*! + * \brief Get all CPU features from target + * \return list with all valid cpu features + * \note The features are fetched from the LLVM backend using the target `-mtriple` + * and the `-mcpu` architecture, but also consider the `-mattr` attributes. + */ + const Array GetAllLLVMCpuFeatures() const; + + /*! + * \brief Check the target if has a specific cpu feature + * \param feature string with the feature to check + * \return true or false + * \note The feature is checked in the LLVM backend for the target `-mtriple` + * and `-mcpu` architecture, but also consider the `-mattr` attributes. + */ + const bool TargetHasCPUFeature(const std::string& feature) const; + protected: /*! * \brief Get the current value of given LLVM option diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index 168163c416cf..05a7df230f61 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -45,16 +45,6 @@ #include #include #include -#if TVM_LLVM_VERSION < 110 -#include -#include -#else -#if TVM_LLVM_VERSION < 170 -#include -#else -#include -#endif -#endif #include #include #include @@ -87,25 +77,6 @@ #include "codegen_llvm.h" #include "llvm_instance.h" -#if TVM_LLVM_VERSION < 110 -namespace llvm { -// SubtargetSubTypeKV view -template MCSubtargetInfo::*Member> -struct ArchViewer { - friend ArrayRef& archViewer(MCSubtargetInfo Obj) { return Obj.*Member; } -}; -template struct ArchViewer<&MCSubtargetInfo::ProcDesc>; -ArrayRef& archViewer(MCSubtargetInfo); -// SubtargetFeatureKV view -template MCSubtargetInfo::*Member> -struct FeatViewer { - friend ArrayRef& featViewer(MCSubtargetInfo Obj) { return Obj.*Member; } -}; -template struct FeatViewer<&MCSubtargetInfo::ProcFeatures>; -ArrayRef& featViewer(MCSubtargetInfo); -} // namespace llvm -#endif - namespace tvm { namespace codegen { @@ -514,131 +485,69 @@ TVM_REGISTER_GLOBAL("target.llvm_get_intrinsic_name").set_body_typed([](int64_t #endif }); -#if TVM_LLVM_VERSION < 110 -static const llvm::MCSubtargetInfo* llvm_compat_get_subtargetinfo(const std::string triple, - const std::string cpu_name) { - std::string error; - llvm::InitializeAllTargets(); - llvm::InitializeAllTargetMCs(); - // create a LLVM x86 instance - auto* llvm_instance = llvm::TargetRegistry::lookupTarget(triple, error); - // create a target machine - llvm::TargetOptions target_options; - auto RM = llvm::Optional(); - auto* tm = llvm_instance->createTargetMachine(triple, cpu_name.c_str(), "", target_options, RM); - // create subtarget info module - const llvm::MCSubtargetInfo* MCInfo = tm->getMCSubtargetInfo(); - - return MCInfo; -} - -static const Array llvm_compat_get_archlist(const std::string triple) { - // get the subtarget info module - const auto* MCInfo = llvm_compat_get_subtargetinfo(triple, ""); - // get all X86 arches - llvm::ArrayRef x86_arches = - llvm::archViewer(*(llvm::MCSubtargetInfo*)MCInfo); - Array cpu_arches; - for (auto& arch : x86_arches) { - cpu_arches.push_back(arch.Key); - } - return cpu_arches; -} - -static const Array llvm_compat_get_features(const std::string triple, - const std::string cpu_name) { - // get the subtarget info module - const auto* MCInfo = llvm_compat_get_subtargetinfo(triple, cpu_name.c_str()); - // get all features - llvm::ArrayRef x86_features = - llvm::featViewer(*(llvm::MCSubtargetInfo*)MCInfo); - // only targeted CPU features - Array cpu_features; - for (auto& feat : x86_features) { - if (MCInfo->checkFeatures("+" + std::string(feat.Key))) { - cpu_features.push_back(feat.Key); - } - } - return cpu_features; -} -#endif +TVM_REGISTER_GLOBAL("target.llvm_get_targets").set_body_typed([]() -> Array { + auto llvm_instance = std::make_unique(); + LLVMTargetInfo llvm_backend(*llvm_instance, "llvm"); + return llvm_backend.GetAllLLVMTargets(); +}); -TVM_REGISTER_GLOBAL("target.llvm_x86_get_archlist") - .set_body_typed([](bool only64bit) -> Array { - Array cpu_arches; -#if TVM_LLVM_VERSION < 110 - cpu_arches = llvm_compat_get_archlist("x86_64--"); -#else - llvm::SmallVector x86_arches; - llvm::X86::fillValidCPUArchList(x86_arches, only64bit); - for (auto& arch : x86_arches) { - cpu_arches.push_back(arch.str()); +TVM_REGISTER_GLOBAL("target.llvm_get_cpu_archlist") + .set_body_typed([](const Target& target) -> Array { + auto use_target = target.defined() ? target : Target::Current(false); + // ignore non "llvm" target + if (target.defined()) { + if (target->kind->name != "llvm") { + return Array{}; + } } -#endif - return cpu_arches; + auto llvm_instance = std::make_unique(); + LLVMTargetInfo llvm_backend(*llvm_instance, use_target); + return llvm_backend.GetAllLLVMTargetArches(); }); -TVM_REGISTER_GLOBAL("target.llvm_x86_get_features") - .set_body_typed([](std::string cpu_name) -> Array { - Array cpu_features; -#if TVM_LLVM_VERSION < 110 - cpu_features = llvm_compat_get_features("x86_64--", cpu_name); -#else - llvm::SmallVector x86_features; - llvm::X86::getFeaturesForCPU(cpu_name, x86_features); - for (auto& feat : x86_features) { - cpu_features.push_back(feat.str()); +TVM_REGISTER_GLOBAL("target.llvm_get_cpu_features") + .set_body_typed([](const Target& target) -> Array { + auto use_target = target.defined() ? target : Target::Current(false); + // ignore non "llvm" target + if (target.defined()) { + if (target->kind->name != "llvm") { + return Array{}; + } } -#endif - return cpu_features; + auto llvm_instance = std::make_unique(); + LLVMTargetInfo llvm_backend(*llvm_instance, use_target); + return llvm_backend.GetAllLLVMCpuFeatures(); }); -TVM_REGISTER_GLOBAL("target.llvm_x86_has_feature") - .set_body_typed([](String feature, const Target& target) -> bool { - // target argument is optional (nullptr or None) - // if not explicit then use the current context target - Optional mcpu = target.defined() ? target->GetAttr("mcpu") - : Target::Current(false)->GetAttr("mcpu"); - Optional> mattr = target.defined() - ? target->GetAttr>("mattr") - : Target::Current(false)->GetAttr>("mattr"); - String name = target.defined() ? target->kind->name : Target::Current(false)->kind->name; - // lookup only for `llvm` targets having -mcpu - if ((name != "llvm") || !mcpu) { - return false; - } - // lookup in -mattr flags - bool is_in_mattr = - !mattr ? false - : std::any_of(mattr.value().begin(), mattr.value().end(), - [&](const String& var) { return var == ("+" + feature); }); -#if TVM_LLVM_VERSION < 110 - auto x86_arches = llvm_compat_get_archlist("x86_64--"); - // decline on invalid arch (avoid llvm assertion) - if (!std::any_of(x86_arches.begin(), x86_arches.end(), - [&](const String& var) { return var == mcpu.value(); })) { - return false; +TVM_REGISTER_GLOBAL("target.llvm_cpu_has_feature") + .set_body_typed([](const String feature, const Target& target) -> bool { + auto use_target = target.defined() ? target : Target::Current(false); + // ignore non "llvm" target + if (target.defined()) { + if (target->kind->name != "llvm") { + return false; + } } - // lookup in -mcpu llvm architecture flags - auto cpu_features = llvm_compat_get_features("x86_64--", mcpu.value()); + auto llvm_instance = std::make_unique(); + LLVMTargetInfo llvm_backend(*llvm_instance, use_target); + auto cpu_features = llvm_backend.GetAllLLVMCpuFeatures(); bool has_feature = std::any_of(cpu_features.begin(), cpu_features.end(), - [&](const String& var) { return var == feature; }); -#else - llvm::SmallVector x86_arches; - llvm::X86::fillValidCPUArchList(x86_arches, false); - // decline on invalid arch (avoid llvm assertion) - if (!std::any_of(x86_arches.begin(), x86_arches.end(), - [&](const llvm::StringRef& var) { return var == mcpu.value().c_str(); })) { - return false; + [&](auto& var) { return var == feature; }); + return has_feature; + }); + +TVM_REGISTER_GLOBAL("target.target_has_feature") + .set_body_typed([](const String feature, const Target& target) -> bool { + auto use_target = target.defined() ? target : Target::Current(false); + // ignore non "llvm" target + if (target.defined()) { + if (target->kind->name != "llvm") { + return false; + } } - // lookup in -mcpu llvm architecture flags - llvm::SmallVector x86_features; - llvm::X86::getFeaturesForCPU(mcpu.value().c_str(), x86_features); - bool has_feature = - std::any_of(x86_features.begin(), x86_features.end(), - [&](const llvm::StringRef& var) { return var == feature.c_str(); }); -#endif - return has_feature || is_in_mattr; + auto llvm_instance = std::make_unique(); + LLVMTargetInfo llvm_target(*llvm_instance, use_target); + return llvm_target.TargetHasCPUFeature(feature); }); TVM_REGISTER_GLOBAL("target.llvm_version_major").set_body_typed([]() -> int { diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index 0a0ae561ab73..bd984d32e6bd 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -1837,7 +1837,10 @@ def test_depthwise_conv2d_int8(): wdata = np.random.rand(*kernel_shape) * 10 parameters = {"weight": tvm.nd.array(wdata.astype(weight_dtype))} - targets = ["llvm -mcpu=skylake-avx512", "llvm -mcpu=cascadelake"] + targets = [ + "llvm -mtriple=x86_64-linux-gnu -mcpu=skylake-avx512", + "llvm -mtriple=x86_64-linux-gnu -mcpu=cascadelake", + ] llvm_version = tvm.target.codegen.llvm_version_major() for target in targets: if llvm_version >= 8: diff --git a/tests/python/relay/test_op_qnn_conv2_transpose.py b/tests/python/relay/test_op_qnn_conv2_transpose.py index ec273eb2f785..18ad68df9e6c 100644 --- a/tests/python/relay/test_op_qnn_conv2_transpose.py +++ b/tests/python/relay/test_op_qnn_conv2_transpose.py @@ -644,7 +644,7 @@ def test_broadcast_layout(): func = relay.Function(relay.analysis.free_vars(func), func) mod = tvm.IRModule.from_expr(func) with tvm.transform.PassContext(opt_level=3): - libs = relay.build(mod, "llvm -mcpu=skylake-avx512") + libs = relay.build(mod, "llvm -mtriple=x86_64-linux-gnu -mcpu=skylake-avx512") def test_non_scalar_input_scale_zp(): diff --git a/tests/python/relay/test_op_qnn_conv2d.py b/tests/python/relay/test_op_qnn_conv2d.py index 3736350cbfe1..e7d2c8941b9e 100644 --- a/tests/python/relay/test_op_qnn_conv2d.py +++ b/tests/python/relay/test_op_qnn_conv2d.py @@ -948,7 +948,9 @@ def test_broadcast_layout(): func = relay.Function(relay.analysis.free_vars(func), func) mod = tvm.IRModule.from_expr(func) with tvm.transform.PassContext(opt_level=3): - graph, lib, params = relay.build(mod, "llvm -mcpu=skylake-avx512") + graph, lib, params = relay.build( + mod, "llvm -mtriple=x86_64-linux-gnu -mcpu=skylake-avx512" + ) def test_depthwise_depth_multiplier(): diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index 829c1d6ae43f..87065b2d2786 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -886,7 +886,7 @@ def before(): from tvm import topi def alter_conv2d(attrs, inputs, tinfos, out_type): - with tvm.target.Target("llvm -mcpu=core-avx2"): + with tvm.target.Target("llvm -mtriple=x86_64-linux-gnu -mcpu=core-avx2"): return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, out_type) def expected(): @@ -1373,7 +1373,7 @@ def expected(): y = relay.Function(analysis.free_vars(y), y) return y - target = "llvm -mcpu=core-avx2" + target = "llvm -mtriple=x86_64-linux-gnu -mcpu=core-avx2" with tvm.target.Target(target): with TempOpAttr( "nn.dense", "FTVMAlterOpLayout", topi.x86.dense_alter_op._alter_dense_layout @@ -1441,7 +1441,7 @@ def expected(): ) return relay.Function(analysis.free_vars(dense), dense) - with tvm.target.Target("llvm -mcpu=core-avx2"): + with tvm.target.Target("llvm -mtriple=x86_64-linux-gnu -mcpu=core-avx2"): with TempOpAttr( "nn.dense", "FTVMAlterOpLayout", topi.x86.dense_alter_op._alter_dense_layout ): diff --git a/tests/python/relay/test_pass_qnn_legalize.py b/tests/python/relay/test_pass_qnn_legalize.py index 73ba9c22082f..a32100ea206e 100644 --- a/tests/python/relay/test_pass_qnn_legalize.py +++ b/tests/python/relay/test_pass_qnn_legalize.py @@ -138,7 +138,10 @@ def _get_mod(data_dtype, kernel_dtype): # Check transformations for platforms with fast Int8 support. ############################################################# # Check that Intel AVX512 (with or w/o VNNI) gets picked up. - for target in ["llvm -mcpu=skylake-avx512", "llvm -mcpu=cascadelake"]: + for target in [ + "llvm -mtriple=x86_64-linux-gnu -mcpu=skylake-avx512", + "llvm -mtriple=x86_64-linux-gnu -mcpu=cascadelake", + ]: with tvm.target.Target(target): mod = relay.transform.InferType()(mod) legalized_mod = relay.qnn.transform.Legalize()(mod) @@ -170,7 +173,7 @@ def _get_mod(data_dtype, kernel_dtype): # Check transformations for platforms with fast Int8 support. ############################################################# # Check no transformation for Intel AVX512. - with tvm.target.Target("llvm -mcpu=skylake-avx512"): + with tvm.target.Target("llvm -mtriple=x86_64-linux-gnu -mcpu=skylake-avx512"): mod = relay.transform.InferType()(mod) legalized_mod = relay.qnn.transform.Legalize()(mod) assert tvm.ir.structural_equal(mod, legalized_mod) @@ -232,7 +235,10 @@ def _get_mod(data_dtype, kernel_dtype): # Check transformations for platforms with fast Int8 support. ############################################################# # Check that Intel AVX512 (with or w/o VNNI) gets picked up. - for target in ["llvm -mcpu=skylake-avx512", "llvm -mcpu=cascadelake"]: + for target in [ + "llvm -mtriple=x86_64-linux-gnu -mcpu=skylake-avx512", + "llvm -mtriple=x86_64-linux-gnu -mcpu=cascadelake", + ]: with tvm.target.Target(target): mod = relay.transform.InferType()(mod) legalized_mod = relay.qnn.transform.Legalize()(mod) @@ -264,7 +270,7 @@ def _get_mod(data_dtype, kernel_dtype): # Check transformations for platforms with fast Int8 support. ############################################################# # Check no transformation for Intel AVX512. - with tvm.target.Target("llvm -mcpu=skylake-avx512"): + with tvm.target.Target("llvm -mtriple=x86_64-linux-gnu -mcpu=skylake-avx512"): mod = relay.transform.InferType()(mod) legalized_mod = relay.qnn.transform.Legalize()(mod) assert tvm.ir.structural_equal(mod, legalized_mod) diff --git a/tests/python/target/test_llvm_features_info.py b/tests/python/target/test_llvm_features_info.py new file mode 100644 index 000000000000..1be71331dda8 --- /dev/null +++ b/tests/python/target/test_llvm_features_info.py @@ -0,0 +1,104 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +import tvm +from tvm.target import _ffi_api, codegen, Target + +LLVM_VERSION = codegen.llvm_version_major() + + +def test_llvm_targets(): + + ## + ## check LLVM backend + ## + + # check blank results + assert len(codegen.llvm_get_targets()) + # check ffi vs python + assert str(codegen.llvm_get_targets()) == str(_ffi_api.llvm_get_targets()) + + # check LLVM target -mcpu legality + try: + tvm.target.codegen.llvm_get_cpu_features( + tvm.target.Target("llvm -mtriple=x86_64-linux-gnu -mcpu=dummy") + ) + assert False + except tvm.error.TVMError as e: + msg = str(e) + assert ( + msg.find( + "TVMError: LLVM cpu architecture `-mcpu=dummy` is not valid in `-mtriple=x86_64-linux-gnu`" + ) + != -1 + ) + + +min_llvm_version, llvm_target, cpu_arch, cpu_features, is_supported = tvm.testing.parameters( + (-1, "x86_64", "sandybridge", "sse4.1", True), + (-1, "x86_64", "ivybridge", ["sse4.1", "ssse3"], True), + (-1, "x86_64", "ivybridge", ["sse4.1", "ssse3", "avx512bw"], False), + # 32bit vs 64bit + (-1, "aarch64", "cortex-a55", "neon", True), + (-1, "aarch64", "cortex-a55", "dotprod", True), + (-1, "aarch64", "cortex-a55", "dsp", False), + (-1, "arm", "cortex-a55", "dsp", True), + (-1, "aarch64", "cortex-a55", ["neon", "dotprod"], True), + (-1, "aarch64", "cortex-a55", ["neon", "dotprod", "dsp"], False), + (-1, "arm", "cortex-a55", ["neon", "dotprod"], True), + (-1, "aarch64", "cortex-a55", ["neon", "dotprod", "dsp"], False), + (-1, "arm", "cortex-a55", ["neon", "dotprod", "dsp"], True), +) + + +def test_target_features(min_llvm_version, llvm_target, cpu_arch, cpu_features, is_supported): + + target = Target("llvm -mtriple=%s-- -mcpu=%s" % (llvm_target, cpu_arch)) + + ## + ## legalize llvm_target + ## + + assert llvm_target in codegen.llvm_get_targets() + + ## + ## legalize cpu_arch + ## + + ### with context + with target: + assert cpu_arch in codegen.llvm_get_cpu_archlist() + ### no context but with expicit target + assert cpu_arch in codegen.llvm_get_cpu_archlist(target) + # check ffi vs python + assert str(codegen.llvm_get_cpu_archlist(target)) == str(_ffi_api.llvm_get_cpu_archlist(target)) + + ## + ## check has_features + ## + + ### with context + with target: + assert codegen.llvm_cpu_has_features(cpu_features) == is_supported + ### no context but with expicit target + assert codegen.llvm_cpu_has_features(cpu_features, target) == is_supported + # check ffi vs python + for feat in cpu_features: + assert str(codegen.llvm_cpu_has_features(feat, target)) == str( + _ffi_api.llvm_cpu_has_feature(feat, target) + ) diff --git a/tests/python/target/test_x86_features.py b/tests/python/target/test_x86_features.py index 31a823b504eb..ef1ab9b42359 100644 --- a/tests/python/target/test_x86_features.py +++ b/tests/python/target/test_x86_features.py @@ -18,85 +18,89 @@ import tvm from tvm.target import _ffi_api, codegen, Target -from tvm.target.x86 import target_has_features +from tvm.target.codegen import target_has_features LLVM_VERSION = codegen.llvm_version_major() min_llvm_version, tvm_target, x86_feature, is_supported = tvm.testing.parameters( # sse4.1 - (-1, "llvm -mcpu=btver2", "sse4a", True), - (-1, "llvm -mcpu=penryn", "sse4.1", True), - (-1, "llvm -mcpu=silvermont", "sse4.2", True), - (11, "llvm -mcpu=slm", "sse4.2", True), - (-1, "llvm -mcpu=goldmont", "sse4.2", True), - (-1, "llvm -mcpu=goldmont-plus", "sse4.2", True), - (-1, "llvm -mcpu=tremont", "sse4.2", True), - (-1, "llvm -mcpu=nehalem", "sse4.2", True), - (11, "llvm -mcpu=corei7", "sse4.2", True), - (-1, "llvm -mcpu=westmere", "sse4.2", True), - (-1, "llvm -mcpu=bdver1", "sse4.2", True), - (-1, "llvm -mcpu=bdver2", "sse4.2", True), - (-1, "llvm -mcpu=bdver3", "sse4.2", True), - (11, "llvm -mcpu=x86-64-v2", "sse4.2", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=btver2", "sse4a", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=penryn", "sse4.1", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=silvermont", "sse4.2", True), + (11, "llvm -mtriple=x86_64-- -mcpu=slm", "sse4.2", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=goldmont", "sse4.2", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=goldmont-plus", "sse4.2", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=tremont", "sse4.2", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=nehalem", "sse4.2", True), + (11, "llvm -mtriple=x86_64-- -mcpu=corei7", "sse4.2", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=westmere", "sse4.2", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=bdver1", "sse4.2", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=bdver2", "sse4.2", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=bdver3", "sse4.2", True), + (11, "llvm -mtriple=x86_64-- -mcpu=x86-64-v2", "sse4.2", True), # avx - (-1, "llvm -mcpu=sandybridge", "avx", True), - (11, "llvm -mcpu=corei7-avx", "avx", True), - (-1, "llvm -mcpu=ivybridge", "avx", True), - (11, "llvm -mcpu=core-avx-i", "avx", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=sandybridge", "avx", True), + (11, "llvm -mtriple=x86_64-- -mcpu=corei7-avx", "avx", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=ivybridge", "avx", True), + (11, "llvm -mtriple=x86_64-- -mcpu=core-avx-i", "avx", True), # avx2 - (-1, "llvm -mcpu=haswell", "avx2", True), - (11, "llvm -mcpu=core-avx2", "avx2", True), - (-1, "llvm -mcpu=broadwell", "avx2", True), - (-1, "llvm -mcpu=skylake", "avx2", True), - (-1, "llvm -mcpu=bdver4", "avx2", True), - (-1, "llvm -mcpu=znver1", "avx2", True), - (-1, "llvm -mcpu=znver2", "avx2", True), - (11, "llvm -mcpu=znver3", "avx2", True), - (11, "llvm -mcpu=x86-64-v3", "avx2", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=haswell", "avx2", True), + (11, "llvm -mtriple=x86_64-- -mcpu=core-avx2", "avx2", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=broadwell", "avx2", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=skylake", "avx2", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=bdver4", "avx2", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=znver1", "avx2", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=znver2", "avx2", True), + (11, "llvm -mtriple=x86_64-- -mcpu=znver3", "avx2", True), + (11, "llvm -mtriple=x86_64-- -mcpu=x86-64-v3", "avx2", True), # avx512bw - (-1, "llvm -mcpu=skylake-avx512", "avx512bw", True), - (11, "llvm -mcpu=skx", "avx512bw", True), - (11, "llvm -mcpu=knl", "avx512bw", False), - (-1, "llvm -mcpu=knl", "avx512f", True), - (11, "llvm -mcpu=knl", ["avx512bw", "avx512f"], False), - (11, "llvm -mcpu=knl", ("avx512bw", "avx512f"), False), - (-1, "llvm -mcpu=knl", "avx512cd", True), - (11, "llvm -mcpu=knl", ["avx512cd", "avx512f"], True), - (11, "llvm -mcpu=knl", ("avx512cd", "avx512f"), True), - (-1, "llvm -mcpu=knl", "avx512er", True), - (-1, "llvm -mcpu=knl", "avx512pf", True), - (11, "llvm -mcpu=knm", "avx512bw", False), - (-1, "llvm -mcpu=knm", "avx512f", True), - (-1, "llvm -mcpu=knm", "avx512cd", True), - (-1, "llvm -mcpu=knm", "avx512er", True), - (-1, "llvm -mcpu=knm", "avx512pf", True), - (11, "llvm -mcpu=x86-64-v4", "avx512bw", True), - (-1, "llvm -mcpu=cannonlake", "avx512bw", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=skylake-avx512", "avx512bw", True), + (11, "llvm -mtriple=x86_64-- -mcpu=skx", "avx512bw", True), + (11, "llvm -mtriple=x86_64-- -mcpu=knl", "avx512bw", False), + (-1, "llvm -mtriple=x86_64-- -mcpu=knl", "avx512f", True), + (11, "llvm -mtriple=x86_64-- -mcpu=knl", ["avx512bw", "avx512f"], False), + (11, "llvm -mtriple=x86_64-- -mcpu=knl", ("avx512bw", "avx512f"), False), + (-1, "llvm -mtriple=x86_64-- -mcpu=knl", "avx512cd", True), + (11, "llvm -mtriple=x86_64-- -mcpu=knl", ["avx512cd", "avx512f"], True), + (11, "llvm -mtriple=x86_64-- -mcpu=knl", ("avx512cd", "avx512f"), True), + (-1, "llvm -mtriple=x86_64-- -mcpu=knl", "avx512er", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=knl", "avx512pf", True), + (11, "llvm -mtriple=x86_64-- -mcpu=knm", "avx512bw", False), + (-1, "llvm -mtriple=x86_64-- -mcpu=knm", "avx512f", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=knm", "avx512cd", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=knm", "avx512er", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=knm", "avx512pf", True), + (11, "llvm -mtriple=x86_64-- -mcpu=x86-64-v4", "avx512bw", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=cannonlake", "avx512bw", True), # explicit enumeration of VNNI capable due to collision with alderlake - (11, "llvm -mcpu=alderlake", "avx512bw", False), - (-1, "llvm -mcpu=cascadelake", "avx512bw", True), - (-1, "llvm -mcpu=icelake-client", "avx512bw", True), - (-1, "llvm -mcpu=icelake-server", "avx512bw", True), - (11, "llvm -mcpu=rocketlake", "avx512bw", True), - (-1, "llvm -mcpu=tigerlake", "avx512bw", True), - (-1, "llvm -mcpu=cooperlake", "avx512bw", True), - (11, "llvm -mcpu=sapphirerapids", "avx512bw", True), + (11, "llvm -mtriple=x86_64-- -mcpu=alderlake", "avx512bw", False), + (-1, "llvm -mtriple=x86_64-- -mcpu=cascadelake", "avx512bw", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=icelake-client", "avx512bw", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=icelake-server", "avx512bw", True), + (11, "llvm -mtriple=x86_64-- -mcpu=rocketlake", "avx512bw", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=tigerlake", "avx512bw", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=cooperlake", "avx512bw", True), + (11, "llvm -mtriple=x86_64-- -mcpu=sapphirerapids", "avx512bw", True), # avx512vnni - (11, "llvm -mcpu=alderlake", "avx512vnni", False), - (11, "llvm -mcpu=alderlake", "avxvnni", True), - (-1, "llvm -mcpu=cascadelake", "avx512vnni", True), - (-1, "llvm -mcpu=icelake-client", "avx512vnni", True), - (-1, "llvm -mcpu=icelake-server", "avx512vnni", True), - (11, "llvm -mcpu=rocketlake", "avx512vnni", True), - (-1, "llvm -mcpu=tigerlake", "avx512vnni", True), - (-1, "llvm -mcpu=cooperlake", "avx512vnni", True), - (11, "llvm -mcpu=sapphirerapids", "avx512vnni", True), + (11, "llvm -mtriple=x86_64-- -mcpu=alderlake", "avx512vnni", False), + (11, "llvm -mtriple=x86_64-- -mcpu=alderlake", "avxvnni", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=cascadelake", "avx512vnni", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=icelake-client", "avx512vnni", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=icelake-server", "avx512vnni", True), + (11, "llvm -mtriple=x86_64-- -mcpu=rocketlake", "avx512vnni", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=tigerlake", "avx512vnni", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=cooperlake", "avx512vnni", True), + (11, "llvm -mtriple=x86_64-- -mcpu=sapphirerapids", "avx512vnni", True), # amx-int8 - (11, "llvm -mcpu=sapphirerapids", "amx-int8", True), + (11, "llvm -mtriple=x86_64-- -mcpu=sapphirerapids", "amx-int8", True), # generic CPU (no features) but with extra -mattr - (-1, "llvm -mcpu=x86-64 -mattr=+sse4.1,+avx2", "avx2", True), - (-1, "llvm -mcpu=x86-64 -mattr=+sse4.1,+avx2", "sse4.1", True), - (-1, "llvm -mcpu=x86-64 -mattr=+sse4.1,+avx2", "ssse3", False), + (-1, "llvm -mtriple=x86_64-- -mcpu=x86-64 -mattr=+sse4.1,+avx2", "avx2", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=x86-64 -mattr=+sse4.1,+avx2", "sse4.1", True), + # enabling +sse4.1 implies ssse3 presence in LLVM + (-1, "llvm -mtriple=x86_64-- -mcpu=x86-64 -mattr=+sse4.1,+avx2", "ssse3", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=ivybridge -mattr=-ssse3", "ssse3", False), + # disabling avx512f (foundation) also disables avx512bw + (-1, "llvm -mtriple=x86_64-- -mcpu=cascadelake -mattr=-avx512f", "avx512bw", False), ) @@ -135,7 +139,7 @@ def test_x86_target_features(min_llvm_version, tvm_target, x86_feature, is_suppo if isinstance(x86_feature, str): # check for feature via the ffi llvm api (no explicit target, no context target) try: - assert _ffi_api.llvm_x86_has_feature(x86_feature, None) == is_supported + assert _ffi_api.target_has_feature(x86_feature, None) == is_supported assert False except tvm.error.InternalError as e: msg = str(e) @@ -154,7 +158,7 @@ def test_x86_target_features(min_llvm_version, tvm_target, x86_feature, is_suppo assert target_has_features(x86_feature, Target(tvm_target)) == is_supported if isinstance(x86_feature, str): # check for feature via the ffi llvm api (with explicit target, no context target) - assert _ffi_api.llvm_x86_has_feature(x86_feature, Target(tvm_target)) == is_supported + assert _ffi_api.target_has_feature(x86_feature, Target(tvm_target)) == is_supported ## ## with context @@ -166,11 +170,8 @@ def test_x86_target_features(min_llvm_version, tvm_target, x86_feature, is_suppo assert target_has_features(x86_feature) == is_supported # check for feature via the python api (with explicit target) assert target_has_features(x86_feature, Target(tvm_target)) == is_supported - if isinstance(x86_feature, str): - # check for feature via the ffi llvm api (current context target) - assert _ffi_api.llvm_x86_has_feature(x86_feature, None) == is_supported - # check for feature via the ffi llvm api (with explicit target) - assert _ffi_api.llvm_x86_has_feature(x86_feature, Target(tvm_target)) == is_supported - # check for feature in target's llvm full x86 CPU feature list - if not Target(tvm_target).mattr: - assert (x86_feature in codegen.llvm_x86_get_features(mcpu)) == is_supported + # check for feature via the ffi llvm api (current context target) + (sum(_ffi_api.target_has_feature(feat, None) for feat in x86_feature) > 0) == is_supported + # check for feature in target's llvm full x86 CPU feature list + if (not Target(tvm_target).mattr) and isinstance(x86_feature, str): + assert (x86_feature in codegen.llvm_get_cpu_features()) == is_supported diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py index 6a2f5573b274..f1316ae3cee0 100644 --- a/tests/python/unittest/test_target_codegen_llvm.py +++ b/tests/python/unittest/test_target_codegen_llvm.py @@ -946,7 +946,7 @@ def test_llvm_target_attributes(): xo, xi = s[C].split(C.op.axis[0], nparts=2) s[C].parallel(xo) - target_llvm = "llvm -mcpu=skylake -mattr=+avx512f" + target_llvm = "llvm -mtriple=x86_64-linux-gnu -mcpu=skylake -mattr=+avx512f" target = tvm.target.Target(target_llvm, host=target_llvm) module = tvm.build(s, [A, B, C, n], target=target, name="test_func")