Skip to content

Commit

Permalink
making binop argument tir function
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Jan 26, 2021
1 parent ac53832 commit 395215c
Showing 1 changed file with 20 additions and 14 deletions.
34 changes: 20 additions & 14 deletions python/tvm/topi/cuda/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,13 @@
from .injective import schedule_injective_from_existing


binop_name_to_func = {"sum": tvm.tir.generic.add}
def _get_thrust_func_name(tvmop):
tvmop_to_thrust_func_name = {tvm.tir.generic.add: "tvm.contrib.thrust.sum_scan"}
assert tvmop in tvmop_to_thrust_func_name, "{} not supported by thrust".format(tvmop)
return tvmop_to_thrust_func_name[tvmop]


def exclusive_scan_ir(data, output, reduction=None, binop="sum"):
def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add):
"""Low level IR to do exclusive sum scan along rows of 2D input.
Parameters
Expand Down Expand Up @@ -93,7 +96,6 @@ def exclusive_scan_ir(data, output, reduction=None, binop="sum"):
lim = tvm.tir.generic.cast(
tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(scan_axis_size, "float64"))), "int64"
)
op = binop_name_to_func[binop]
with ib.for_range(0, lim, dtype="int64") as l2_width:
width = 2 << l2_width

Expand All @@ -118,7 +120,7 @@ def exclusive_scan_ir(data, output, reduction=None, binop="sum"):
middle[0] = start[0] + tvm.tir.indexdiv(width, 2)
end[0] = tvm.te.min(start[0] + width, scan_axis_size)
with ib.if_scope(middle[0] < scan_axis_size):
output[by * scan_axis_size + end[0] - 1] = op(
output[by * scan_axis_size + end[0] - 1] = binop(
output[by * scan_axis_size + end[0] - 1],
output[by * scan_axis_size + middle[0] - 1],
)
Expand Down Expand Up @@ -161,13 +163,13 @@ def exclusive_scan_ir(data, output, reduction=None, binop="sum"):
output[by * scan_axis_size + middle[0] - 1] = output[
by * scan_axis_size + end[0] - 1
]
output[by * scan_axis_size + end[0] - 1] = op(
output[by * scan_axis_size + end[0] - 1] = binop(
output[by * scan_axis_size + end[0] - 1], tmp[0]
)
return ib.get()


def get_reduction_from_exclusive_scan(data, ex_scan_output, binop="sum"):
def get_reduction_from_exclusive_scan(data, ex_scan_output, binop=tvm.tir.generic.add):
"""Return the sum of the last element of data and the exclusive scan output.
The is the reduction of data along each row (for 2-D case).
Expand Down Expand Up @@ -213,7 +215,7 @@ def ir(data, data_ex_scan, reduction):
tid = bx * max_threads + tx
with ib.if_scope(tid < batch_size):
with ib.if_scope(scan_axis_size > 0):
reduction[tid] = binop_name_to_func[binop](
reduction[tid] = binop(
data_ex_scan[tid * scan_axis_size + scan_axis_size - 1],
data[tid, scan_axis_size - 1],
)
Expand Down Expand Up @@ -248,7 +250,9 @@ def is_thrust_available():
return get_global_func("tvm.contrib.thrust.sum_scan", allow_missing=True) is not None


def scan_thrust(data, output_dtype, exclusive=True, return_reduction=False, binop="sum"):
def scan_thrust(
data, output_dtype, exclusive=True, return_reduction=False, binop=tvm.tir.generic.add
):
"""Do exclusive or inclusive scan on 1D or multidimensional input, using thrust.
Parameters
Expand Down Expand Up @@ -281,12 +285,12 @@ def scan_thrust(data, output_dtype, exclusive=True, return_reduction=False, bino
"""
data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
output_buf = tvm.tir.decl_buffer(data.shape, output_dtype, "output_buf", data_alignment=8)
binop_to_thrust_func_name = {"sum": "tvm.contrib.thrust.sum_scan"}

output = te.extern(
[data.shape],
[data],
lambda ins, outs: tvm.tir.call_packed(
binop_to_thrust_func_name[binop], ins[0], outs[0], exclusive
_get_thrust_func_name(binop), ins[0], outs[0], exclusive
),
dtype=[output_dtype],
in_buffers=[data_buf],
Expand All @@ -303,7 +307,9 @@ def scan_thrust(data, output_dtype, exclusive=True, return_reduction=False, bino
return output


def exclusive_scan(data, axis=-1, return_reduction=False, output_dtype=None, binop="sum"):
def exclusive_scan(
data, axis=-1, return_reduction=False, output_dtype=None, binop=tvm.tir.generic.add
):
"""Do exclusive scan on 1D or multidimensional input.
Parameters
Expand Down Expand Up @@ -411,7 +417,7 @@ def do_scan(data, output_dtype):
return output


def inclusive_scan(data, axis=-1, output_dtype=None, binop="sum"):
def inclusive_scan(data, axis=-1, output_dtype=None, binop=tvm.tir.generic.add):
"""Do inclusive scan on 1D or multidimensional input.
Parameters
Expand All @@ -438,7 +444,7 @@ def inclusive_scan(data, axis=-1, output_dtype=None, binop="sum"):
if output_dtype is not None and data.dtype != output_dtype and output_dtype != "":
data = cast(data, output_dtype)

return binop_name_to_func[binop](data, ex_scan)
return binop(data, ex_scan)


def schedule_scan(outs):
Expand Down Expand Up @@ -498,4 +504,4 @@ def cumsum(data, axis=None, dtype=None):
axis = 0
data = reshape(data, (prod(data.shape),))
axis = get_const_int(axis)
return inclusive_scan(data, axis, output_dtype=dtype, binop="sum")
return inclusive_scan(data, axis, output_dtype=dtype, binop=tvm.tir.generic.add)

0 comments on commit 395215c

Please sign in to comment.