Skip to content

Commit

Permalink
use clz for ceil_log2 when compiling for vk
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Apr 16, 2021
1 parent aabc763 commit a70dd1d
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 11 deletions.
7 changes: 3 additions & 4 deletions python/tvm/topi/cuda/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from tvm.contrib.thrust import can_use_rocthrust, can_use_thrust

from .. import tag
from ..math import cast
from ..math import cast, ceil_log2
from ..transform import expand_dims, reshape, squeeze, transpose
from ..utils import ceil_div, get_const_int, prod, swap
from .injective import schedule_injective_from_existing
Expand Down Expand Up @@ -103,9 +103,8 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, i

# The following algorithm performs parallel exclusive scan
# Up Sweep of exclusive scan
lim = tvm.tir.generic.cast(
tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(scan_axis_size, "float64"))), "int64"
)
lim = ceil_log2(scan_axis_size)

with ib.for_range(0, lim, dtype="int64") as l2_width:
width = 2 << l2_width

Expand Down
10 changes: 3 additions & 7 deletions python/tvm/topi/cuda/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ..transform import strided_slice, transpose
from .. import tag
from ..utils import ceil_div, swap
from ..math import cast
from ..math import cast, ceil_log2


def _schedule_sort(outs):
Expand Down Expand Up @@ -238,9 +238,7 @@ def compare(a, b):
return out

# Sort the lower levels of the merge using odd-even sort, it's fast for small inputs
lower_lim = tvm.tir.generic.cast(
tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(block_size, "float64"))), "int64"
)
lower_lim = ceil_log2(block_size)

_odd_even_sort(
ib,
Expand All @@ -254,9 +252,7 @@ def compare(a, b):
values_swap,
)

upper_lim = tvm.tir.generic.cast(
tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(size, "float64"))), "int64"
)
upper_lim = ceil_log2(size)

def get_merge_begin(source, base_idx, aCount, bCount, aStart, bStart, diag, step_count):
first = ib.allocate("int64", (1,), name="first", scope="local")
Expand Down
22 changes: 22 additions & 0 deletions python/tvm/topi/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,3 +742,25 @@ def fast_erf(x):
The result.
"""
return cpp.fast_erf(x, x.dtype, tag.ELEMWISE)


def ceil_log2(x):
"""TODO"""
if not isinstance(x, tvm.tir.PrimExpr):
x = tvm.tir.const(x)

if "float" in x.dtype:
return tvm.tir.ceil(tvm.tir.log2(x))

if "vulkan" in tvm.target.Target.current().kind.name:
# SPIR-V does not support log2 on fp64. Instead, we compute ceil_log2 via clz
clz = tvm.tir.clz(x)
bits = int(x.dtype[-2:])
ceil_log2 = tvm.tir.if_then_else(x & (x - 1) == 0, bits - clz - 1, bits - clz)

if ceil_log2.dtype != x.dtype:
return cast(ceil_log2, x.dtype)

return ceil_log2

return cast(tvm.tir.ceil(tvm.tir.log2(cast(x, "float64"))), x.dtype)

0 comments on commit a70dd1d

Please sign in to comment.