Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOPI][SPIRV] Cast to float32 not float64 before log2 in sort/scan #7669

Merged
merged 8 commits into from
Apr 17, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions python/tvm/topi/cuda/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from tvm.contrib.thrust import can_use_rocthrust, can_use_thrust

from .. import tag
from ..math import cast
from ..math import cast, ceil_log2
from ..transform import expand_dims, reshape, squeeze, transpose
from ..utils import ceil_div, get_const_int, prod, swap
from .injective import schedule_injective_from_existing
Expand Down Expand Up @@ -103,9 +103,8 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, i

# The following algorithm performs parallel exclusive scan
# Up Sweep of exclusive scan
lim = tvm.tir.generic.cast(
tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(scan_axis_size, "float64"))), "int64"
)
lim = ceil_log2(scan_axis_size)

with ib.for_range(0, lim, dtype="int64") as l2_width:
width = 2 << l2_width

Expand Down
10 changes: 3 additions & 7 deletions python/tvm/topi/cuda/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ..transform import strided_slice, transpose
from .. import tag
from ..utils import ceil_div, swap
from ..math import cast
from ..math import cast, ceil_log2


def _schedule_sort(outs):
Expand Down Expand Up @@ -238,9 +238,7 @@ def compare(a, b):
return out

# Sort the lower levels of the merge using odd-even sort, it's fast for small inputs
lower_lim = tvm.tir.generic.cast(
tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(block_size, "float64"))), "int64"
)
lower_lim = ceil_log2(block_size)

_odd_even_sort(
ib,
Expand All @@ -254,9 +252,7 @@ def compare(a, b):
values_swap,
)

upper_lim = tvm.tir.generic.cast(
tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(size, "float64"))), "int64"
)
upper_lim = ceil_log2(size)

def get_merge_begin(source, base_idx, aCount, bCount, aStart, bStart, diag, step_count):
first = ib.allocate("int64", (1,), name="first", scope="local")
Expand Down
34 changes: 34 additions & 0 deletions python/tvm/topi/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,3 +742,37 @@ def fast_erf(x):
The result.
"""
return cpp.fast_erf(x, x.dtype, tag.ELEMWISE)


def ceil_log2(x):
"""Compute integer ceil log2 with a special code path for vulkan
SPIR-V does not support log2 on fp64. Instead, we compute integer ceil_log2 via clz
intrinsic when the target is vulkan.

Parameters
----------
x : tvm.te.Tensor
Input argument.

Returns
-------
y : tvm.te.Tensor
The result.
"""
if not isinstance(x, tvm.tir.PrimExpr):
x = tvm.tir.const(x)

if "float" in x.dtype:
return tvm.tir.ceil(tvm.tir.log2(x))

if "vulkan" in tvm.target.Target.current().kind.name:
clz = tvm.tir.clz(x)
bits = int(x.dtype[-2:])
ceil_log2 = tvm.tir.if_then_else(x & (x - 1) == 0, bits - clz - 1, bits - clz)

if ceil_log2.dtype != x.dtype:
return cast(ceil_log2, x.dtype)

return ceil_log2

return cast(tvm.tir.ceil(tvm.tir.log2(cast(x, "float64"))), x.dtype)
10 changes: 10 additions & 0 deletions tests/python/unittest/test_target_codegen_spirv.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,16 @@ def test_pushconstants():

check_mod(mod, x_np, res_np)

# One 64 bit and one 32 bit constants
dtype = "int32"
x = relay.var("x", shape=(relay.Any(),), dtype=dtype)
mod = tvm.IRModule()
mod["main"] = relay.Function([x], relay.cumsum(x))
x_np = np.random.randint(0, high=10, size=(10,)).astype(dtype)
res_np = np.cumsum(x_np)

check_mod(mod, x_np, res_np)


def test_unique():
if not tvm.testing.device_enabled("vulkan"):
Expand Down