Skip to content

Commit

Permalink
[TIR] Add a new intrinsic count leading zeros for LLVM and SPIR-V (#7825
Browse files Browse the repository at this point in the history
)
  • Loading branch information
masahi authored Apr 16, 2021
1 parent aa9cb63 commit cc79e8f
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 3 deletions.
1 change: 1 addition & 0 deletions include/tvm/tir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -864,6 +864,7 @@ TVM_DECLARE_INTRIN_UNARY(atan);
TVM_DECLARE_INTRIN_UNARY(acosh);
TVM_DECLARE_INTRIN_UNARY(asinh);
TVM_DECLARE_INTRIN_UNARY(atanh);
TVM_DECLARE_INTRIN_UNARY(clz);

#define TVM_DECLARE_INTRIN_BINARY(OpName) \
inline PrimExpr OpName(PrimExpr x, PrimExpr y, Span span = Span()) { \
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

from .op import call_packed, call_intrin, call_pure_extern, call_extern
from .op import call_llvm_intrin, call_llvm_pure_intrin, ret, all, any, min_value, max_value, trace
from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp
from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp, clz
from .op import sin, sinh, asin, asinh
from .op import cos, cosh, acos, acosh
from .op import tan, tanh, atan, atan2, atanh
Expand Down
16 changes: 16 additions & 0 deletions python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,22 @@ def rsqrt(x):
return call_intrin(x.dtype, "tir.rsqrt", x)


def clz(x):
"""Count leading zero bits of an integer x.
Parameters
----------
x : PrimExpr
Input argument. The result is undefined if the input is 0.
Returns
-------
y : PrimExpr
The result.
"""
return call_intrin("int32", "tir.clz", x)


def floor(x, span=None):
"""Take floor of float input x.
Expand Down
15 changes: 15 additions & 0 deletions src/target/llvm/intrin_rule_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,21 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.sinh")
*rv = ret;
});

TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.clz").set_body([](const TVMArgs& targs, TVMRetValue* rv) {
PrimExpr e = targs[0];
const tir::CallNode* call = e.as<tir::CallNode>();
ICHECK(call != nullptr);
ICHECK_EQ(call->args.size(), 1);
Array<PrimExpr> cargs;
cargs.push_back(IntImm(DataType::UInt(32), ::llvm::Intrinsic::ctlz));
cargs.push_back(IntImm(DataType::UInt(32), 2));
cargs.push_back(call->args[0]);
cargs.push_back(IntImm(DataType::Int(1), 1)); // is_zero_undef
// LLVM requires that the return type must match the first argument type
auto clz = tir::Call(call->args[0]->dtype, tir::builtin::call_llvm_intrin(), cargs);
*rv = cast(call->dtype, clz);
});

} // namespace llvm
} // namespace codegen
} // namespace tvm
Expand Down
22 changes: 20 additions & 2 deletions src/target/spirv/intrin_rule_spirv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <tvm/runtime/registry.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>

namespace tvm {
namespace codegen {
Expand All @@ -32,8 +33,9 @@ namespace spirv {
using namespace runtime;

// num_signature means number of arguments used to query signature

template <unsigned id>
inline void DispatchGLSLPureIntrin(const TVMArgs& targs, TVMRetValue* rv) {
PrimExpr CallGLSLIntrin(const TVMArgs& targs, TVMRetValue* rv) {
PrimExpr e = targs[0];
const tir::CallNode* call = e.as<tir::CallNode>();
ICHECK(call != nullptr);
Expand All @@ -44,7 +46,12 @@ inline void DispatchGLSLPureIntrin(const TVMArgs& targs, TVMRetValue* rv) {
for (PrimExpr arg : call->args) {
cargs.push_back(arg);
}
*rv = tir::Call(call->dtype, tir::builtin::call_spirv_pure_glsl450(), cargs);
return tir::Call(call->dtype, tir::builtin::call_spirv_pure_glsl450(), cargs);
}

template <unsigned id>
inline void DispatchGLSLPureIntrin(const TVMArgs& targs, TVMRetValue* rv) {
*rv = CallGLSLIntrin<id>(targs, rv);
}

TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.floor")
Expand Down Expand Up @@ -76,6 +83,17 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.pow").set_body(DispatchGLSLPureIntri

TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.tanh").set_body(DispatchGLSLPureIntrin<GLSLstd450Tanh>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.clz")
.set_body([](const TVMArgs& targs, TVMRetValue* rv) {
PrimExpr e = targs[0];
const tir::CallNode* call = e.as<tir::CallNode>();
ICHECK(call != nullptr);
ICHECK_EQ(call->args.size(), 1);
PrimExpr arg = call->args[0];
PrimExpr msb = CallGLSLIntrin<GLSLstd450FindUMsb>(targs, rv);
*rv = PrimExpr(arg.dtype().bits() - 1) - msb;
});

// WebGPU rules.
TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.floor")
.set_body(DispatchGLSLPureIntrin<GLSLstd450Floor>);
Expand Down
2 changes: 2 additions & 0 deletions src/tir/op/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -858,6 +858,8 @@ TIR_REGISTER_PURE_UNARY_OP("tir.asinh");

TIR_REGISTER_PURE_UNARY_OP("tir.atanh");

TIR_REGISTER_PURE_UNARY_OP("tir.clz");

// binary intrinsics
TIR_REGISTER_PURE_BINARY_OP("tir.atan2");

Expand Down
38 changes: 38 additions & 0 deletions tests/python/unittest/test_tir_intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,47 @@ def test_ldexp():
)


def test_clz():
def clz_np(x, dtype):
ceil_log2 = np.ceil(np.log2(x)).astype(dtype)
bits = int(dtype[-2:])
clz = bits - ceil_log2
clz[np.bitwise_and(x, x - 1) == 0] -= 1
return clz

for target in ["llvm", "vulkan"]:
if not tvm.testing.device_enabled("vulkan"):
continue

for dtype in ["int32", "int64"]:
m = te.var("m")
A = te.placeholder((m,), name="A", dtype=dtype)
B = te.compute((m,), lambda *i: tvm.tir.clz(A(*i)), name="B")
s = te.create_schedule(B.op)

if target == "vulkan":
bx, tx = s[B].split(B.op.axis[0], factor=64)

s[B].bind(bx, te.thread_axis("blockIdx.x"))
s[B].bind(tx, te.thread_axis("threadIdx.x"))

f = tvm.build(s, [A, B], target)
dev = tvm.device(target, 0)
n = 10

for high in [10, 100, 1000, 10000, 100000, 1000000]:
a_np = np.random.randint(1, high=high, size=(n,)).astype(dtype)
a = tvm.nd.array(a_np, dev)
b = tvm.nd.array(np.zeros((n,)).astype("int32"), dev)
f(a, b)
ref = clz_np(a_np, dtype)
np.testing.assert_equal(b.asnumpy(), ref)


if __name__ == "__main__":
test_nearbyint()
test_unary_intrin()
test_round_intrinsics_on_int()
test_binary_intrin()
test_ldexp()
test_clz()

0 comments on commit cc79e8f

Please sign in to comment.