Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUDA] Parallel Cuda Mergesort #7099

Merged
merged 11 commits into from
Dec 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/tvm/driver/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def _build_for_device(input_mod, target, target_host):
lambda f: "calling_conv" not in f.attrs
or f.attrs["calling_conv"].value != CallingConv.DEVICE_KERNEL_LAUNCH
),
tvm.tir.transform.Apply(lambda f: f.with_attr("target", target)),
tvm.tir.transform.Apply(lambda f: f.with_attr("target", target_host)),
Copy link
Member

@masahi masahi Dec 15, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the record, segfault with nvptx was happening because the generated host code was calling intrinsics registered for nvptx, like __nv_log2 or __nv_ceil. The reason it was working on CUDA was just by coincident: there is no CUDA intrinsics registered for fp64 log2, ceil, so TVM was using the default lowering, which happens to be the right one (llvm).

This change fixes that issue.

tvm.tir.transform.LowerTVMBuiltin(),
tvm.tir.transform.LowerDeviceStorageAccessInfo(),
tvm.tir.transform.LowerCustomDatatypes(),
Expand Down
292 changes: 224 additions & 68 deletions python/tvm/topi/cuda/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument
# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument, no-else-return
"""Sort related operators """
import tvm
from tvm import te
Expand Down Expand Up @@ -62,28 +62,43 @@ def traverse(op):
return s


def sort_ir(data, values_out, axis, is_ascend, indices_out=None):
def sort_ir(
data, values_out, values_out_swap, axis, is_ascend, indices_out=None, indices_out_swap=None
mbrookhart marked this conversation as resolved.
Show resolved Hide resolved
):
"""Low level IR to do nms sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU.

Parameters
----------
data: Buffer
Buffer of input data. Data will be sorted in place.

output : Buffer
Output buffer of indicies of sorted tensor with same shape as data.
values_out : Buffer
Output buffer of values of sorted tensor with same shape as data.

values_out_swap : Buffer
Output buffer of values with same shape as data to use as swap.

axis : Int
Axis long which to sort the input tensor.

is_ascend : Boolean
Whether to sort in ascending or descending order.

indicess_out : Buffer
Output buffer of indices of sorted tensor with same shape as data.

indices_out_swap : Buffer
Output buffer of indices with same shape as data to use as swap.

Returns
-------
stmt : Stmt
The result IR statement.
"""

def ceil_div(a, b):
return tvm.tir.indexdiv(a + b - 1, b)

axis_mul_before = 1
axis_mul_after = 1
shape = data.shape
Expand All @@ -94,64 +109,182 @@ def sort_ir(data, values_out, axis, is_ascend, indices_out=None):
axis_mul_before *= value
elif i > axis:
axis_mul_after *= value
max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)

ib = tvm.tir.ir_builder.create()

data = ib.buffer_ptr(data)
values_out = ib.buffer_ptr(values_out)
values_out_swap = ib.buffer_ptr(values_out_swap)
if indices_out is not None:
indices_out = ib.buffer_ptr(indices_out)
nthread_tx = max_threads
nthread_bx = shape[axis] // max_threads + 1
assert indices_out_swap is not None
indices_out_swap = ib.buffer_ptr(indices_out_swap)

tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * nthread_tx + tx
temp_data = ib.allocate(values_out.dtype, (1,), name="temp_data", scope="local")
if indices_out is not None:
temp_index = ib.allocate(indices_out.dtype, (1,), name="temp_index", scope="local")
# Set up threading
max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
nthread_tx = max_threads
nthread_bx = ceil_div(shape[axis], max_threads)
nthread_by = axis_mul_before
nthread_bz = axis_mul_after

# Copy the data to initial output
with ib.new_scope():
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * nthread_tx + tx

by = te.thread_axis("blockIdx.y")
bz = te.thread_axis("blockIdx.z")
ib.scope_attr(by, "thread_extent", nthread_by)
ib.scope_attr(bz, "thread_extent", nthread_bz)
idx = (by * shape[axis] + tid) * axis_mul_after + bz
with ib.if_scope(tid < shape[axis]):
values_out[idx] = data[idx]
if indices_out is not None:
indices_out[idx] = tvm.tir.generic.cast(tid, indices_out.dtype)

