Skip to content

Commit

Permalink
[Target][Lowering] Update Op Intrinsic Lowering Mechanism And Intrins…
Browse files Browse the repository at this point in the history
…ic Lowering Pass (apache#7809)

This PR updated the intrinsic lowering pass to support the new op registry and avoid overloading the global tvm registry. Meanwhile, it kept the fallback mechanism to find the most suitable lower intrinsic function, e.g., llvm.FLowerIntrinsic vs. default.FLowerIntrinsic. All previous op registration are ported to new functions, and some missing ops would be added in separate PR.
  • Loading branch information
zxybazh authored and trevor-m committed May 11, 2021
1 parent 949c9a3 commit 95f8a8c
Show file tree
Hide file tree
Showing 24 changed files with 708 additions and 485 deletions.
13 changes: 12 additions & 1 deletion include/tvm/tir/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,12 @@
#ifndef TVM_TIR_OP_ATTR_TYPES_H_
#define TVM_TIR_OP_ATTR_TYPES_H_

#include <tvm/ir/expr.h>
#include <tvm/runtime/container.h>
#include <tvm/runtime/packed_func.h>

namespace tvm {
namespace tir {

/*!
* \brief Global symbol of the op after lowering.
*/
Expand All @@ -43,6 +44,16 @@ using TGlobalSymbol = String;
*/
using TVectorizable = bool;

/*!
* \brief The intrinsic lowering function for given op.
*/
using FLowerIntrinsic = runtime::TypedPackedFunc<PrimExpr(PrimExpr)>;

/*!
* \brief The legalization function for given tir op.
*/
using FLegalize = runtime::TypedPackedFunc<PrimExpr(PrimExpr)>;

/*!
* \brief The effect type of the call.
*/
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from .tensor_type import TensorType
from .type_relation import TypeCall, TypeRelation
from .expr import BaseExpr, PrimExpr, RelayExpr, GlobalVar, Range
from .op import Op, register_op_attr
from .op import Op, register_op_attr, register_intrin_lowering
from .function import CallingConv, BaseFunc
from .adt import Constructor, TypeData
from .module import IRModule
Expand Down
37 changes: 37 additions & 0 deletions python/tvm/ir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,40 @@ def _register(v):
return v

return _register(value) if value is not None else _register


def register_intrin_lowering(
op_name,
target,
*,
f=None,
level=10,
):
"""Register Op lowering function
Parameters
----------
op_name : str
The op name
target : str
The target string for given intrinsic lowering function
f : function, optional
The function to be registered.
level : int
The priority level
Returns
-------
fregister : function
Register op lowering function if f is not specified.
"""

def _register(f):
"""internal register function"""
_ffi_api.RegisterOpLowerIntrinsic(op_name, f, target, level)
return f

return _register(f) if f is not None else _register
1 change: 0 additions & 1 deletion python/tvm/target/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,3 @@
from .generic_func import generic_func, get_native_generic_func, override_native_generic_func
from . import datatype
from . import codegen
from .intrin import register_intrin_rule
54 changes: 5 additions & 49 deletions python/tvm/target/intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,54 +15,10 @@
# specific language governing permissions and limitations
# under the License.
"""Target dependent intrinsic registration."""
import tvm._ffi
from tvm.ir import register_intrin_lowering
from tvm.tir import call_pure_extern


# Intrinsic rule related code
def register_intrin_rule(target, intrin, f=None, override=False):
"""Register an intrinsic function generation rule.
Intrinsic generation rules are callback functions for
code generator to get device specific calls.
This function simply translates to.
:code:`register_func("tvm.intrin.rule.%s.%s" % (target, intrin), f, override)`
TVM may already pre-register intrinsic rules in the backend.
However, user can use this function to change the intrinsic translation
behavior or add new intrinsic rules during runtime.
Parameters
----------
target : str
The name of codegen target.
intrin : str
The name of the intrinsic.
f : function, optional
The function to be registered.
override: boolean optional
Whether override existing entry.
Returns
-------
fregister : function
Register function if f is not specified.
Examples
--------
The following code registers exp expansion rule for opencl.
.. code-block:: python
register_intrin_rule("opencl", "exp", my_exp_rule, override=True)
"""
return tvm._ffi.register_func("tvm.intrin.rule.%s.%s" % (target, intrin), f, override)


def _rule_float_suffix(op):
"""Intrinsic rule: Add float suffix if it is float32.
Expand All @@ -81,7 +37,7 @@ def _rule_float_suffix(op):
See Also
--------
register_intrin_rule : The registration function for intrin rule.
register_intrin_lowering : The registration function for intrinsic lowering rule.
"""
name = op.op.name
assert name.startswith("tir.")
Expand Down Expand Up @@ -112,14 +68,14 @@ def _rule_float_direct(op):
See Also
--------
register_intrin_rule : The registration function for intrin rule.
register_intrin_lowering : The registration function for intrinsic lowering rule.
"""
if str(op.dtype).startswith("float"):
return call_pure_extern(op.dtype, op.op.name[4:], *op.args)
return None


# opencl pattern for exp
register_intrin_rule("opencl", "exp", _rule_float_direct, override=True)
register_intrin_lowering("tir.exp", target="opencl", f=_rule_float_direct, level=99)
# default pattern for exp
register_intrin_rule("default", "exp", _rule_float_suffix, override=True)
register_intrin_lowering("tir.exp", target="default", f=_rule_float_suffix, level=99)
5 changes: 3 additions & 2 deletions python/tvm/topi/arm_cpu/tensor_intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import tvm
from tvm import te
from tvm.ir import register_intrin_lowering


def gemm_4x4_int8_int8_int32(M, N, K, unroll, in_type):
Expand Down Expand Up @@ -1054,6 +1055,6 @@ def _q_multiply_shift_arm(op):
return tvm.tir.Select(s < 0, out_1, out_2)


tvm.target.intrin.register_intrin_rule(
"llvm.aarch64", "q_multiply_shift", _q_multiply_shift_arm, override=True
register_intrin_lowering(
"tir.q_multiply_shift", target="llvm.aarch64", f=_q_multiply_shift_arm, level=99
)
7 changes: 3 additions & 4 deletions python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from tvm import te
from tvm.contrib import nvcc
from tvm.contrib.thrust import can_use_thrust, can_use_rocthrust
from tvm.ir import register_intrin_lowering
from tvm.tir import if_then_else
from .sort import argsort, argsort_thrust
from .scan import exclusive_scan
Expand Down Expand Up @@ -51,11 +52,9 @@ def opencl_atomic_add_rule(op):
raise RuntimeError("only support int32")


tvm.target.intrin.register_intrin_rule("cuda", "atomic_add", cuda_atomic_add_rule, override=True)
register_intrin_lowering("tir.atomic_add", target="cuda", f=cuda_atomic_add_rule, level=99)

tvm.target.intrin.register_intrin_rule(
"opencl", "atomic_add", opencl_atomic_add_rule, override=True
)
register_intrin_lowering("tir.atomic_add", target="opencl", f=opencl_atomic_add_rule, level=99)


def atomic_add(x, y):
Expand Down
8 changes: 8 additions & 0 deletions src/ir/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <tvm/runtime/container.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/tir/op_attr_types.h>

#include <memory>

Expand All @@ -36,6 +37,7 @@ namespace tvm {
using runtime::PackedFunc;
using runtime::TVMArgs;
using runtime::TVMRetValue;
using tir::FLowerIntrinsic;

using OpRegistry = AttrRegistry<OpRegEntry, Op>;

Expand Down Expand Up @@ -122,6 +124,12 @@ TVM_REGISTER_GLOBAL("ir.RegisterOpAttr")
}
});

TVM_REGISTER_GLOBAL("ir.RegisterOpLowerIntrinsic")
.set_body_typed([](String name, PackedFunc f, String target, int plevel) {
tvm::OpRegEntry::RegisterOrGet(name).set_attr<FLowerIntrinsic>(target + ".FLowerIntrinsic", f,
plevel);
});

// helper to get internal dev function in objectref.
struct Op2ObjectPtr : public ObjectRef {
static ObjectPtr<Object> Get(const Op& op) { return GetDataPtr<Object>(op); }
Expand Down
Loading

0 comments on commit 95f8a8c

Please sign in to comment.