diff --git a/include/tvm/tir/op_attr_types.h b/include/tvm/tir/op_attr_types.h index 3dcc4b943a79..963458ccee4a 100644 --- a/include/tvm/tir/op_attr_types.h +++ b/include/tvm/tir/op_attr_types.h @@ -28,11 +28,12 @@ #ifndef TVM_TIR_OP_ATTR_TYPES_H_ #define TVM_TIR_OP_ATTR_TYPES_H_ +#include #include +#include namespace tvm { namespace tir { - /*! * \brief Global symbol of the op after lowering. */ @@ -43,6 +44,16 @@ using TGlobalSymbol = String; */ using TVectorizable = bool; +/*! + * \brief The intrinsic lowering function for given op. + */ +using FLowerIntrinsic = runtime::TypedPackedFunc; + +/*! + * \brief The legalization function for given tir op. + */ +using FLegalize = runtime::TypedPackedFunc; + /*! * \brief The effect type of the call. */ diff --git a/python/tvm/ir/__init__.py b/python/tvm/ir/__init__.py index e35077bb5aab..a12d3e9855f0 100644 --- a/python/tvm/ir/__init__.py +++ b/python/tvm/ir/__init__.py @@ -23,7 +23,7 @@ from .tensor_type import TensorType from .type_relation import TypeCall, TypeRelation from .expr import BaseExpr, PrimExpr, RelayExpr, GlobalVar, Range -from .op import Op, register_op_attr +from .op import Op, register_op_attr, register_intrin_lowering from .function import CallingConv, BaseFunc from .adt import Constructor, TypeData from .module import IRModule diff --git a/python/tvm/ir/op.py b/python/tvm/ir/op.py index 7b06c3da33d6..88e760ef91a1 100644 --- a/python/tvm/ir/op.py +++ b/python/tvm/ir/op.py @@ -115,3 +115,40 @@ def _register(v): return v return _register(value) if value is not None else _register + + +def register_intrin_lowering( + op_name, + target, + *, + f=None, + level=10, +): + """Register Op lowering function + + Parameters + ---------- + op_name : str + The op name + + target : str + The target string for given intrinsic lowering function + + f : function, optional + The function to be registered. + + level : int + The priority level + + Returns + ------- + fregister : function + Register op lowering function if f is not specified. + """ + + def _register(f): + """internal register function""" + _ffi_api.RegisterOpLowerIntrinsic(op_name, f, target, level) + return f + + return _register(f) if f is not None else _register diff --git a/python/tvm/target/__init__.py b/python/tvm/target/__init__.py index 9482e782041c..92d72b25b44d 100644 --- a/python/tvm/target/__init__.py +++ b/python/tvm/target/__init__.py @@ -61,4 +61,3 @@ from .generic_func import generic_func, get_native_generic_func, override_native_generic_func from . import datatype from . import codegen -from .intrin import register_intrin_rule diff --git a/python/tvm/target/intrin.py b/python/tvm/target/intrin.py index 6d205bc0447c..a0242cf828a3 100644 --- a/python/tvm/target/intrin.py +++ b/python/tvm/target/intrin.py @@ -15,54 +15,10 @@ # specific language governing permissions and limitations # under the License. """Target dependent intrinsic registration.""" -import tvm._ffi +from tvm.ir import register_intrin_lowering from tvm.tir import call_pure_extern -# Intrinsic rule related code -def register_intrin_rule(target, intrin, f=None, override=False): - """Register an intrinsic function generation rule. - - Intrinsic generation rules are callback functions for - code generator to get device specific calls. - This function simply translates to. - - :code:`register_func("tvm.intrin.rule.%s.%s" % (target, intrin), f, override)` - - TVM may already pre-register intrinsic rules in the backend. - However, user can use this function to change the intrinsic translation - behavior or add new intrinsic rules during runtime. - - Parameters - ---------- - target : str - The name of codegen target. - - intrin : str - The name of the intrinsic. - - f : function, optional - The function to be registered. - - override: boolean optional - Whether override existing entry. - - Returns - ------- - fregister : function - Register function if f is not specified. - - Examples - -------- - The following code registers exp expansion rule for opencl. - - .. code-block:: python - - register_intrin_rule("opencl", "exp", my_exp_rule, override=True) - """ - return tvm._ffi.register_func("tvm.intrin.rule.%s.%s" % (target, intrin), f, override) - - def _rule_float_suffix(op): """Intrinsic rule: Add float suffix if it is float32. @@ -81,7 +37,7 @@ def _rule_float_suffix(op): See Also -------- - register_intrin_rule : The registration function for intrin rule. + register_intrin_lowering : The registration function for intrinsic lowering rule. """ name = op.op.name assert name.startswith("tir.") @@ -112,7 +68,7 @@ def _rule_float_direct(op): See Also -------- - register_intrin_rule : The registration function for intrin rule. + register_intrin_lowering : The registration function for intrinsic lowering rule. """ if str(op.dtype).startswith("float"): return call_pure_extern(op.dtype, op.op.name[4:], *op.args) @@ -120,6 +76,6 @@ def _rule_float_direct(op): # opencl pattern for exp -register_intrin_rule("opencl", "exp", _rule_float_direct, override=True) +register_intrin_lowering("tir.exp", target="opencl", f=_rule_float_direct, level=99) # default pattern for exp -register_intrin_rule("default", "exp", _rule_float_suffix, override=True) +register_intrin_lowering("tir.exp", target="default", f=_rule_float_suffix, level=99) diff --git a/python/tvm/topi/arm_cpu/tensor_intrin.py b/python/tvm/topi/arm_cpu/tensor_intrin.py index 4055d7b05c24..bc9b0491947b 100644 --- a/python/tvm/topi/arm_cpu/tensor_intrin.py +++ b/python/tvm/topi/arm_cpu/tensor_intrin.py @@ -19,6 +19,7 @@ import tvm from tvm import te +from tvm.ir import register_intrin_lowering def gemm_4x4_int8_int8_int32(M, N, K, unroll, in_type): @@ -1054,6 +1055,6 @@ def _q_multiply_shift_arm(op): return tvm.tir.Select(s < 0, out_1, out_2) -tvm.target.intrin.register_intrin_rule( - "llvm.aarch64", "q_multiply_shift", _q_multiply_shift_arm, override=True +register_intrin_lowering( + "tir.q_multiply_shift", target="llvm.aarch64", f=_q_multiply_shift_arm, level=99 ) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 2789452cc10b..f064360768c2 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -21,6 +21,7 @@ from tvm import te from tvm.contrib import nvcc from tvm.contrib.thrust import can_use_thrust, can_use_rocthrust +from tvm.ir import register_intrin_lowering from tvm.tir import if_then_else from .sort import argsort, argsort_thrust from .scan import exclusive_scan @@ -51,11 +52,9 @@ def opencl_atomic_add_rule(op): raise RuntimeError("only support int32") -tvm.target.intrin.register_intrin_rule("cuda", "atomic_add", cuda_atomic_add_rule, override=True) +register_intrin_lowering("tir.atomic_add", target="cuda", f=cuda_atomic_add_rule, level=99) -tvm.target.intrin.register_intrin_rule( - "opencl", "atomic_add", opencl_atomic_add_rule, override=True -) +register_intrin_lowering("tir.atomic_add", target="opencl", f=opencl_atomic_add_rule, level=99) def atomic_add(x, y): diff --git a/src/ir/op.cc b/src/ir/op.cc index 5d2dc704f5b7..5b258ed2f2f0 100644 --- a/src/ir/op.cc +++ b/src/ir/op.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include @@ -36,6 +37,7 @@ namespace tvm { using runtime::PackedFunc; using runtime::TVMArgs; using runtime::TVMRetValue; +using tir::FLowerIntrinsic; using OpRegistry = AttrRegistry; @@ -122,6 +124,12 @@ TVM_REGISTER_GLOBAL("ir.RegisterOpAttr") } }); +TVM_REGISTER_GLOBAL("ir.RegisterOpLowerIntrinsic") + .set_body_typed([](String name, PackedFunc f, String target, int plevel) { + tvm::OpRegEntry::RegisterOrGet(name).set_attr(target + ".FLowerIntrinsic", f, + plevel); + }); + // helper to get internal dev function in objectref. struct Op2ObjectPtr : public ObjectRef { static ObjectPtr Get(const Op& op) { return GetDataPtr(op); } diff --git a/src/target/intrin_rule.cc b/src/target/intrin_rule.cc index 1a7214476188..bfc3fe6fcc8c 100644 --- a/src/target/intrin_rule.cc +++ b/src/target/intrin_rule.cc @@ -24,108 +24,131 @@ #include "intrin_rule.h" #include +#include namespace tvm { namespace codegen { namespace intrin { +using tir::FLowerIntrinsic; -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.exp").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.exp").set_attr("default.FLowerIntrinsic", + DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.erf").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.erf").set_attr("default.FLowerIntrinsic", + DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.log").set_attr("default.FLowerIntrinsic", + DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log2").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.log2") + .set_attr("default.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log10").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.log10") + .set_attr("default.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log1p").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.log1p") + .set_attr("default.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tanh").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.tanh") + .set_attr("default.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tan").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.tan").set_attr("default.FLowerIntrinsic", + DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atan").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.atan") + .set_attr("default.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atanh").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.atanh") + .set_attr("default.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atan2").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.atan2") + .set_attr("default.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.cos").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.cos").set_attr("default.FLowerIntrinsic", + DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.acos").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.acos") + .set_attr("default.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.cosh").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.cosh") + .set_attr("default.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.acosh").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.acosh") + .set_attr("default.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sin").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.sin").set_attr("default.FLowerIntrinsic", + DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.asin").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.asin") + .set_attr("default.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sinh").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.sinh") + .set_attr("default.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.asinh").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.asinh") + .set_attr("default.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.hypot").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.hypot") + .set_attr("default.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.nextafter").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.nextafter") + .set_attr("default.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.copysign").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.copysign") + .set_attr("default.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.ldexp").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.ldexp") + .set_attr("default.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sqrt").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.sqrt") + .set_attr("default.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.floor").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.floor") + .set_attr("default.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.ceil").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.ceil") + .set_attr("default.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.round").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.round") + .set_attr("default.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.rsqrt") - .set_body([](const TVMArgs& args, TVMRetValue* rv) { - PrimExpr e = args[0]; +TVM_REGISTER_OP("tir.rsqrt") + .set_attr("default.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr { const CallNode* call = e.as(); ICHECK(call != nullptr); - auto one = make_const(call->args[0].dtype(), 1); - *rv = one / sqrt(call->args[0]); + return one / sqrt(call->args[0]); }); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.pow").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.pow").set_attr("default.FLowerIntrinsic", + DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sigmoid") - .set_body([](const TVMArgs& args, TVMRetValue* rv) { - PrimExpr e = args[0]; +TVM_REGISTER_OP("tir.sigmoid") + .set_attr("default.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr { const CallNode* call = e.as(); ICHECK(call != nullptr); - auto one = make_const(call->args[0].dtype(), 1); - *rv = one / (one + exp(-call->args[0])); + return one / (one + exp(-call->args[0])); }); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.isfinite") - .set_body([](const TVMArgs& args, TVMRetValue* rv) { - PrimExpr e = args[0]; +TVM_REGISTER_OP("tir.isfinite") + .set_attr("default.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr { const CallNode* call = e.as(); ICHECK(call != nullptr); - *rv = isfinite(call->args[0]); + return isfinite(call->args[0]); }); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.isinf") - .set_body([](const TVMArgs& args, TVMRetValue* rv) { - PrimExpr e = args[0]; +TVM_REGISTER_OP("tir.isinf") + .set_attr("default.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr { const CallNode* call = e.as(); ICHECK(call != nullptr); - *rv = isinf(call->args[0]); + return isinf(call->args[0]); }); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.q_multiply_shift") - .set_body([](const TVMArgs& args, TVMRetValue* rv) { +TVM_REGISTER_OP("tir.q_multiply_shift") + .set_attr("default.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr { using tir::make_const; - PrimExpr e = args[0]; const tir::CallNode* call = e.as(); ICHECK(call != nullptr); @@ -145,15 +168,15 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.q_multiply_shift") CHECK(int_node != nullptr); return int_node->value; }; - // Power of 2 is determined by the fixed_point_multiplier == 1 << 30. In case of power of 2, - // fixed point multiplier will represent a float value of 0.5. In fixed point, this is + // Power of 2 is determined by the fixed_point_multiplier == 1 << 30. In case of power of + // 2, fixed point multiplier will represent a float value of 0.5. In fixed point, this is // represented by 1 << 30. if (get_int_value(y) == (1 << 30)) { PrimExpr exp = s - 1; int exp_val = get_int_value(s) - 1; if (exp_val > 0) { // power of 2 is greater than 0, apply left shift. - *rv = x << exp; + return x << exp; } else { // power of 2 is less than 0, round and then apply right shift. DataType lp_dtype = DataType::Int(32, x.dtype().lanes()); @@ -161,7 +184,7 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.q_multiply_shift") exp = -exp; PrimExpr rounding_factor = one << (exp - 1); PrimExpr rounded_t = x + rounding_factor; - *rv = rounded_t >> exp; + return rounded_t >> exp; } } else { // Only int32 types are supported (any number of lanes is allowed) @@ -193,8 +216,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.q_multiply_shift") // 5) Simply right shift the result to get the final output. x = x >> total_right_shift; - // 6) The fixed point multiplication keeps the value in int32 range. Casting back to int32. - *rv = cast(lp_dtype, x); + // 6) The fixed point multiplication keeps the value in int32 range. Casting back to + // int32. + return cast(lp_dtype, x); } }); diff --git a/src/target/intrin_rule.h b/src/target/intrin_rule.h index 69196e1b2c39..6a517a9abd24 100644 --- a/src/target/intrin_rule.h +++ b/src/target/intrin_rule.h @@ -55,8 +55,7 @@ struct Direct { // Call pure extern function. template -inline void DispatchPureExtern(const TVMArgs& args, TVMRetValue* rv) { - PrimExpr e = args[0]; +inline PrimExpr DispatchPureExtern(const PrimExpr& e) { const CallNode* call = e.as(); ICHECK(call != nullptr); // Use string based dispatch to extern for backward compact @@ -72,9 +71,9 @@ inline void DispatchPureExtern(const TVMArgs& args, TVMRetValue* rv) { for (auto arg : call->args) { new_args.push_back(arg); } - *rv = Call(call->dtype, tir::builtin::call_pure_extern(), new_args); + return Call(call->dtype, builtin::call_pure_extern(), new_args); } else { - *rv = e; + return e; } } diff --git a/src/target/llvm/intrin_rule_hexagon.cc b/src/target/llvm/intrin_rule_hexagon.cc index d38225184c2e..82f7d5051391 100644 --- a/src/target/llvm/intrin_rule_hexagon.cc +++ b/src/target/llvm/intrin_rule_hexagon.cc @@ -19,44 +19,54 @@ #ifdef TVM_LLVM_VERSION +#include + #include "intrin_rule_llvm.h" namespace tvm { namespace codegen { namespace llvm { +using tir::FLowerIntrinsic; -TVM_REGISTER_GLOBAL("tvm.intrin.rule.hexagon.exp") - .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::exp, 1>); +TVM_REGISTER_OP("tir.exp").set_attr( + "hexagon.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::exp, 1>); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.hexagon.fma") - .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd, 3>); +TVM_REGISTER_OP("tir.fma").set_attr( + "hexagon.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd, 3>); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.hexagon.log") - .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log, 1>); +TVM_REGISTER_OP("tir.log").set_attr( + "hexagon.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::log, 1>); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.hexagon.sqrt") - .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::sqrt, 1>); +TVM_REGISTER_OP("tir.sqrt") + .set_attr("hexagon.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::sqrt, 1>); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.hexagon.floor") - .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::floor, 1>); +TVM_REGISTER_OP("tir.floor") + .set_attr("hexagon.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::floor, 1>); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.hexagon.ceil") - .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::ceil, 1>); +TVM_REGISTER_OP("tir.ceil") + .set_attr("hexagon.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::ceil, 1>); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.hexagon.trunc") - .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::trunc, 1>); +TVM_REGISTER_OP("tir.trunc") + .set_attr("hexagon.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::trunc, 1>); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.hexagon.fabs") - .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fabs, 1>); +TVM_REGISTER_OP("tir.fabs") + .set_attr("hexagon.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::fabs, 1>); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.hexagon.round") - .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::round, 1>); +TVM_REGISTER_OP("tir.round") + .set_attr("hexagon.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::round, 1>); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.hexagon.pow") - .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::pow, 1>); +TVM_REGISTER_OP("tir.pow").set_attr( + "hexagon.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::pow, 1>); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.hexagon.popcount") - .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::ctpop, 1>); +TVM_REGISTER_OP("tir.ctpop") + .set_attr("hexagon.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::ctpop, 1>); } // namespace llvm } // namespace codegen diff --git a/src/target/llvm/intrin_rule_llvm.cc b/src/target/llvm/intrin_rule_llvm.cc index 093a746adcab..2d30c2030685 100644 --- a/src/target/llvm/intrin_rule_llvm.cc +++ b/src/target/llvm/intrin_rule_llvm.cc @@ -25,73 +25,84 @@ #include "intrin_rule_llvm.h" #include +#include namespace tvm { namespace codegen { namespace llvm { +using tir::FLowerIntrinsic; -TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.prefetch") - .set_body(DispatchLLVMIntrin<::llvm::Intrinsic::prefetch, 4>); +TVM_REGISTER_OP("tir.prefetch") + .set_attr("llvm.FLowerIntrinsic", + DispatchLLVMIntrin<::llvm::Intrinsic::prefetch, 4>); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp") - .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::exp, 1>); +TVM_REGISTER_OP("tir.exp").set_attr( + "llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::exp, 1>); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp2") - .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::exp2, 1>); +TVM_REGISTER_OP("tir.exp2") + .set_attr("llvm.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::exp2, 1>); // TODO(tvm-team): migrate the legalization transformations as a separate // set of rules in TIR that can be shared across backends. -TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp10") - .set_body([](const TVMArgs& targs, TVMRetValue* rv) { +TVM_REGISTER_OP("tir.exp10") + .set_attr("llvm.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr { using tir::make_const; using tir::make_zero; - PrimExpr e = targs[0]; const tir::CallNode* call = e.as(); ICHECK(call != nullptr); const PrimExpr& x = call->args[0]; PrimExpr ln10 = make_const(x.dtype(), 2.302585093); PrimExpr ret = exp(x * ln10); - *rv = ret; + return ret; }); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.fma") - .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd, 3>); +TVM_REGISTER_OP("tir.fma").set_attr( + "llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd, 3>); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log") - .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log, 1>); +TVM_REGISTER_OP("tir.log").set_attr( + "llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::log, 1>); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log2") - .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log2, 1>); +TVM_REGISTER_OP("tir.log2") + .set_attr("llvm.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::log2, 1>); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log10") - .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log10, 1>); +TVM_REGISTER_OP("tir.log10") + .set_attr("llvm.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::log10, 1>); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.sqrt") - .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::sqrt, 1>); +TVM_REGISTER_OP("tir.sqrt") + .set_attr("llvm.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::sqrt, 1>); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.floor") - .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::floor, 1>); +TVM_REGISTER_OP("tir.floor") + .set_attr("llvm.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::floor, 1>); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.ceil") - .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::ceil, 1>); +TVM_REGISTER_OP("tir.ceil") + .set_attr("llvm.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::ceil, 1>); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.trunc") - .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::trunc, 1>); +TVM_REGISTER_OP("tir.trunc") + .set_attr("llvm.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::trunc, 1>); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.fabs") - .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fabs, 1>); +TVM_REGISTER_OP("tir.fabs") + .set_attr("llvm.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::fabs, 1>); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.round") - .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::round, 1>); +TVM_REGISTER_OP("tir.round") + .set_attr("llvm.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::round, 1>); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.nearbyint") - .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::nearbyint, 1>); +TVM_REGISTER_OP("tir.nearbyint") + .set_attr("llvm.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::nearbyint, 1>); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tanh") - .set_body([](const TVMArgs& targs, TVMRetValue* rv) { +TVM_REGISTER_OP("tir.tanh") + .set_attr("llvm.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr { using tir::make_const; using tir::make_zero; - PrimExpr e = targs[0]; const tir::CallNode* call = e.as(); ICHECK(call != nullptr); const PrimExpr& x = call->args[0]; @@ -104,32 +115,33 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tanh") PrimExpr tanh_pos = (one - exp_neg2x) / (one + exp_neg2x); PrimExpr tanh_neg = (exp_pos2x - one) / (exp_pos2x + one); - *rv = tir::Select(x >= make_zero(x.dtype()), tanh_pos, tanh_neg); + return tir::Select(x >= make_zero(x.dtype()), tanh_pos, tanh_neg); }); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.pow") - .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::pow, 2>); +TVM_REGISTER_OP("tir.pow").set_attr( + "llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::pow, 2>); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.popcount") - .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::ctpop, 1>); +TVM_REGISTER_OP("tir.popcount") + .set_attr("llvm.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::ctpop, 1>); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tan").set_body([](const TVMArgs& targs, TVMRetValue* rv) { - PrimExpr e = targs[0]; - const tir::CallNode* call = e.as(); - ICHECK(call != nullptr); - const PrimExpr& x = call->args[0]; - PrimExpr tan_x = sin(x) / cos(x); - *rv = tan_x; -}); +TVM_REGISTER_OP("tir.tan").set_attr("llvm.FLowerIntrinsic", + [](const PrimExpr& e) -> PrimExpr { + const tir::CallNode* call = + e.as(); + ICHECK(call != nullptr); + const PrimExpr& x = call->args[0]; + PrimExpr tan_x = sin(x) / cos(x); + return tan_x; + }); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.cos") - .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::cos, 1>); +TVM_REGISTER_OP("tir.cos").set_attr( + "llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::cos, 1>); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.cosh") - .set_body([](const TVMArgs& targs, TVMRetValue* rv) { +TVM_REGISTER_OP("tir.cosh") + .set_attr("llvm.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr { using tir::make_const; using tir::make_zero; - PrimExpr e = targs[0]; const tir::CallNode* call = e.as(); ICHECK(call != nullptr); const PrimExpr& x = call->args[0]; @@ -138,17 +150,16 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.cosh") PrimExpr exp_negx = exp(neg_one * x); PrimExpr exp_posx = exp(x); PrimExpr ret = (exp_posx + exp_negx) / two; - *rv = ret; + return ret; }); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.sin") - .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::sin, 1>); +TVM_REGISTER_OP("tir.sin").set_attr( + "llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::sin, 1>); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.sinh") - .set_body([](const TVMArgs& targs, TVMRetValue* rv) { +TVM_REGISTER_OP("tir.sinh") + .set_attr("llvm.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr { using tir::make_const; using tir::make_zero; - PrimExpr e = targs[0]; const tir::CallNode* call = e.as(); ICHECK(call != nullptr); const PrimExpr& x = call->args[0]; @@ -157,23 +168,23 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.sinh") PrimExpr exp_negx = exp(neg_one * x); PrimExpr exp_posx = exp(x); PrimExpr ret = (exp_posx - exp_negx) / two; - *rv = ret; + return ret; }); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.clz").set_body([](const TVMArgs& targs, TVMRetValue* rv) { - PrimExpr e = targs[0]; - const tir::CallNode* call = e.as(); - ICHECK(call != nullptr); - ICHECK_EQ(call->args.size(), 1); - Array cargs; - cargs.push_back(IntImm(DataType::UInt(32), ::llvm::Intrinsic::ctlz)); - cargs.push_back(IntImm(DataType::UInt(32), 2)); - cargs.push_back(call->args[0]); - cargs.push_back(IntImm(DataType::Int(1), 1)); // is_zero_undef - // LLVM requires that the return type must match the first argument type - auto clz = tir::Call(call->args[0]->dtype, tir::builtin::call_llvm_intrin(), cargs); - *rv = cast(call->dtype, clz); -}); +TVM_REGISTER_OP("tir.clz").set_attr( + "llvm.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr { + const tir::CallNode* call = e.as(); + ICHECK(call != nullptr); + ICHECK_EQ(call->args.size(), 1); + Array cargs; + cargs.push_back(IntImm(DataType::UInt(32), ::llvm::Intrinsic::ctlz)); + cargs.push_back(IntImm(DataType::UInt(32), 2)); + cargs.push_back(call->args[0]); + cargs.push_back(IntImm(DataType::Int(1), 1)); // is_zero_undef + // LLVM requires that the return type must match the first argument type + auto clz = tir::Call(call->args[0]->dtype, tir::builtin::call_llvm_intrin(), cargs); + return cast(call->dtype, clz); + }); } // namespace llvm } // namespace codegen diff --git a/src/target/llvm/intrin_rule_llvm.h b/src/target/llvm/intrin_rule_llvm.h index 99463793d8de..a926d7b9be31 100644 --- a/src/target/llvm/intrin_rule_llvm.h +++ b/src/target/llvm/intrin_rule_llvm.h @@ -38,8 +38,7 @@ namespace tvm { namespace codegen { // num_signature means number of arguments used to query signature template -inline void DispatchLLVMPureIntrin(const TVMArgs& targs, TVMRetValue* rv) { - PrimExpr e = targs[0]; +inline PrimExpr DispatchLLVMPureIntrin(const PrimExpr& e) { const tir::CallNode* call = e.as(); ICHECK(call != nullptr); Array cargs; @@ -50,12 +49,11 @@ inline void DispatchLLVMPureIntrin(const TVMArgs& targs, TVMRetValue* rv) { for (PrimExpr arg : call->args) { cargs.push_back(arg); } - *rv = tir::Call(call->dtype, tir::builtin::call_llvm_pure_intrin(), cargs); + return tir::Call(call->dtype, tir::builtin::call_llvm_pure_intrin(), cargs); } template -inline void DispatchLLVMIntrin(const TVMArgs& targs, TVMRetValue* rv) { - PrimExpr e = targs[0]; +inline PrimExpr DispatchLLVMIntrin(const PrimExpr& e) { const tir::CallNode* call = e.as(); ICHECK(call != nullptr); Array cargs; @@ -65,7 +63,7 @@ inline void DispatchLLVMIntrin(const TVMArgs& targs, TVMRetValue* rv) { for (PrimExpr arg : call->args) { cargs.push_back(arg); } - *rv = tir::Call(call->dtype, tir::builtin::call_llvm_intrin(), cargs); + return tir::Call(call->dtype, tir::builtin::call_llvm_intrin(), cargs); } } // namespace codegen diff --git a/src/target/llvm/intrin_rule_nvptx.cc b/src/target/llvm/intrin_rule_nvptx.cc index bb653e8ee5e0..0ee01a63c042 100644 --- a/src/target/llvm/intrin_rule_nvptx.cc +++ b/src/target/llvm/intrin_rule_nvptx.cc @@ -26,14 +26,14 @@ #include #include #include +#include #include namespace tvm { namespace codegen { -inline void DispatchPureExternLibDevice(const TVMArgs& args, TVMRetValue* rv) { - PrimExpr e = args[0]; +inline PrimExpr DispatchPureExternLibDevice(const PrimExpr& e) { using namespace tir; const CallNode* call = e.as(); ICHECK(call != nullptr); @@ -53,54 +53,77 @@ inline void DispatchPureExternLibDevice(const TVMArgs& args, TVMRetValue* rv) { for (auto arg : call->args) { new_args.push_back(arg); } - *rv = Call(call->dtype, builtin::call_pure_extern(), new_args); + return Call(call->dtype, builtin::call_pure_extern(), new_args); } namespace llvm { +using tir::FLowerIntrinsic; -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.floor").set_body(DispatchPureExternLibDevice); +TVM_REGISTER_OP("tir.floor") + .set_attr("nvptx.FLowerIntrinsic", DispatchPureExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.ceil").set_body(DispatchPureExternLibDevice); +TVM_REGISTER_OP("tir.ceil") + .set_attr("nvptx.FLowerIntrinsic", DispatchPureExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.round").set_body(DispatchPureExternLibDevice); +TVM_REGISTER_OP("tir.round") + .set_attr("nvptx.FLowerIntrinsic", DispatchPureExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.trunc").set_body(DispatchPureExternLibDevice); +TVM_REGISTER_OP("tir.trunc") + .set_attr("nvptx.FLowerIntrinsic", DispatchPureExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.fabs").set_body(DispatchPureExternLibDevice); +TVM_REGISTER_OP("tir.fabs") + .set_attr("nvptx.FLowerIntrinsic", DispatchPureExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp").set_body(DispatchPureExternLibDevice); +TVM_REGISTER_OP("tir.exp").set_attr("nvptx.FLowerIntrinsic", + DispatchPureExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp2").set_body(DispatchPureExternLibDevice); +TVM_REGISTER_OP("tir.exp2") + .set_attr("nvptx.FLowerIntrinsic", DispatchPureExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp10").set_body(DispatchPureExternLibDevice); +TVM_REGISTER_OP("tir.exp10") + .set_attr("nvptx.FLowerIntrinsic", DispatchPureExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.erf").set_body(DispatchPureExternLibDevice); +TVM_REGISTER_OP("tir.erf").set_attr("nvptx.FLowerIntrinsic", + DispatchPureExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.fma").set_body(DispatchPureExternLibDevice); +TVM_REGISTER_OP("tir.fma").set_attr("nvptx.FLowerIntrinsic", + DispatchPureExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.log").set_body(DispatchPureExternLibDevice); +TVM_REGISTER_OP("tir.log").set_attr("nvptx.FLowerIntrinsic", + DispatchPureExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.log2").set_body(DispatchPureExternLibDevice); +TVM_REGISTER_OP("tir.log2") + .set_attr("nvptx.FLowerIntrinsic", DispatchPureExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.log10").set_body(DispatchPureExternLibDevice); +TVM_REGISTER_OP("tir.log10") + .set_attr("nvptx.FLowerIntrinsic", DispatchPureExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.sqrt").set_body(DispatchPureExternLibDevice); +TVM_REGISTER_OP("tir.sqrt") + .set_attr("nvptx.FLowerIntrinsic", DispatchPureExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.pow").set_body(DispatchPureExternLibDevice); +TVM_REGISTER_OP("tir.pow").set_attr("nvptx.FLowerIntrinsic", + DispatchPureExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.tanh").set_body(DispatchPureExternLibDevice); +TVM_REGISTER_OP("tir.tanh") + .set_attr("nvptx.FLowerIntrinsic", DispatchPureExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.tan").set_body(DispatchPureExternLibDevice); +TVM_REGISTER_OP("tir.tan").set_attr("nvptx.FLowerIntrinsic", + DispatchPureExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.cos").set_body(DispatchPureExternLibDevice); +TVM_REGISTER_OP("tir.cos").set_attr("nvptx.FLowerIntrinsic", + DispatchPureExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.cosh").set_body(DispatchPureExternLibDevice); +TVM_REGISTER_OP("tir.cosh") + .set_attr("nvptx.FLowerIntrinsic", DispatchPureExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.sin").set_body(DispatchPureExternLibDevice); +TVM_REGISTER_OP("tir.sin").set_attr("nvptx.FLowerIntrinsic", + DispatchPureExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.sinh").set_body(DispatchPureExternLibDevice); +TVM_REGISTER_OP("tir.sinh") + .set_attr("nvptx.FLowerIntrinsic", DispatchPureExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.atan").set_body(DispatchPureExternLibDevice); +TVM_REGISTER_OP("tir.atan") + .set_attr("nvptx.FLowerIntrinsic", DispatchPureExternLibDevice); } // namespace llvm } // namespace codegen diff --git a/src/target/llvm/intrin_rule_rocm.cc b/src/target/llvm/intrin_rule_rocm.cc index 08b32ed1b946..072686868a81 100644 --- a/src/target/llvm/intrin_rule_rocm.cc +++ b/src/target/llvm/intrin_rule_rocm.cc @@ -26,14 +26,14 @@ #include #include #include +#include #include namespace tvm { namespace codegen { -inline void DispatchPureExternOCML(const TVMArgs& args, TVMRetValue* rv) { - PrimExpr e = args[0]; +inline PrimExpr DispatchPureExternOCML(const PrimExpr& e) { using namespace tir; const CallNode* call = e.as(); ICHECK(call != nullptr); @@ -51,13 +51,12 @@ inline void DispatchPureExternOCML(const TVMArgs& args, TVMRetValue* rv) { new_args.push_back(arg); } - *rv = Call(call->dtype, builtin::call_pure_extern(), new_args); + return Call(call->dtype, builtin::call_pure_extern(), new_args); } -inline void DispatchShuffle(const TVMArgs& targs, TVMRetValue* rv) { - PrimExpr e_call = targs[0]; +inline PrimExpr DispatchShuffle(const PrimExpr& e) { using namespace tir; - const CallNode* call = e_call.as(); + const CallNode* call = e.as(); ICHECK(call != nullptr); ICHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size PrimExpr var = call->args[1]; @@ -89,67 +88,93 @@ inline void DispatchShuffle(const TVMArgs& targs, TVMRetValue* rv) { } PrimExpr res = Call(var.dtype(), builtin::call_pure_extern(), {StringImm("llvm.amdgcn.ds.bpermute"), index << 2, var}); - *rv = res; + return res; } namespace llvm { +using tir::FLowerIntrinsic; // dummy because we don't have the activemask -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tvm_warp_activemask") - .set_body([](const TVMArgs& targs, TVMRetValue* rv) { +TVM_REGISTER_OP("tir.tvm_warp_activemask") + .set_attr("rocm.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr { PrimExpr zero = tir::make_zero(DataType::Int(32)); - *rv = zero; + return zero; }); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tvm_warp_shuffle").set_body(DispatchShuffle); +TVM_REGISTER_OP("tir.tvm_warp_shuffle") + .set_attr("rocm.FLowerIntrinsic", DispatchShuffle); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tvm_warp_shuffle_up").set_body(DispatchShuffle); +TVM_REGISTER_OP("tir.tvm_warp_shuffle_up") + .set_attr("rocm.FLowerIntrinsic", DispatchShuffle); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tvm_warp_shuffle_down").set_body(DispatchShuffle); +TVM_REGISTER_OP("tir.tvm_warp_shuffle_down") + .set_attr("rocm.FLowerIntrinsic", DispatchShuffle); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.floor").set_body(DispatchPureExternOCML); +TVM_REGISTER_OP("tir.floor") + .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.ceil").set_body(DispatchPureExternOCML); +TVM_REGISTER_OP("tir.ceil") + .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.round").set_body(DispatchPureExternOCML); +TVM_REGISTER_OP("tir.round") + .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.trunc").set_body(DispatchPureExternOCML); +TVM_REGISTER_OP("tir.trunc") + .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.fabs").set_body(DispatchPureExternOCML); +TVM_REGISTER_OP("tir.fabs") + .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp").set_body(DispatchPureExternOCML); +TVM_REGISTER_OP("tir.exp").set_attr("rocm.FLowerIntrinsic", + DispatchPureExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp2").set_body(DispatchPureExternOCML); +TVM_REGISTER_OP("tir.exp2") + .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp10").set_body(DispatchPureExternOCML); +TVM_REGISTER_OP("tir.exp10") + .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.erf").set_body(DispatchPureExternOCML); +TVM_REGISTER_OP("tir.erf").set_attr("rocm.FLowerIntrinsic", + DispatchPureExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.fma").set_body(DispatchPureExternOCML); +TVM_REGISTER_OP("tir.fma").set_attr("rocm.FLowerIntrinsic", + DispatchPureExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.log").set_body(DispatchPureExternOCML); +TVM_REGISTER_OP("tir.log").set_attr("rocm.FLowerIntrinsic", + DispatchPureExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.log2").set_body(DispatchPureExternOCML); +TVM_REGISTER_OP("tir.log2") + .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.log10").set_body(DispatchPureExternOCML); +TVM_REGISTER_OP("tir.log10") + .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.sqrt").set_body(DispatchPureExternOCML); +TVM_REGISTER_OP("tir.sqrt") + .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.pow").set_body(DispatchPureExternOCML); +TVM_REGISTER_OP("tir.pow").set_attr("rocm.FLowerIntrinsic", + DispatchPureExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tanh").set_body(DispatchPureExternOCML); +TVM_REGISTER_OP("tir.tanh") + .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tan").set_body(DispatchPureExternOCML); +TVM_REGISTER_OP("tir.tan").set_attr("rocm.FLowerIntrinsic", + DispatchPureExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.cos").set_body(DispatchPureExternOCML); +TVM_REGISTER_OP("tir.cos").set_attr("rocm.FLowerIntrinsic", + DispatchPureExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.cosh").set_body(DispatchPureExternOCML); +TVM_REGISTER_OP("tir.cosh") + .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.sin").set_body(DispatchPureExternOCML); +TVM_REGISTER_OP("tir.sin").set_attr("rocm.FLowerIntrinsic", + DispatchPureExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.sinh").set_body(DispatchPureExternOCML); +TVM_REGISTER_OP("tir.sinh") + .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.atan").set_body(DispatchPureExternOCML); +TVM_REGISTER_OP("tir.atan") + .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); } // namespace llvm } // namespace codegen diff --git a/src/target/source/intrin_rule_aocl.cc b/src/target/source/intrin_rule_aocl.cc index 69279a041413..09fc087ca252 100644 --- a/src/target/source/intrin_rule_aocl.cc +++ b/src/target/source/intrin_rule_aocl.cc @@ -21,55 +21,80 @@ * \file intrin_rule_aocl.cc * \brief AOCL intrinsic rules. */ +#include + #include "../intrin_rule.h" namespace tvm { namespace codegen { namespace intrin { +using tir::FLowerIntrinsic; -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.floor").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.floor") + .set_attr("aocl.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.ceil").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.ceil") + .set_attr("aocl.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.trunc").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.trunc") + .set_attr("aocl.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.fabs").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.fabs") + .set_attr("aocl.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.round").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.round") + .set_attr("aocl.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.exp").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.exp").set_attr("aocl.FLowerIntrinsic", + DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.log").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.log").set_attr("aocl.FLowerIntrinsic", + DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.tanh").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.tanh") + .set_attr("aocl.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.sqrt").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.sqrt") + .set_attr("aocl.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.pow").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.pow").set_attr("aocl.FLowerIntrinsic", + DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.popcount").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.popcount") + .set_attr("aocl.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.floor").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.floor") + .set_attr("aocl_sw_emu.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.ceil").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.ceil") + .set_attr("aocl_sw_emu.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.trunc").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.trunc") + .set_attr("aocl_sw_emu.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.fabs").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.fabs") + .set_attr("aocl_sw_emu.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.round").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.round") + .set_attr("aocl_sw_emu.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.exp").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.exp").set_attr("aocl_sw_emu.FLowerIntrinsic", + DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.log").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.log").set_attr("aocl_sw_emu.FLowerIntrinsic", + DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.tanh").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.tanh") + .set_attr("aocl_sw_emu.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.sqrt").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.sqrt") + .set_attr("aocl_sw_emu.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.pow").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.pow").set_attr("aocl_sw_emu.FLowerIntrinsic", + DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.popcount").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.popcount") + .set_attr("aocl_sw_emu.FLowerIntrinsic", DispatchPureExtern); } // namespace intrin } // namespace codegen diff --git a/src/target/source/intrin_rule_cuda.cc b/src/target/source/intrin_rule_cuda.cc index 965b86c24d9e..edbfb4fe5b83 100644 --- a/src/target/source/intrin_rule_cuda.cc +++ b/src/target/source/intrin_rule_cuda.cc @@ -30,6 +30,8 @@ namespace tvm { namespace codegen { namespace intrin { // Add float suffix to the intrinsics, CUDA fast math. +using tir::FLowerIntrinsic; + struct CUDAMath { std::string operator()(DataType t, std::string name) const { if (t.is_float()) { @@ -110,79 +112,100 @@ struct CUDAWarpIntrinsic { } }; -static void DispatchCUDAWarpActiveMask(const TVMArgs& args, TVMRetValue* rv) { - Call call = args[0]; - *rv = Call(call->dtype, Op::Get("tir.cuda.__activemask"), call->args); +static PrimExpr DispatchCUDAWarpActiveMask(const PrimExpr& e) { + const CallNode* call = e.as(); + return Call(call->dtype, Op::Get("tir.cuda.__activemask"), call->args); } template -static void DispatchCUDAShuffle(const TVMArgs& args, TVMRetValue* rv) { - PrimExpr e = args[0]; +static PrimExpr DispatchCUDAShuffle(const PrimExpr& e) { const CallNode* call = e.as(); ICHECK(call != nullptr); ICHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size Array cuda_args{{call->args[0], call->args[1], call->args[2], call->args[3]}}; - - *rv = Call(call->dtype, T()(call->dtype, Downcast(call->op)), cuda_args); + return Call(call->dtype, T()(call->dtype, Downcast(call->op)), cuda_args); } -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.floor").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.floor") + .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.ceil").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.ceil") + .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.trunc").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.trunc") + .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.fabs").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.fabs") + .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.round").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.round") + .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.exp").set_attr("cuda.FLowerIntrinsic", + DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp2").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.exp2") + .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp10").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.exp10") + .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.erf").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.erf").set_attr("cuda.FLowerIntrinsic", + DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.log").set_attr("cuda.FLowerIntrinsic", + DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log2").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.log2") + .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log10").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.log10") + .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tan").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.tan").set_attr("cuda.FLowerIntrinsic", + DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.cos").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.cos").set_attr("cuda.FLowerIntrinsic", + DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.cosh").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.cosh") + .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sin").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.sin").set_attr("cuda.FLowerIntrinsic", + DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sinh").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.sinh") + .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.atan").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.atan") + .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tanh").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.tanh") + .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sqrt").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.sqrt") + .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.pow").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.pow").set_attr("cuda.FLowerIntrinsic", + DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.popcount").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.popcount") + .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_shuffle") - .set_body(DispatchCUDAShuffle); +TVM_REGISTER_OP("tir.tvm_warp_shuffle") + .set_attr("cuda.FLowerIntrinsic", DispatchCUDAShuffle); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_shuffle_up") - .set_body(DispatchCUDAShuffle); +TVM_REGISTER_OP("tir.tvm_warp_shuffle_up") + .set_attr("cuda.FLowerIntrinsic", DispatchCUDAShuffle); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_shuffle_down") - .set_body(DispatchCUDAShuffle); +TVM_REGISTER_OP("tir.tvm_warp_shuffle_down") + .set_attr("cuda.FLowerIntrinsic", DispatchCUDAShuffle); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_activemask") - .set_body(DispatchCUDAWarpActiveMask); +TVM_REGISTER_OP("tir.tvm_warp_activemask") + .set_attr("cuda.FLowerIntrinsic", DispatchCUDAWarpActiveMask); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.fmod").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.fmod") + .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); // Register low-level builtin ops. // TODO(tvm-team): consider make CUDA its own subfolder and create a file for low-level builtins. diff --git a/src/target/source/intrin_rule_metal.cc b/src/target/source/intrin_rule_metal.cc index 80a10312c011..3b072fcc7006 100644 --- a/src/target/source/intrin_rule_metal.cc +++ b/src/target/source/intrin_rule_metal.cc @@ -21,51 +21,74 @@ * \file intrin_rule_metal.cc * \brief Metal intrinsic rules. */ +#include + #include "../intrin_rule.h" namespace tvm { namespace codegen { namespace intrin { +using tir::FLowerIntrinsic; -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.floor").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.floor") + .set_attr("metal.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.ceil").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.ceil") + .set_attr("metal.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.trunc").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.trunc") + .set_attr("metal.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.fabs").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.fabs") + .set_attr("metal.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.round").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.round") + .set_attr("metal.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.exp").set_attr("metal.FLowerIntrinsic", + DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp2").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.exp2") + .set_attr("metal.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp10").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.exp10") + .set_attr("metal.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.log").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.log").set_attr("metal.FLowerIntrinsic", + DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.log2").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.log2") + .set_attr("metal.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.log10").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.log10") + .set_attr("metal.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.tanh").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.tanh") + .set_attr("metal.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.sqrt").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.sqrt") + .set_attr("metal.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.pow").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.pow").set_attr("metal.FLowerIntrinsic", + DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.popcount").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.popcount") + .set_attr("metal.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.fmod").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.fmod") + .set_attr("metal.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.sin").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.sin").set_attr("metal.FLowerIntrinsic", + DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.sinh").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.sinh") + .set_attr("metal.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.cos").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.cos").set_attr("metal.FLowerIntrinsic", + DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.cosh").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.cosh") + .set_attr("metal.FLowerIntrinsic", DispatchPureExtern); } // namespace intrin } // namespace codegen diff --git a/src/target/source/intrin_rule_opencl.cc b/src/target/source/intrin_rule_opencl.cc index 54da5c74ab02..288bb2cfc069 100644 --- a/src/target/source/intrin_rule_opencl.cc +++ b/src/target/source/intrin_rule_opencl.cc @@ -22,57 +22,78 @@ * \brief OpenCL intrinsic rules. */ #include +#include #include "../intrin_rule.h" namespace tvm { namespace codegen { namespace intrin { +using tir::FLowerIntrinsic; -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.floor").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.floor") + .set_attr("opencl.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.ceil").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.ceil") + .set_attr("opencl.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.trunc").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.trunc") + .set_attr("opencl.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.fabs").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.fabs") + .set_attr("opencl.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.round").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.round") + .set_attr("opencl.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.exp").set_attr("opencl.FLowerIntrinsic", + DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp2").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.exp2") + .set_attr("opencl.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp10").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.exp10") + .set_attr("opencl.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.log").set_attr("opencl.FLowerIntrinsic", + DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log2").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.log2") + .set_attr("opencl.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log10").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.log10") + .set_attr("opencl.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tanh").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.tanh") + .set_attr("opencl.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sqrt").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.sqrt") + .set_attr("opencl.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.pow").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.pow").set_attr("opencl.FLowerIntrinsic", + DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.popcount").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.popcount") + .set_attr("opencl.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.fmod").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.fmod") + .set_attr("opencl.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sin").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.sin").set_attr("opencl.FLowerIntrinsic", + DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sinh").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.sinh") + .set_attr("opencl.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.cos").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.cos").set_attr("opencl.FLowerIntrinsic", + DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.cosh").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.cosh") + .set_attr("opencl.FLowerIntrinsic", DispatchPureExtern); // There is no warp shuffle instruction in standard OpenCL // When shuffle is used, we assume it is intel's shuffle extension -static void DispatchIntelShuffle(const TVMArgs& args, TVMRetValue* rv) { - PrimExpr e = args[0]; +static PrimExpr DispatchIntelShuffle(const PrimExpr& e) { const CallNode* call = e.as(); ICHECK(call != nullptr); ICHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size @@ -80,10 +101,11 @@ static void DispatchIntelShuffle(const TVMArgs& args, TVMRetValue* rv) { ICHECK(analyzer.CanProve(call->args[3] == call->args[4])) << "Intel warp shuffle dose not support width != warp_size"; Array opencl_args{{StringImm("intel_sub_group_shuffle"), call->args[1], call->args[2]}}; - *rv = Call(call->dtype, builtin::call_pure_extern(), opencl_args); + return Call(call->dtype, builtin::call_pure_extern(), opencl_args); } -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tvm_warp_shuffle").set_body(DispatchIntelShuffle); +TVM_REGISTER_OP("tir.tvm_warp_shuffle") + .set_attr("opencl.FLowerIntrinsic", DispatchIntelShuffle); } // namespace intrin } // namespace codegen diff --git a/src/target/source/intrin_rule_vhls.cc b/src/target/source/intrin_rule_vhls.cc index da9bc79452ed..57be8ae17a57 100644 --- a/src/target/source/intrin_rule_vhls.cc +++ b/src/target/source/intrin_rule_vhls.cc @@ -21,49 +21,71 @@ * \file intrin_rule_vhls.cc * \brief VHLS intrinsic rules. */ +#include + #include "../intrin_rule.h" namespace tvm { namespace codegen { namespace intrin { +using tir::FLowerIntrinsic; -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.floor").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.floor") + .set_attr("sdaccel.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.ceil").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.ceil") + .set_attr("sdaccel.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.trunc").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.trunc") + .set_attr("sdaccel.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.fabs").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.fabs") + .set_attr("sdaccel.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.round").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.round") + .set_attr("sdaccel.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.exp").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.exp").set_attr("sdaccel.FLowerIntrinsic", + DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.exp2").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.exp2") + .set_attr("sdaccel.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.exp10").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.exp10") + .set_attr("sdaccel.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.log").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.log").set_attr("sdaccel.FLowerIntrinsic", + DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.log2").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.log2") + .set_attr("sdaccel.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.log10").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.log10") + .set_attr("sdaccel.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.tanh").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.tanh") + .set_attr("sdaccel.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.sqrt").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.sqrt") + .set_attr("sdaccel.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.pow").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.pow").set_attr("sdaccel.FLowerIntrinsic", + DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.popcount").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.popcount") + .set_attr("sdaccel.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.sin").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.sin").set_attr("sdaccel.FLowerIntrinsic", + DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.sinh").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.sinh") + .set_attr("sdaccel.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.cos").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.cos").set_attr("sdaccel.FLowerIntrinsic", + DispatchPureExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.cosh").set_body(DispatchPureExtern); +TVM_REGISTER_OP("tir.cosh") + .set_attr("sdaccel.FLowerIntrinsic", DispatchPureExtern); } // namespace intrin } // namespace codegen diff --git a/src/target/spirv/intrin_rule_spirv.cc b/src/target/spirv/intrin_rule_spirv.cc index f77e8f4a26f8..3baa77e961cc 100644 --- a/src/target/spirv/intrin_rule_spirv.cc +++ b/src/target/spirv/intrin_rule_spirv.cc @@ -25,18 +25,17 @@ #include #include #include +#include namespace tvm { namespace codegen { namespace spirv { - -using namespace runtime; +using tir::FLowerIntrinsic; // num_signature means number of arguments used to query signature template -PrimExpr CallGLSLIntrin(const TVMArgs& targs, TVMRetValue* rv) { - PrimExpr e = targs[0]; +PrimExpr CallGLSLIntrin(const PrimExpr& e) { const tir::CallNode* call = e.as(); ICHECK(call != nullptr); Array cargs; @@ -50,73 +49,89 @@ PrimExpr CallGLSLIntrin(const TVMArgs& targs, TVMRetValue* rv) { } template -inline void DispatchGLSLPureIntrin(const TVMArgs& targs, TVMRetValue* rv) { - *rv = CallGLSLIntrin(targs, rv); +inline PrimExpr DispatchGLSLPureIntrin(const PrimExpr& e) { + return CallGLSLIntrin(e); } -TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.floor") - .set_body(DispatchGLSLPureIntrin); +TVM_REGISTER_OP("tir.floor") + .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.ceil").set_body(DispatchGLSLPureIntrin); +TVM_REGISTER_OP("tir.ceil") + .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.round") - .set_body(DispatchGLSLPureIntrin); +TVM_REGISTER_OP("tir.round") + .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.trunc") - .set_body(DispatchGLSLPureIntrin); +TVM_REGISTER_OP("tir.trunc") + .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.fabs").set_body(DispatchGLSLPureIntrin); +TVM_REGISTER_OP("tir.fabs") + .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.exp").set_body(DispatchGLSLPureIntrin); +TVM_REGISTER_OP("tir.exp").set_attr("vulkan.FLowerIntrinsic", + DispatchGLSLPureIntrin); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.sin").set_body(DispatchGLSLPureIntrin); +TVM_REGISTER_OP("tir.sin").set_attr("vulkan.FLowerIntrinsic", + DispatchGLSLPureIntrin); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.cos").set_body(DispatchGLSLPureIntrin); +TVM_REGISTER_OP("tir.cos").set_attr("vulkan.FLowerIntrinsic", + DispatchGLSLPureIntrin); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.log").set_body(DispatchGLSLPureIntrin); +TVM_REGISTER_OP("tir.log").set_attr("vulkan.FLowerIntrinsic", + DispatchGLSLPureIntrin); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.log2").set_body(DispatchGLSLPureIntrin); +TVM_REGISTER_OP("tir.log2") + .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.sqrt").set_body(DispatchGLSLPureIntrin); +TVM_REGISTER_OP("tir.sqrt") + .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.pow").set_body(DispatchGLSLPureIntrin); +TVM_REGISTER_OP("tir.pow").set_attr("vulkan.FLowerIntrinsic", + DispatchGLSLPureIntrin); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.tanh").set_body(DispatchGLSLPureIntrin); +TVM_REGISTER_OP("tir.tanh") + .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.clz") - .set_body([](const TVMArgs& targs, TVMRetValue* rv) { - PrimExpr e = targs[0]; +TVM_REGISTER_OP("tir.clz").set_attr( + "vulkan.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr { const tir::CallNode* call = e.as(); ICHECK(call != nullptr); ICHECK_EQ(call->args.size(), 1); PrimExpr arg = call->args[0]; - PrimExpr msb = CallGLSLIntrin(targs, rv); - *rv = PrimExpr(arg.dtype().bits() - 1) - msb; + PrimExpr msb = CallGLSLIntrin(e); + return PrimExpr(arg.dtype().bits() - 1) - msb; }); // WebGPU rules. -TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.floor") - .set_body(DispatchGLSLPureIntrin); +TVM_REGISTER_OP("tir.floor") + .set_attr("webgpu.FLowerIntrinsic", DispatchGLSLPureIntrin); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.ceil").set_body(DispatchGLSLPureIntrin); +TVM_REGISTER_OP("tir.ceil") + .set_attr("webgpu.FLowerIntrinsic", DispatchGLSLPureIntrin); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.round") - .set_body(DispatchGLSLPureIntrin); +TVM_REGISTER_OP("tir.round") + .set_attr("webgpu.FLowerIntrinsic", DispatchGLSLPureIntrin); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.trunc") - .set_body(DispatchGLSLPureIntrin); +TVM_REGISTER_OP("tir.trunc") + .set_attr("webgpu.FLowerIntrinsic", DispatchGLSLPureIntrin); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.fabs").set_body(DispatchGLSLPureIntrin); +TVM_REGISTER_OP("tir.fabs") + .set_attr("webgpu.FLowerIntrinsic", DispatchGLSLPureIntrin); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.exp").set_body(DispatchGLSLPureIntrin); +TVM_REGISTER_OP("tir.exp").set_attr("webgpu.FLowerIntrinsic", + DispatchGLSLPureIntrin); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.log").set_body(DispatchGLSLPureIntrin); +TVM_REGISTER_OP("tir.log").set_attr("webgpu.FLowerIntrinsic", + DispatchGLSLPureIntrin); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.sqrt").set_body(DispatchGLSLPureIntrin); +TVM_REGISTER_OP("tir.sqrt") + .set_attr("webgpu.FLowerIntrinsic", DispatchGLSLPureIntrin); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.pow").set_body(DispatchGLSLPureIntrin); +TVM_REGISTER_OP("tir.pow").set_attr("webgpu.FLowerIntrinsic", + DispatchGLSLPureIntrin); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.tanh").set_body(DispatchGLSLPureIntrin); +TVM_REGISTER_OP("tir.tanh") + .set_attr("webgpu.FLowerIntrinsic", DispatchGLSLPureIntrin); } // namespace spirv } // namespace codegen diff --git a/src/tir/transforms/lower_intrin.cc b/src/tir/transforms/lower_intrin.cc index cd7c10ffa688..4101891db699 100644 --- a/src/tir/transforms/lower_intrin.cc +++ b/src/tir/transforms/lower_intrin.cc @@ -42,28 +42,42 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { IntrinInjecter(arith::Analyzer* analyzer, std::string target, std::string mtriple = "") : IRMutatorWithAnalyzer(analyzer) { - patterns_.push_back("tvm.intrin.rule." + target + "."); + std::vector patterns_; + patterns_.push_back(target + ".FLowerIntrinsic"); bool is_llvm_aarch64 = (mtriple.find("aarch64") != std::string::npos); if (is_llvm_aarch64) { - patterns_.push_back("tvm.intrin.rule." + target + "." + "aarch64."); + patterns_.push_back(target + ".aarch64.FLowerIntrinsic"); } - patterns_.push_back("tvm.intrin.rule.default."); - fma_ = runtime::Registry::Get(patterns_[0] + "fma"); + patterns_.push_back("default.FLowerIntrinsic"); + + fma_ = runtime::Registry::Get("tvm.intrin.rule." + target + ".fma"); if (target == "stackvm") { support_bitwise_op_ = false; } + + for (const std::string& pattern : patterns_) + if (Op::HasAttrMap(pattern)) + lower_intrin_maps_.push_back(Op::GetAttrMap(pattern)); } PrimExpr VisitExpr_(const CallNode* op) final { if (auto* ptr_op = op->op.as()) { - // Still use legacy string based rewriting - // TODO(tvm-team): migrate the pattern application from global function look up - // to an OpAttrMap - std::string name = ptr_op->name; - PrimExpr r = ApplyPattern(name, GetRef(op)); - if (r.defined()) return r; + for (const auto& f_lower_intrin_map : lower_intrin_maps_) { + FLowerIntrinsic f = f_lower_intrin_map.get(GetRef(ptr_op), nullptr); + if (f != nullptr) { + PrimExpr e = GetRef(op); + PrimExpr r = f(e); + ICHECK(r.defined()) << "intrinsic rule must always return valid Expr"; + if (!r.same_as(e)) { + r = this->VisitExpr(r); + if (r.defined()) { + return r; + } + } + } + } } return IRMutatorWithAnalyzer::VisitExpr_(op); } @@ -266,32 +280,8 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { return IRMutatorWithAnalyzer::VisitExpr_(op); } - PrimExpr ApplyPattern(std::string name, const PrimExpr& e) { - if (name.compare(0, 4, "tir.") == 0) { - name = name.substr(4); - } - - for (size_t i = 0; i < patterns_.size(); ++i) { - std::string& p = patterns_[i]; - size_t psize = p.length(); - p.resize(psize + name.length()); - name.copy(&p[0] + psize, name.length()); - const runtime::PackedFunc* f = runtime::Registry::Get(p); - p.resize(psize); - // if pattern exists. - if (f != nullptr) { - PrimExpr r = (*f)(e); - ICHECK(r.defined()) << "intrinsic rule must always return valid Expr"; - if (!r.same_as(e)) { - return this->VisitExpr(r); - } - } - } - return PrimExpr(); - } - // patterns - std::vector patterns_; + std::vector> lower_intrin_maps_; const PackedFunc* fma_{nullptr}; bool support_bitwise_op_{true}; }; diff --git a/tutorials/language/intrin_math.py b/tutorials/language/intrin_math.py index 145322586b41..92383b90a53f 100644 --- a/tutorials/language/intrin_math.py +++ b/tutorials/language/intrin_math.py @@ -29,10 +29,11 @@ the interface via tvm's intrinsic API. """ from __future__ import absolute_import, print_function +import numpy as np import tvm from tvm import te -import numpy as np +from tvm.ir import register_op_attr, register_intrin_lowering ###################################################################### # Direct Declare Extern Math Call @@ -112,7 +113,7 @@ def my_cuda_math_rule(op): return op -tvm.target.register_intrin_rule("cuda", "exp", my_cuda_math_rule, override=True) +register_intrin_lowering("tir.exp", target="cuda", f=my_cuda_math_rule, level=99) ###################################################################### # Register the rule to TVM with override option to override existing rule. # Notice the difference between the printed code from previous one: @@ -147,8 +148,8 @@ def my_cuda_mylog_rule(op): # new op registration is triggered by registering an attribute of the op -tvm.ir.register_op_attr("tir.mylog", "TCallEffectKind", tvm.tir.CallEffectKind.Pure) -tvm.target.register_intrin_rule("cuda", "mylog", my_cuda_mylog_rule, override=True) +register_op_attr("tir.mylog", "TCallEffectKind", tvm.tir.CallEffectKind.Pure) +register_intrin_lowering("tir.mylog", target="cuda", f=my_cuda_mylog_rule, level=99) n = te.var("n") A = te.placeholder((n,), name="A") diff --git a/vta/python/vta/environment.py b/vta/python/vta/environment.py index 4b6e5bdeca78..9181a44fa523 100644 --- a/vta/python/vta/environment.py +++ b/vta/python/vta/environment.py @@ -23,6 +23,7 @@ import copy import tvm from tvm import te +from tvm.ir import register_intrin_lowering from . import intrin @@ -291,8 +292,8 @@ def mem_info_acc_buffer(): ) -# TVM related registration -@tvm.register_func("tvm.intrin.rule.default.vta.coproc_sync") +# TVM Op related registration +@register_intrin_lowering("tir.vta.coproc_sync", "default") def coproc_sync(op): _ = op return tvm.tir.call_extern( @@ -303,14 +304,14 @@ def coproc_sync(op): ) -@tvm.register_func("tvm.intrin.rule.default.vta.coproc_dep_push") +@register_intrin_lowering("tir.vta.coproc_dep_push", "default") def coproc_dep_push(op): return tvm.tir.call_extern( "int32", "VTADepPush", get_env().dev.command_handle, op.args[0], op.args[1] ) -@tvm.register_func("tvm.intrin.rule.default.vta.coproc_dep_pop") +@register_intrin_lowering("tir.vta.coproc_dep_pop", "default") def coproc_dep_pop(op): return tvm.tir.call_extern( "int32", "VTADepPop", get_env().dev.command_handle, op.args[0], op.args[1]