Skip to content

Commit

Permalink
[Target] LLVM helper functions for any target info (#15761)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbalint13 authored Sep 27, 2023
1 parent c318fa8 commit cf8521a
Show file tree
Hide file tree
Showing 21 changed files with 538 additions and 351 deletions.
3 changes: 3 additions & 0 deletions cmake/modules/LLVM.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ add_definitions(-DDMLC_USE_FOPEN64=0 -DNDEBUG=1)
# It may be a boolean or a string
if(NOT ${USE_LLVM} MATCHES ${IS_FALSE_PATTERN})
find_llvm(${USE_LLVM})
if (${TVM_LLVM_VERSION} LESS 60)
message(FATAL_ERROR "LLVM version 6.0 or greater is required.")
endif()
include_directories(SYSTEM ${LLVM_INCLUDE_DIRS})
add_definitions(${LLVM_DEFINITIONS})
message(STATUS "Build with LLVM " ${LLVM_PACKAGE_VERSION})
Expand Down
93 changes: 79 additions & 14 deletions python/tvm/target/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""Code generation related functions."""
from . import _ffi_api
from .target import Target
from ..ir.container import Array


def build_module(mod, target):
Expand All @@ -39,6 +40,30 @@ def build_module(mod, target):
return _ffi_api.Build(mod, target)


def target_has_features(cpu_features, target=None):
"""Check CPU features for the target's `-mtriple` and `-mcpu` and `-mattr`.
Parameters
----------
target : Target
The TVM target.
cpu_features : str or Array
CPU Feature(s) to check.
Returns
-------
has_features : bool
True if target has the feature(s).
"""
assert isinstance(target, Target) or target is None
assert isinstance(cpu_features, (Array, list, tuple, str))
has_feats = True
cpu_features = [cpu_features] if isinstance(cpu_features, str) else cpu_features
for feat in cpu_features:
has_feats &= _ffi_api.target_has_feature(feat, target)
return has_feats


def llvm_lookup_intrinsic_id(name):
"""Lookup LLVM intrinsic id by name.
Expand Down Expand Up @@ -71,36 +96,76 @@ def llvm_get_intrinsic_name(intrin_id: int) -> str:
return _ffi_api.llvm_get_intrinsic_name(intrin_id)


def llvm_x86_get_archlist(only64bit=False):
"""Get X86 CPU name list.
def llvm_get_targets():
"""Get LLVM target list.
Parameters
----------
Returns
-------
llvm_targets : list[str]
List of available LLVM targets.
"""
return _ffi_api.llvm_get_targets()


def llvm_get_cpu_archlist(target=None):
"""Get CPU architectures for the target's `-mtriple`.
Parameters
----------
target : Target
The TVM target.
Returns
-------
cpu_archlist : list[str]
List of available CPU architectures.
"""
assert isinstance(target, Target) or target is None
return _ffi_api.llvm_get_cpu_archlist(target)


def llvm_get_cpu_features(target=None):
"""Get CPU features for the target's `-mtriple` and `-mcpu` and considering `-mattr`.
Parameters
----------
only64bit : bool
Filter 64bit architectures.
target : Target
The TVM target.
Returns
-------
features : list[str]
String list of X86 architectures.
cpu_features : list[str]
List of available CPU features.
"""
return _ffi_api.llvm_x86_get_archlist(only64bit)
assert isinstance(target, Target) or target is None
return _ffi_api.llvm_get_cpu_features(target)


def llvm_x86_get_features(cpu_name):
"""Get X86 CPU features.
def llvm_cpu_has_features(cpu_features, target=None):
"""Check CPU features for the target's `-mtriple` and `-mcpu` and considering `-mattr`.
Parameters
----------
cpu_name : string
X86 CPU name (e.g. "skylake").
target : Target
The TVM target.
cpu_features : str or Array
CPU Feature(s) to check.
Returns
-------
features : list[str]
String list of X86 CPU features.
has_features : bool
True if target CPU has the feature(s).
"""
return _ffi_api.llvm_x86_get_features(cpu_name)
assert isinstance(target, Target) or target is None
assert isinstance(cpu_features, (Array, list, tuple, str))
has_feats = True
cpu_features = [cpu_features] if isinstance(cpu_features, str) else cpu_features
for feat in cpu_features:
has_feats &= _ffi_api.llvm_cpu_has_feature(feat, target)
return has_feats


def llvm_version_major(allow_none=False):
Expand Down
28 changes: 1 addition & 27 deletions python/tvm/target/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,30 +16,7 @@
# under the License.
"""Common x86 related utilities"""
from .._ffi import register_func
from . import _ffi_api
from ..ir.container import Array


@register_func("tvm.target.x86.target_has_features")
def target_has_features(features, target=None):
"""Check X86 CPU features.
Parameters
----------
features : str or Array
Feature(s) to check.
target : Target
Optional TVM target, default `None` use the global context target.
Returns
-------
has_feats : bool
True if feature(s) are in the target arch.
"""
has_feats = True
assert isinstance(features, (Array, str))
features = [features] if isinstance(features, str) else features
for feat in features:
has_feats &= _ffi_api.llvm_x86_has_feature(feat, target)
return has_feats
from .codegen import target_has_features


@register_func("tvm.topi.x86.utils.get_simd_32bit_lanes")
Expand All @@ -53,9 +30,6 @@ def get_simd_32bit_lanes():
The optimal vector length of CPU from the global context target.
"""
vec_len = 4
# avx512f: llvm.x86.avx512.addpd.w.512 (LLVM auto, added)
# avx512bw: llvm.x86.avx512.pmaddubs.w.512" (TVM required)
# + llvm.x86.avx512.pmaddw.d.512"
if target_has_features(["avx512bw", "avx512f"]):
vec_len = 16
elif target_has_features("avx2"):
Expand Down
8 changes: 1 addition & 7 deletions python/tvm/topi/x86/batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from tvm import autotvm, te
from tvm.autotvm.task.space import SplitEntity
from tvm.contrib import cblas, mkl
from tvm.target.x86 import target_has_features
from tvm.target.codegen import target_has_features

from .. import generic, nn
from ..transform import layout_transform
Expand All @@ -38,9 +38,6 @@ def batch_matmul_int8_compute(cfg, x, y, *_):
packed_y = layout_transform(y, "BNK", packed_y_layout)
_, n_o, _, n_i, _ = packed_y.shape
ak = te.reduce_axis((0, k), name="k")
# avx512f: llvm.x86.avx512.addpd.w.512 (LLVM auto, added)
# avx512bw: llvm.x86.avx512.pmaddubs.w.512" (TVM required)
# + llvm.x86.avx512.pmaddw.d.512"
if target_has_features(["avx512bw", "avx512f"]):
attrs_info = {"schedule_rule": "batch_matmul_int8"}
else:
Expand Down Expand Up @@ -241,9 +238,6 @@ def _callback(op):
layout_trans = op.input_tensors[1]
if target_has_features("amx-int8"):
batch_matmul_amx_schedule(cfg, s, op.output(0), outs[0], layout_trans)
# avx512f: llvm.x86.avx512.addpd.w.512 (LLVM auto, added)
# avx512bw: llvm.x86.avx512.pmaddubs.w.512" (TVM required)
# + llvm.x86.avx512.pmaddw.d.512"
elif target_has_features(["avx512bw", "avx512f"]):
batch_matmul_int8_schedule(cfg, s, op.output(0), outs[0], layout_trans)

Expand Down
9 changes: 2 additions & 7 deletions python/tvm/topi/x86/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
from tvm import autotvm, te
from tvm.autotvm.task.space import SplitEntity
from tvm.contrib import cblas, dnnl, mkl
from tvm.target.x86 import get_simd_32bit_lanes, target_has_features
from tvm.target.x86 import get_simd_32bit_lanes
from tvm.target.codegen import target_has_features

from .. import generic, tag
from ..utils import get_const_tuple, traverse_inline
Expand Down Expand Up @@ -303,9 +304,6 @@ def _callback(op):
if "dense_int8" in op.tag:
if target_has_features("amx-int8"):
dense_amx_int8_schedule(cfg, s, op.output(0), outs[0])
# avx512f: llvm.x86.avx512.addpd.w.512 (LLVM auto, added)
# avx512bw: llvm.x86.avx512.pmaddubs.w.512" (TVM required)
# + llvm.x86.avx512.pmaddw.d.512"
elif target_has_features(["avx512bw", "avx512f"]):
dense_int8_schedule(cfg, s, op.output(0), outs[0])

Expand All @@ -318,9 +316,6 @@ def dense_int8_compute(cfg, X, packed_w, bias=None):
m, k = X.shape
n_o, _, n_i, _ = packed_w.shape
ak = te.reduce_axis((0, k), name="k")
# avx512f: llvm.x86.avx512.addpd.w.512 (LLVM auto, added)
# avx512bw: llvm.x86.avx512.pmaddubs.w.512" (TVM required)
# + llvm.x86.avx512.pmaddw.d.512"
if target_has_features(["avx512bw", "avx512f"]):
target_attr = {"schedule_rule": "meta_schedule.x86.dense_int8"}
else:
Expand Down
5 changes: 1 addition & 4 deletions python/tvm/topi/x86/dense_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import tvm
from tvm import autotvm, relay, te
from tvm.target.x86 import target_has_features
from tvm.target.codegen import target_has_features

from .. import nn
from ..nn import dense_alter_layout
Expand All @@ -28,9 +28,6 @@


def check_int8_applicable(x, y, allow_padding=False):
# avx512f: llvm.x86.avx512.addpd.w.512 (LLVM auto, added)
# avx512bw: llvm.x86.avx512.pmaddubs.w.512" (TVM required)
# + llvm.x86.avx512.pmaddw.d.512"
simd_avai = target_has_features(["avx512bw", "avx512f"])
simd_avai |= target_has_features("amx-int8")
# TODO(vvchernov): may be also target_has_features("avx2") or lower?
Expand Down
19 changes: 8 additions & 11 deletions src/meta_schedule/space_generator/space_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,17 @@ namespace meta_schedule {

String GetRuleKindFromTarget(const Target& target) {
if (target->kind->name == "llvm") {
static const PackedFunc* llvm_x86_has_feature_fn_ptr =
runtime::Registry::Get("target.llvm_x86_has_feature");
ICHECK(llvm_x86_has_feature_fn_ptr != nullptr)
<< "The `target.llvm_x86_has_feature` func is not in tvm registry.";
bool have_avx512vnni = (*llvm_x86_has_feature_fn_ptr)("avx512vnni", target);
bool have_avxvnni = (*llvm_x86_has_feature_fn_ptr)("avxvnni", target);
static const PackedFunc* target_has_feature_fn_ptr =
runtime::Registry::Get("target.target_has_feature");
ICHECK(target_has_feature_fn_ptr != nullptr)
<< "The `target.target_has_feature` func is not in tvm registry.";
bool have_avx512vnni = (*target_has_feature_fn_ptr)("avx512vnni", target);
bool have_avxvnni = (*target_has_feature_fn_ptr)("avxvnni", target);
if (have_avx512vnni || have_avxvnni) {
return "vnni";
} else {
// avx512f: llvm.x86.avx512.addpd.w.512 (LLVM auto, added)
// avx512bw: llvm.x86.avx512.pmaddubs.w.512" (TVM required)
// + llvm.x86.avx512.pmaddw.d.512"
bool have_avx512f = (*llvm_x86_has_feature_fn_ptr)("avx512f", target);
bool have_avx512bw = (*llvm_x86_has_feature_fn_ptr)("avx512bw", target);
bool have_avx512f = (*target_has_feature_fn_ptr)("avx512f", target);
bool have_avx512bw = (*target_has_feature_fn_ptr)("avx512bw", target);
if (have_avx512bw && have_avx512f) {
return "avx512";
}
Expand Down
6 changes: 3 additions & 3 deletions src/relay/qnn/op/requantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,9 @@ InferCorrectLayoutOutput RequantizeInferCorrectLayout(const Attrs& attrs,
}

bool has_current_target_sse41_support() {
auto llvm_x86_has_feature_fn_ptr = tvm::runtime::Registry::Get("target.llvm_x86_has_feature");
ICHECK(llvm_x86_has_feature_fn_ptr) << "Function target.llvm_x86_has_feature not found";
return (*llvm_x86_has_feature_fn_ptr)("sse4.1", Target::Current(true));
auto target_has_feature_fn_ptr = tvm::runtime::Registry::Get("target.target_has_feature");
ICHECK(target_has_feature_fn_ptr) << "Function target.target_has_feature not found";
return (*target_has_feature_fn_ptr)("sse4.1", Target::Current(true));
}

/*
Expand Down
6 changes: 3 additions & 3 deletions src/relay/qnn/op/requantize_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,10 @@ class RequantizeConfigNode : public Object {
// For the x86 architecture, the float32 computation is expected to give significant speedup,
// with little loss in the accuracy of the requantize operation.
auto target = Target::Current(true);
auto llvm_x86_has_feature_fn_ptr = tvm::runtime::Registry::Get("target.llvm_x86_has_feature");
ICHECK(llvm_x86_has_feature_fn_ptr) << "Function target.llvm_x86_has_feature not found";
auto target_has_feature_fn_ptr = tvm::runtime::Registry::Get("target.target_has_feature");
ICHECK(target_has_feature_fn_ptr) << "Function target.target_has_feature not found";
if (target.defined() && target->kind->name == "llvm") {
if ((*llvm_x86_has_feature_fn_ptr)("sse4.1", target)) {
if ((*target_has_feature_fn_ptr)("sse4.1", target)) {
return "float32";
}
}
Expand Down
39 changes: 2 additions & 37 deletions src/target/llvm/codegen_x86_64.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@
#if TVM_LLVM_VERSION >= 100
#include <llvm/IR/IntrinsicsX86.h>
#endif
#include <llvm/MC/MCSubtargetInfo.h>
#include <llvm/Support/Casting.h>
#include <llvm/Target/TargetMachine.h>
#include <tvm/runtime/registry.h>

#include <string>
Expand All @@ -43,38 +41,6 @@
namespace tvm {
namespace codegen {

namespace {
bool TargetHasFeature(const llvm::TargetMachine& tm, const std::string& feature) {
// MCSubTargetInfo::checkFeatures was added in LLVM 6.0
#if TVM_LLVM_VERSION >= 60
const auto* MCInfo = tm.getMCSubtargetInfo();
return MCInfo->checkFeatures(std::string("+") + feature);
#else
return false;
// TODO(tulloch) - enable this block, need to figure out how to reimplement
// this given visibility constraints, similar to
// https://github.com/rust-lang/rust/pull/31709

// Copied from
// https://github.com/llvm-mirror/llvm/blob/5136df4/lib/MC/MCSubtargetInfo.cpp#L78-L88.

// auto checkFeatures = [&](const std::string FS) {
// llvm::SubtargetFeatures T(FS);
// llvm::FeatureBitset Set, All;
// for (std::string F : T.getFeatures()) {
// llvm::SubtargetFeatures::ApplyFeatureFlag(Set, F, MCInfo->ProcFeatures);
// if (F[0] == '-') {
// F[0] = '+';
// }
// llvm::SubtargetFeatures::ApplyFeatureFlag(All, F, MCInfo->ProcFeatures);
// }
// return (MCInfo->getFeatureBits() & All) == Set;
// };
// return checkFeatures(MCInfo, std::string("+") + feature);
#endif
}
} // namespace

class CodeGenX86_64 final : public CodeGenCPU {
public:
llvm::Value* VisitExpr_(const CastNode* op) override;
Expand All @@ -92,9 +58,8 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) {
const auto to = op->dtype;
if (from.is_float() && to.is_float() && from.bits() == 16 && to.bits() == 32) {
ICHECK_EQ(from.lanes(), to.lanes());
llvm::TargetMachine* tm = llvm_target_->GetOrCreateTargetMachine();

const auto has_avx512 = TargetHasFeature(*tm, "avx512f");
const auto has_avx512 = llvm_target_->TargetHasCPUFeature("avx512f");

if (from.lanes() >= 16 && has_avx512) {
return CallVectorIntrin(
Expand All @@ -111,7 +76,7 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) {

#if TVM_LLVM_VERSION <= 100
// The intrinsic x86_vcvtph2ps_256 was removed in LLVM 11.
const auto has_f16c = TargetHasFeature(*tm, "f16c");
const auto has_f16c = llvm_target_->TargetHasCPUFeature("f16c");

if (from.lanes() >= 8 && has_f16c) {
return CallVectorIntrin(llvm::Intrinsic::x86_vcvtph2ps_256, 8,
Expand Down
Loading

0 comments on commit cf8521a

Please sign in to comment.