Skip to content

Commit

Permalink
Refactor to use ir_builder directly
Browse files Browse the repository at this point in the history
  • Loading branch information
ymwangg committed Feb 23, 2021
1 parent 1553d48 commit 14811bf
Show file tree
Hide file tree
Showing 5 changed files with 198 additions and 101 deletions.
2 changes: 1 addition & 1 deletion python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2167,7 +2167,7 @@ def is_floating_point(self, inputs, input_types):
def unique(self, inputs, input_types):
assert len(inputs) == 4
[data, is_sorted, return_inverse, return_counts] = inputs
if is_sorted == False:
if not is_sorted:
logging.warning("TVM always assumes sorted=True for torch.unique")
is_sorted = True
if return_counts:
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def compute_cumsum(attrs, inputs, output_type):

@_reg.register_compute("unique")
def compute_unique(attrs, inputs, output_type):
"""Compute definition of cumsum"""
"""Compute definition of unique"""
return topi.unique(inputs[0], attrs.sorted, attrs.return_counts)


Expand Down
2 changes: 1 addition & 1 deletion python/tvm/topi/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,4 @@
from . import tensorcore_alter_op
from .argwhere import *
from .scan import *
from .unique import *
from .unique import *
287 changes: 192 additions & 95 deletions python/tvm/topi/cuda/unique.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,62 +16,83 @@
# under the License.
# pylint: disable=invalid-name, no-else-return
"""Unique operator"""
from tvm import te, tir
import tvm

from ...te import hybrid
from .scan import cumsum
from .sort import sort, argsort
from tvm import te
import tvm
from ..utils import ceil_div
from .nms import atomic_add


@hybrid.script
def _calc_adjacent_diff(data):
output = output_tensor(data.shape, "int32")
idx = allocate((1,), "int32", "local")
i_extent = min(data.shape[0], max_num_threads(False))
j_extent = ceil_div(data.shape[0], i_extent)
for i in bind("threadIdx.x", i_extent):
for j in range(j_extent):
idx[0] = j * i_extent + i
if idx[0] == 0:
output[0] = int32(0)
elif idx[0] < data.shape[0]:
output[idx[0]] = int32(1) if data[idx[0]] != data[idx[0] - 1] else int32(0)
return output
def _calc_adjacent_diff_ir(data, adjacent_diff):
ib = tvm.tir.ir_builder.create()
data_ptr = ib.buffer_ptr(data)
adjacent_diff_ptr = ib.buffer_ptr(adjacent_diff)
batch_size = data.shape[0]
max_threads = tir.min(batch_size, tvm.target.Target.current(allow_none=False).max_num_threads)
with ib.new_scope():
nthread_tx = max_threads
nthread_bx = ceil_div(batch_size, max_threads)
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx
with ib.if_scope(tid < batch_size):
with ib.if_scope(tid == 0):
adjacent_diff_ptr[tid] = 0
with ib.else_scope():
with ib.if_scope(data_ptr[tid] != data_ptr[tid - 1]):
adjacent_diff_ptr[tid] = 1
with ib.else_scope():
adjacent_diff_ptr[tid] = 0
return ib.get()


@hybrid.script
def _calc_num_unique(data):
output = output_tensor((1,), "int32")
for i in bind("threadIdx.x", 1):
output[0] = data[data.shape[0] - 1] + int32(1)
output[i] = data[data.shape[0] - 1] + int32(1)
return output


@hybrid.script
def _calc_unique_sorted(data, argsorted_indices, inc_scan):
unique_elements = output_tensor(data.shape, data.dtype)
indices = output_tensor(data.shape, "int32")
idx = allocate((1,), "int32", "local")
i_extent = min(data.shape[0], max_num_threads(False))
j_extent = ceil_div(data.shape[0], i_extent)
for i in bind("threadIdx.x", i_extent):
for j in range(j_extent):
idx[0] = j * i_extent + i
if idx[0] < data.shape[0]:
indices[argsorted_indices[idx[0]]] = inc_scan[idx[0]]
if idx[0] == 0 or inc_scan[idx[0]] != inc_scan[idx[0] - 1]:
unique_elements[inc_scan[idx[0]]] = data[argsorted_indices[idx[0]]]
return unique_elements, indices
def _calc_unique_sorted_ir(data, argsorted_indices, inc_scan, unique_elements, indices):
ib = tvm.tir.ir_builder.create()
data_ptr = ib.buffer_ptr(data)
argsorted_indices_ptr = ib.buffer_ptr(argsorted_indices)
inc_scan_ptr = ib.buffer_ptr(inc_scan)
unique_elements_ptr = ib.buffer_ptr(unique_elements)
indices_ptr = ib.buffer_ptr(indices)