## we are looping over the array doing mergesort from the bottom up.
## The outer loop runs on the host and launches a cuda kernel for each iteration
## of the algorithm.
## The basic idea is that at iteration 0, each thread does sort on 2 elements.
## On iteration 1, each thread merges 2 sorted arrays of 2 elements,
## to deal with 4 total elements.
## On iteration 2, each thread merges 2 sorted arrays of 4 elements,
## to deal with 8 total elements. On iteration 3, each thread deals with 16 elements, etc
## On the final iteration of the algorithm, one thread will merge two sorted lists
## to sort the entire array
lim = tvm.tir.generic.cast(
tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(shape[axis], "float64"))), "int64"
)
with ib.for_range(0, lim, dtype="int64") as l2_width:
width = 2 << l2_width
# Define and launch the cuda kernel
with ib.new_scope():
i = ib.allocate("int64", (1,), name="i", scope="local")
j = ib.allocate("int64", (1,), name="j", scope="local")
start = ib.allocate("int64", (1,), name="start", scope="local")
middle = ib.allocate("int64", (1,), name="middle", scope="local")
end = ib.allocate("int64", (1,), name="end", scope="local")
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
# Reduce the number of blocks as the work per thread grows
ib.scope_attr(
bx,
"thread_extent",
tvm.tir.generic.cast(ceil_div(shape[axis], width * max_threads), "int32"),
)
tid = bx * nthread_tx + tx

by = te.thread_axis("blockIdx.y")
bz = te.thread_axis("blockIdx.z")
ib.scope_attr(by, "thread_extent", nthread_by)
ib.scope_attr(bz, "thread_extent", nthread_bz)

def compare(a, b):
"""
Compare a and b in proper ascending or descending order
"""
if is_ascend:
out = a <= b
else:
out = b <= a
return out

def bottom_up_merge(source, dest, source_idx, dest_idx, start, middle, end, even):
"""
Merge the two sections of the array assigned to this thread
"""
# pylint: disable=arguments-out-of-order
# initialize iterators
i[0] = start
j[0] = middle
# set up indexes
base_idx = by * shape[axis] * axis_mul_after + bz
# iterate over the output loop
with ib.for_range(0, end - start) as k:
i_idx = base_idx + i[0] * axis_mul_after
j_idx = base_idx + j[0] * axis_mul_after
k_idx = base_idx + (k + start) * axis_mul_after

def swap_values(source, dest, source_idx, dest_idx):
def assign_i():
"""assign i value to current output"""
dest[k_idx] = source[i_idx]
if indices_out is not None:
dest_idx[k_idx] = source_idx[i_idx]
i[0] += 1

def assign_j():
"""assign j value to current output"""
dest[k_idx] = source[j_idx]
if indices_out is not None:
dest_idx[k_idx] = source_idx[j_idx]
j[0] += 1

## if both of the iterators are in range
with ib.if_scope(tvm.tir.all(i[0] < middle, j[0] < end)):
# compare them and insert whichever is next into the output
with ib.if_scope(compare(source[i_idx], source[j_idx])):
assign_i()
with ib.else_scope():
assign_j()
# otherwise, simply copy the remainder of the valid iterator to the output
with ib.else_scope():
with ib.if_scope(i[0] < middle):
assign_i()
with ib.else_scope():
assign_j()

# Switch which input is the source and which is the destination each iteration
with ib.if_scope(even):
swap_values(source, dest, source_idx, dest_idx)
with ib.else_scope():
swap_values(dest, source, dest_idx, source_idx)

def mergesort(source, dest, source_idx, dest_idx, size, width, even):
# calculate the start, mid, and end points of this section
start[0] = width * tid
with ib.if_scope(start[0] < size):
middle[0] = tvm.te.min(start[0] + tvm.tir.indexdiv(width, 2), size)
end[0] = tvm.te.min(start[0] + width, size)
## merge the start->middle and middle->end arrays
bottom_up_merge(
source, dest, source_idx, dest_idx, start[0], middle[0], end[0], even
)

with ib.for_range(0, axis_mul_before) as i:
with ib.for_range(0, axis_mul_after) as j:
base_idx = i * shape[axis] * axis_mul_after + j
# Call the kernel
mergesort(
values_out,
values_out_swap,
indices_out,
indices_out_swap,
shape[axis],
width,
tvm.tir.indexmod(l2_width, 2) == 0,
)

## if the final sorted data ended up in the swap, copy it to the real output
with ib.if_scope(tvm.tir.indexmod(lim, 2) == 1):
with ib.new_scope():
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * nthread_tx + tx

