diff --git a/python/tvm/topi/cuda/scan.py b/python/tvm/topi/cuda/scan.py index 5d3798e3d27b..6dbaf02191c8 100644 --- a/python/tvm/topi/cuda/scan.py +++ b/python/tvm/topi/cuda/scan.py @@ -23,7 +23,7 @@ from tvm.contrib.thrust import can_use_rocthrust, can_use_thrust from .. import tag -from ..math import cast +from ..math import cast, ceil_log2 from ..transform import expand_dims, reshape, squeeze, transpose from ..utils import ceil_div, get_const_int, prod, swap from .injective import schedule_injective_from_existing @@ -103,9 +103,8 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, i # The following algorithm performs parallel exclusive scan # Up Sweep of exclusive scan - lim = tvm.tir.generic.cast( - tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(scan_axis_size, "float64"))), "int64" - ) + lim = ceil_log2(scan_axis_size) + with ib.for_range(0, lim, dtype="int64") as l2_width: width = 2 << l2_width diff --git a/python/tvm/topi/cuda/sort.py b/python/tvm/topi/cuda/sort.py index 93e4d3feccc7..25cc7a4e2cfb 100644 --- a/python/tvm/topi/cuda/sort.py +++ b/python/tvm/topi/cuda/sort.py @@ -23,7 +23,7 @@ from ..transform import strided_slice, transpose from .. import tag from ..utils import ceil_div, swap -from ..math import cast +from ..math import cast, ceil_log2 def _schedule_sort(outs): @@ -238,9 +238,7 @@ def compare(a, b): return out # Sort the lower levels of the merge using odd-even sort, it's fast for small inputs - lower_lim = tvm.tir.generic.cast( - tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(block_size, "float64"))), "int64" - ) + lower_lim = ceil_log2(block_size) _odd_even_sort( ib, @@ -254,9 +252,7 @@ def compare(a, b): values_swap, ) - upper_lim = tvm.tir.generic.cast( - tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(size, "float64"))), "int64" - ) + upper_lim = ceil_log2(size) def get_merge_begin(source, base_idx, aCount, bCount, aStart, bStart, diag, step_count): first = ib.allocate("int64", (1,), name="first", scope="local") diff --git a/python/tvm/topi/math.py b/python/tvm/topi/math.py index cf6fcbb88c7e..d3ef3daf10dc 100644 --- a/python/tvm/topi/math.py +++ b/python/tvm/topi/math.py @@ -742,3 +742,37 @@ def fast_erf(x): The result. """ return cpp.fast_erf(x, x.dtype, tag.ELEMWISE) + + +def ceil_log2(x): + """Compute integer ceil log2 with a special code path for vulkan + SPIR-V does not support log2 on fp64. Instead, we compute integer ceil_log2 via clz + intrinsic when the target is vulkan. + + Parameters + ---------- + x : tvm.te.Tensor + Input argument. + + Returns + ------- + y : tvm.te.Tensor + The result. + """ + if not isinstance(x, tvm.tir.PrimExpr): + x = tvm.tir.const(x) + + if "float" in x.dtype: + return tvm.tir.ceil(tvm.tir.log2(x)) + + if "vulkan" in tvm.target.Target.current().kind.name: + clz = tvm.tir.clz(x) + bits = int(x.dtype[-2:]) + res = tvm.tir.if_then_else(x & (x - 1) == 0, bits - clz - 1, bits - clz) + + if res.dtype != x.dtype: + return cast(res, x.dtype) + + return res + + return cast(tvm.tir.ceil(tvm.tir.log2(cast(x, "float64"))), x.dtype) diff --git a/tests/python/unittest/test_target_codegen_spirv.py b/tests/python/unittest/test_target_codegen_spirv.py index df42eeb721ab..b9f07cf426fe 100644 --- a/tests/python/unittest/test_target_codegen_spirv.py +++ b/tests/python/unittest/test_target_codegen_spirv.py @@ -104,6 +104,16 @@ def test_pushconstants(): check_mod(mod, x_np, res_np) + # One 64 bit and one 32 bit constants + dtype = "int32" + x = relay.var("x", shape=(relay.Any(),), dtype=dtype) + mod = tvm.IRModule() + mod["main"] = relay.Function([x], relay.cumsum(x)) + x_np = np.random.randint(0, high=10, size=(10,)).astype(dtype) + res_np = np.cumsum(x_np) + + check_mod(mod, x_np, res_np) + def test_unique(): if not tvm.testing.device_enabled("vulkan"):