batch_size = data.shape[0]
max_threads = tir.min(batch_size, tvm.target.Target.current(allow_none=False).max_num_threads)
with ib.new_scope():
nthread_tx = max_threads
nthread_bx = ceil_div(batch_size, max_threads)
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx
with ib.if_scope(tid < batch_size):
indices_ptr[argsorted_indices_ptr[tid]] = inc_scan_ptr[tid]
with ib.if_scope(tid == 0):
unique_elements_ptr[inc_scan_ptr[tid]] = data_ptr[argsorted_indices_ptr[tid]]
with ib.else_scope():
with ib.if_scope(inc_scan_ptr[tid] != inc_scan_ptr[tid - 1]):
unique_elements_ptr[inc_scan_ptr[tid]] = data_ptr[argsorted_indices_ptr[tid]]
return ib.get()


def _calc_counts_sorted_ir(inc_scan, counts):
ib = tvm.tir.ir_builder.create()
inc_scan_ptr = ib.buffer_ptr(inc_scan)
counts_ptr = ib.buffer_ptr(counts)

batch_size = inc_scan.shape[0]
max_threads = min(batch_size, tvm.target.Target.current(allow_none=False).max_num_threads)
max_threads = tir.min(batch_size, tvm.target.Target.current(allow_none=False).max_num_threads)
with ib.new_scope():
nthread_tx = max_threads
nthread_bx = ceil_div(batch_size, max_threads)
Expand Down Expand Up @@ -102,65 +123,83 @@ def _calc_counts_sorted_ir(inc_scan, counts):
return ib.get()


@hybrid.script
def _calc_first_occurence(argsorted_indices, inc_scan):
first_occurence = output_tensor(argsorted_indices.shape, "int32")
idx = allocate((1,), "int32", "local")
i_extent = min(argsorted_indices.shape[0], max_num_threads(False))
j_extent = ceil_div(argsorted_indices.shape[0], i_extent)
for i in bind("threadIdx.x", i_extent):
for j in range(j_extent):
idx[0] = j * i_extent + i
if idx[0] < argsorted_indices.shape[0]:
first_occurence[idx[0]] = argsorted_indices.shape[0]
for i in bind("threadIdx.x", i_extent):
for j in range(j_extent):
idx[0] = j * i_extent + i
if idx[0] < argsorted_indices.shape[0]:
if idx[0] == 0 or inc_scan[idx[0]] != inc_scan[idx[0] - 1]:
first_occurence[inc_scan[idx[0]]] = argsorted_indices[idx[0]]
return first_occurence

def _calc_first_occurence_ir(argsorted_indices, inc_scan, first_occurence):
ib = tvm.tir.ir_builder.create()
argsorted_indices_ptr = ib.buffer_ptr(argsorted_indices)
inc_scan_ptr = ib.buffer_ptr(inc_scan)
first_occurence_ptr = ib.buffer_ptr(first_occurence)
batch_size = argsorted_indices.shape[0]
max_threads = tir.min(batch_size, tvm.target.Target.current(allow_none=False).max_num_threads)
with ib.new_scope():
nthread_tx = max_threads
nthread_bx = ceil_div(batch_size, max_threads)
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx
with ib.if_scope(tid < batch_size):
first_occurence_ptr[tid] = batch_size
with ib.new_scope():
nthread_tx = max_threads
nthread_bx = ceil_div(batch_size, max_threads)
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx
with ib.if_scope(tid < batch_size):
with ib.if_scope(tid == 0):
first_occurence_ptr[inc_scan_ptr[tid]] = argsorted_indices_ptr[tid]
with ib.else_scope():
with ib.if_scope(inc_scan_ptr[tid] != inc_scan_ptr[tid - 1]):
first_occurence_ptr[inc_scan_ptr[tid]] = argsorted_indices_ptr[tid]
return ib.get()