by = te.thread_axis("blockIdx.y")
bz = te.thread_axis("blockIdx.z")
ib.scope_attr(by, "thread_extent", nthread_by)
ib.scope_attr(bz, "thread_extent", nthread_bz)
idx = (by * shape[axis] + tid) * axis_mul_after + bz
with ib.if_scope(tid < shape[axis]):
values_out[base_idx + tid * axis_mul_after] = data[base_idx + tid * axis_mul_after]
idx = (by * shape[axis] + tid) * axis_mul_after + bz
values_out[idx] = values_out_swap[idx]
if indices_out is not None:
indices_out[base_idx + tid * axis_mul_after] = tvm.tir.generic.cast(
tid, indices_out.dtype
)
ib.emit(tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"])))
idxd = tvm.tir.indexdiv
idxm = tvm.tir.indexmod

with ib.for_range(0, axis_mul_before) as i:
with ib.for_range(0, axis_mul_after) as j:
current_sort_num = shape[axis]
base_idx = i * shape[axis] * axis_mul_after + j
# OddEvenTransposeSort
with ib.for_range(0, current_sort_num) as k:
with ib.if_scope(tid < idxd(current_sort_num + 1, 2)):
offset = base_idx + (2 * tid + idxm(k, 2)) * axis_mul_after
if is_ascend:
cond = tvm.tir.all(
2 * tid + idxm(k, 2) + 1 < current_sort_num,
values_out[offset] > values_out[offset + axis_mul_after],
)
else:
cond = tvm.tir.all(
2 * tid + idxm(k, 2) + 1 < current_sort_num,
values_out[offset] < values_out[offset + axis_mul_after],
)
with ib.if_scope(cond):
temp_data[0] = values_out[offset]
values_out[offset] = values_out[offset + axis_mul_after]
values_out[offset + axis_mul_after] = temp_data[0]
if indices_out is not None:
temp_index[0] = indices_out[offset]
indices_out[offset] = indices_out[offset + axis_mul_after]
indices_out[offset + axis_mul_after] = temp_index[0]
ib.emit(tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"])))
indices_out[idx] = indices_out_swap[idx]

return ib.get()

Expand Down Expand Up @@ -336,14 +469,13 @@ def sort(data, axis=-1, is_ascend=1):
out : tvm.te.Tensor
The output of this function.
"""
dtype = "float32"
value_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf", data_alignment=8)
indices_buf = tvm.tir.decl_buffer(data.shape, dtype, "out_buf", data_alignment=8)
value_buf_swap = tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf_swap", data_alignment=8)
out = te.extern(
[data.shape, data.shape],
[data],
lambda ins, outs: sort_ir(ins[0], outs[0], axis, is_ascend, indices_out=outs[1]),
out_buffers=[value_buf, indices_buf],
lambda ins, outs: sort_ir(ins[0], outs[0], outs[1], axis, is_ascend),
out_buffers=[value_buf, value_buf_swap],
name="sort_gpu",
tag="sort_gpu",
)[0]
Expand Down Expand Up @@ -449,12 +581,24 @@ def argsort(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"):
)
else:
value_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf", data_alignment=8)
value_swap_buf = tvm.tir.decl_buffer(
data.shape, data.dtype, "value_swap_buf", data_alignment=8
)
indices_buf = tvm.tir.decl_buffer(data.shape, dtype, "out_buf", data_alignment=8)
indices_swap_buf = tvm.tir.decl_buffer(data.shape, dtype, "out_swap_buf", data_alignment=8)
out = te.extern(
[data.shape, data.shape],
[data.shape, data.shape, data.shape, data.shape],
[data],
lambda ins, outs: sort_ir(ins[0], outs[0], axis, is_ascend, indices_out=outs[1]),
out_buffers=[value_buf, indices_buf],
lambda ins, outs: sort_ir(
ins[0],
outs[0],
outs[2],
axis,
is_ascend,
indices_out=outs[1],
indices_out_swap=outs[3],
),
out_buffers=[value_buf, indices_buf, value_swap_buf, indices_swap_buf],
name="argsort_gpu",
tag="argsort_gpu",
)[1]
Expand Down Expand Up @@ -564,25 +708,37 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"):
axis = axis + ndim if axis < 0 else axis
assert 0 <= axis < ndim
values_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "values_buf", data_alignment=8)
values_swap_buf = tvm.tir.decl_buffer(
data.shape, data.dtype, "values_swap_buf", data_alignment=8
)
indices_buf = tvm.tir.decl_buffer(data.shape, dtype, "indices_buf", data_alignment=8)
indices_swap_buf = tvm.tir.decl_buffer(data.shape, dtype, "indies_swap_buf", data_alignment=8)
if ret_type == "values":
output = te.extern(
[data.shape],
[data.shape, data.shape],
[data],
lambda ins, outs: sort_ir(ins[0], outs[0], axis, is_ascend),
out_buffers=[values_buf],
lambda ins, outs: sort_ir(ins[0], outs[0], outs[1], axis, is_ascend),
out_buffers=[values_buf, values_swap_buf],
name="topk_gpu",
tag="topk_gpu",
)
)[0]
else:
output = te.extern(
[data.shape, data.shape],
[data.shape, data.shape, data.shape, data.shape],
[data],
lambda ins, outs: sort_ir(ins[0], outs[0], axis, is_ascend, indices_out=outs[1]),
out_buffers=[values_buf, indices_buf],
lambda ins, outs: sort_ir(
ins[0],
outs[0],
outs[2],
axis,
is_ascend,
indices_out=outs[1],
indices_out_swap=outs[3],
),
out_buffers=[values_buf, indices_buf, values_swap_buf, indices_swap_buf],
name="topk_gpu",
tag="topk_gpu",
)
)[0:2]
if isinstance(k, int) and k < 1:
if ret_type == "indices":
return output[1]
Expand Down
Loading