@hybrid.script
def _calc_unique_unsorted(data, argsorted_indices, inc_scan, index_converter):
unique_elements = output_tensor(data.shape, data.dtype)
indices = output_tensor(data.shape, "int32")
for i in parallel(data.shape[0]):
new_unique_idx = index_converter[inc_scan[i]]
new_data_idx = argsorted_indices[i]
indices[new_data_idx] = new_unique_idx
if i == 0 or inc_scan[i] != inc_scan[i - 1]:
unique_elements[new_unique_idx] = data[new_data_idx]
return unique_elements, indices

def _calc_unique_unsorted_ir(
data, argsorted_indices, inc_scan, index_converter, unique_elements, indices
):
ib = tvm.tir.ir_builder.create()
data_ptr = ib.buffer_ptr(data)
argsorted_indices_ptr = ib.buffer_ptr(argsorted_indices)
inc_scan_ptr = ib.buffer_ptr(inc_scan)
index_converter_ptr = ib.buffer_ptr(index_converter)
unique_elements_ptr = ib.buffer_ptr(unique_elements)
indices_ptr = ib.buffer_ptr(indices)

@hybrid.script
def _calc_unique_unsorted(data, argsorted_indices, inc_scan, index_converter):
unique_elements = output_tensor(data.shape, data.dtype)
indices = output_tensor(data.shape, "int32")
idx = allocate((1,), "int32", "local")
i_extent = min(data.shape[0], max_num_threads(False))
j_extent = ceil_div(data.shape[0], i_extent)
for i in bind("threadIdx.x", i_extent):
for j in range(j_extent):
idx[0] = j * i_extent + i
if idx[0] < data.shape[0]:
indices[argsorted_indices[idx[0]]] = index_converter[inc_scan[idx[0]]]
if idx[0] == 0 or inc_scan[idx[0]] != inc_scan[idx[0] - 1]:
unique_elements[index_converter[inc_scan[idx[0]]]] = data[
argsorted_indices[idx[0]]
batch_size = data.shape[0]
max_threads = tir.min(batch_size, tvm.target.Target.current(allow_none=False).max_num_threads)
with ib.new_scope():
nthread_tx = max_threads
nthread_bx = ceil_div(batch_size, max_threads)
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx
with ib.if_scope(tid < batch_size):
indices_ptr[argsorted_indices_ptr[tid]] = index_converter_ptr[inc_scan_ptr[tid]]
with ib.if_scope(tid == 0):
unique_elements_ptr[index_converter_ptr[inc_scan_ptr[tid]]] = data_ptr[
argsorted_indices_ptr[tid]
]
with ib.else_scope():
with ib.if_scope(inc_scan_ptr[tid] != inc_scan_ptr[tid - 1]):
unique_elements_ptr[index_converter_ptr[inc_scan_ptr[tid]]] = data_ptr[
argsorted_indices_ptr[tid]
]
return unique_elements, indices
return ib.get()


def _calc_counts_unsorted_ir(inc_scan, index_converter, counts):
ib = tvm.tir.ir_builder.create()
inc_scan_ptr = ib.buffer_ptr(inc_scan)
index_converter_ptr = ib.buffer_ptr(index_converter)
counts_ptr = ib.buffer_ptr(counts)

batch_size = inc_scan.shape[0]
max_threads = min(batch_size, tvm.target.Target.current(allow_none=False).max_num_threads)
max_threads = tir.min(batch_size, tvm.target.Target.current(allow_none=False).max_num_threads)
with ib.new_scope():
nthread_tx = max_threads
nthread_bx = ceil_div(batch_size, max_threads)
Expand Down Expand Up @@ -234,17 +273,55 @@ def unique(data, is_sorted=True, return_counts=False):

sorted_data = sort(data)
argsorted_indices = argsort(data, dtype="int32")
adjacent_diff = _calc_adjacent_diff(sorted_data)
# calculate adjacent difference
sorted_data_buf = tvm.tir.decl_buffer(
data.shape, data.dtype, "sorted_data_buf", data_alignment=8
)
adjacent_diff_buf = tvm.tir.decl_buffer(
data.shape, "int32", "adjacent_diff_buf", data_alignment=8
)
adjacent_diff = te.extern(
[data.shape],
[sorted_data],
lambda ins, outs: _calc_adjacent_diff_ir(ins[0], outs[0]),
dtype=["int32"],
in_buffers=[sorted_data_buf],
out_buffers=[adjacent_diff_buf],
name="_calc_adjacent_diff",
tag="_calc_adjacent_diff_gpu",
)
# calculate inclusive scan
inc_scan = cumsum(adjacent_diff, dtype="int32", exclusive=0)
# calculate number of unique elements
num_unique_elements = _calc_num_unique(inc_scan)
# declare buffers
data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
argsorted_indices_buf = tvm.tir.decl_buffer(
data.shape, "int32", "argsorted_indices_buf", data_alignment=8
)
inc_scan_buf = tvm.tir.decl_buffer(data.shape, "int32", "inc_scan_buf", data_alignment=8)
unique_elements_buf = tvm.tir.decl_buffer(
data.shape, data.dtype, "unique_elements_buf", data_alignment=8
)
inverse_indices_buf = tvm.tir.decl_buffer(
data.shape, "int32", "inverse_indices_buf", data_alignment=8
)
if is_sorted:
unique_elements, inverse_indices = _calc_unique_sorted(data, argsorted_indices, inc_scan)
# calculate unique elements and inverse indices
unique_elements, inverse_indices = te.extern(
[data.shape, data.shape],
[data, argsorted_indices, inc_scan],
lambda ins, outs: _calc_unique_sorted_ir(*ins, *outs),
dtype=[data.dtype, "int32"],
in_buffers=[data_buf, argsorted_indices_buf, inc_scan_buf],
out_buffers=[unique_elements_buf, inverse_indices_buf],
name="_calc_unique_sorted",
tag="_calc_unique_sorted_gpu",
)
if not return_counts:
return [unique_elements, inverse_indices, num_unique_elements]
else:
inc_scan_buf = tvm.tir.decl_buffer(
data.shape, "int32", "inc_scan_buf", data_alignment=8
)
# calculate counts of unique elements
counts_buf = tvm.tir.decl_buffer(data.shape, "int32", "counts_buf", data_alignment=8)
counts = te.extern(
[data.shape],
Expand All @@ -258,21 +335,41 @@ def unique(data, is_sorted=True, return_counts=False):
)
return [unique_elements, inverse_indices, num_unique_elements, counts]
else:
first_occurence = _calc_first_occurence(argsorted_indices, inc_scan)
# calculate first occurence
first_occurence_buf = tvm.tir.decl_buffer(
data.shape, "int32", "first_occurence_buf", data_alignment=8
)
first_occurence = te.extern(
[data.shape],
[argsorted_indices, inc_scan],
lambda ins, outs: _calc_first_occurence_ir(ins[0], ins[1], outs[0]),
dtype=["int32"],
in_buffers=[argsorted_indices_buf, inc_scan_buf],
out_buffers=[first_occurence_buf],
name="_calc_first_occurence",
tag="_calc_first_occurence_gpu",
)
# calculate index converter by sorting unique elements by their first occurence
argsorted_first_occurence = argsort(first_occurence, dtype="int32")
index_converter = argsort(argsorted_first_occurence, dtype="int32")
unique_elements, inverse_indices = _calc_unique_unsorted(
data, argsorted_indices, inc_scan, index_converter
# calculate unique elements and inverse indices
index_converter_buf = tvm.tir.decl_buffer(
data.shape, "int32", "index_converter_buf", data_alignment=8
)
unique_elements, inverse_indices = te.extern(
[data.shape, data.shape],
[data, argsorted_indices, inc_scan, index_converter],
lambda ins, outs: _calc_unique_unsorted_ir(*ins, *outs),
dtype=[data.dtype, "int32"],
in_buffers=[data_buf, argsorted_indices_buf, inc_scan_buf, index_converter_buf],
out_buffers=[unique_elements_buf, inverse_indices_buf],
name="_calc_unique_unsorted",
tag="_calc_unique_unsorted_gpu",
)
if not return_counts:
return [unique_elements, inverse_indices, num_unique_elements]
else:
inc_scan_buf = tvm.tir.decl_buffer(
data.shape, "int32", "inc_scan_buf", data_alignment=8
)
index_converter_buf = tvm.tir.decl_buffer(
data.shape, "int32", "index_converter_buf", data_alignment=8
)
# calculate counts of unique elements
counts_buf = tvm.tir.decl_buffer(data.shape, "int32", "counts_buf", data_alignment=8)
counts = te.extern(
[data.shape],
Expand Down
Loading

0 comments on commit 14811bf

Please sign in to comment.