diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index 03e700b3df83..e622a8ae01ab 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -221,8 +221,6 @@ def verify_get_valid_counts(dshape, score_threshold, id_index, score_index): func = relay.Function([x], z.astuple()) func = run_infer_type(func) for target, ctx in ctx_list(): - if target == 'cuda': - return intrp = relay.create_executor("debug", ctx=ctx, target=target) out = intrp.evaluate(func)(np_data) tvm.testing.assert_allclose(out[0].asnumpy(), np_out1, rtol=1e-3, atol=1e-04) diff --git a/topi/python/topi/cuda/nms.py b/topi/python/topi/cuda/nms.py index 38f87a9523c8..5485859de01f 100644 --- a/topi/python/topi/cuda/nms.py +++ b/topi/python/topi/cuda/nms.py @@ -21,29 +21,46 @@ import tvm from tvm import api -from tvm.generic import cast -from tvm.intrin import if_then_else, log, power +from tvm.intrin import if_then_else from topi.vision import non_max_suppression, get_valid_counts from .sort import argsort from .. import tag -def get_valid_counts_pre(data, flag, idx, score_threshold, id_index, score_index): - """Low level IR to Prepare get valid count of bounding boxes - given a score threshold. Also moves valid boxes to the +def cuda_atomic_add_rule(op): + if op.dtype == "float32": + return tvm.call_pure_extern("float32", "atomicAdd", op.args[0], op.args[1]) + if op.dtype == "float64": + return tvm.call_pure_extern("float64", "atomicAdd", op.args[0], op.args[1]) + if op.dtype == "int32": + return tvm.call_pure_extern("int32", "atomicAdd", op.args[0], op.args[1]) + raise RuntimeError("only support int32, float32 and float64") + + +tvm.target.intrin.register_intrin_rule( + "cuda", "atomic_add", cuda_atomic_add_rule, override=True) + + +def atomic_add(x, y): + return tvm.call_pure_intrin(y.dtype, "atomic_add", x, y) + + +def get_valid_counts_ir(data, valid_count, flag, score_threshold, id_index, score_index): + """Low level IR to get valid count of bounding boxes + given a score threshold. Also prepares to move valid boxes to the top of input data. Parameters ---------- - data: Buffer - 3D Buffer with shape [batch_size, num_anchors, elem_length], output of nms. + data : Buffer + Input data. 3-D Buffer with shape [batch_size, num_anchors, elem_length]. + + valid_count : Buffer + 1D buffer for valid number of boxes with shape [batch_size, ]. flag : Buffer 2D Buffer of flag indicating valid data with shape [batch_size, num_anchors]. - idx : Buffer - 2D Buffer of valid data indices with shape [batch_size, num_anchors]. - score_threshold : float32 Lower limit of score for valid bounding boxes. @@ -60,18 +77,24 @@ def get_valid_counts_pre(data, flag, idx, score_threshold, id_index, score_index """ batch_size = data.shape[0] num_anchors = data.shape[1] - box_data_length = data.shape[2] + elem_length = data.shape[2] ib = tvm.ir_builder.create() data = ib.buffer_ptr(data) + + valid_count = ib.buffer_ptr(valid_count) flag = ib.buffer_ptr(flag) - idx = ib.buffer_ptr(idx) - score_threshold = tvm.make.node("FloatImm", dtype="float32", value=score_threshold) + atomic_add_return = ib.allocate( + valid_count.dtype, (1,), name='atomic_add_return', scope='local') + one_count = tvm.const(1, dtype=valid_count.dtype) + score_threshold = tvm.make.node( + "FloatImm", dtype="float32", value=score_threshold) id_index = tvm.make.node("IntImm", dtype="int32", value=id_index) score_index = tvm.make.node("IntImm", dtype="int32", value=score_index) - max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + max_threads = int(tvm.target.Target.current( + allow_none=False).max_num_threads) nthread_tx = max_threads nthread_bx = batch_size * num_anchors // max_threads + 1 tx = tvm.thread_axis("threadIdx.x") @@ -79,163 +102,52 @@ def get_valid_counts_pre(data, flag, idx, score_threshold, id_index, score_index ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr(bx, "thread_extent", nthread_bx) tid = bx * max_threads + tx + idxd = tvm.indexdiv + # initialize valid_count + with ib.if_scope(tid < batch_size): + valid_count[tid] = 0 + # initialize flag with ib.if_scope(tid < batch_size * num_anchors): - with ib.if_scope(tvm.all(data[tid * box_data_length + score_index] > score_threshold, \ - tvm.any(id_index < 0, data[tid * box_data_length + id_index] >= 0))): + flag[tid] = 0 + with ib.if_scope(tid < batch_size * num_anchors): + i = idxd(tid, num_anchors) + with ib.if_scope(tvm.all(data[tid * elem_length + score_index] > score_threshold, + tvm.any(id_index < 0, data[tid * elem_length + id_index] >= 0))): flag[tid] = 1 - idx[tid] = 1 - with ib.else_scope(): - flag[tid] = 0 - idx[tid] = 0 + atomic_add_return[0] = atomic_add(tvm.call_pure_intrin("handle", "tvm_address_of", + valid_count[i]), one_count) return ib.get() -def get_valid_counts_upsweep(data, idx_in, idx, partial): - """Low level IR of first step of scan: unsweep. - - Parameters - ---------- - data: Buffer - 3D Buffer with shape [batch_size, num_anchors, elem_length], output of nms. - - idx_in : Buffer - 2D Buffer of valid data indices with shape [batch_size, num_anchors]. - - idx : Buffer - 2D Buffer of valid data indices with shape [batch_size, num_anchors]. - - partial : Buffer - 2D Buffer of valid data indices with shape [batch_size, new_range]. - - Returns - ------- - stmt : Stmt - The result IR statement. - """ - batch_size = data.shape[0] - num_anchors = data.shape[1] - ib = tvm.ir_builder.create() - data = ib.buffer_ptr(data) - idx_in = ib.buffer_ptr(idx_in) - idx = ib.buffer_ptr(idx) - partial = ib.buffer_ptr(partial) - max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) - elem_per_thread = num_anchors // max_threads + 1 - nthread_tx = max_threads - nthread_bx = batch_size - tx = tvm.thread_axis("threadIdx.x") - bx = tvm.thread_axis("blockIdx.x") - ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr(bx, "thread_extent", nthread_bx) - new_range = num_anchors // elem_per_thread + 1 - # Scan: Upsweep: - with ib.if_scope(tvm.all(bx < batch_size, tx < new_range)): - with ib.for_range(0, elem_per_thread) as i: - with ib.if_scope(bx * num_anchors + \ - tx * elem_per_thread + i < batch_size * num_anchors): - with ib.if_scope(i == 0): - partial[bx * new_range + tx] = idx_in[bx * num_anchors + tx * elem_per_thread] - idx[bx * num_anchors + tx * elem_per_thread] = \ - idx_in[bx * num_anchors + tx * elem_per_thread] - with ib.else_scope(): - partial[bx * new_range + tx] += \ - idx_in[bx * num_anchors + tx * elem_per_thread + i] - idx[bx * num_anchors + tx * elem_per_thread + i] = \ - idx[bx * num_anchors + tx * elem_per_thread + i - 1] + \ - idx_in[bx * num_anchors + tx * elem_per_thread + i] - ib.emit(tvm.make.Call(None, 'tvm_storage_sync', - tvm.convert(['shared']), - tvm.expr.Call.Intrinsic, None, 0)) - return ib.get() -def get_valid_counts_scan(data, partial_in, partial): - """Low level IR to do scan. +def flag_scan(flag, prefix_sum): + """Low level IR to calculate correct positions for valid boxes. Parameters ---------- - data: Buffer - 3D Buffer with shape [batch_size, num_anchors, elem_length], output of nms. - - idx_in : Buffer - 2D Buffer of valid data indices with shape [batch_size, num_anchors]. - - idx : Buffer - 2D Buffer of valid data indices with shape [batch_size, num_anchors]. + flag : Buffer + 2D Buffer of flag indicating valid data with shape [batch_size, num_anchors]. - partial : Buffer - 2D Buffer of valid data indices with shape [batch_size, new_range]. + prefix_sum : Buffer + 2D Buffer of prefix sum of flags indicating new locations of valid boxes + with same shape as flag. Returns ------- stmt : Stmt The result IR statement. """ - batch_size = data.shape[0] - num_anchors = data.shape[1] - ib = tvm.ir_builder.create() - partial_in = ib.buffer_ptr(partial_in) - partial = ib.buffer_ptr(partial) - max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) - elem_per_thread = num_anchors // max_threads + 1 - nthread_tx = max_threads - nthread_bx = batch_size - tx = tvm.thread_axis("threadIdx.x") - bx = tvm.thread_axis("blockIdx.x") - ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr(bx, "thread_extent", nthread_bx) - var = tvm.make.node("FloatImm", dtype="float32", value=2) - new_range = num_anchors // elem_per_thread + 1 - iteration = cast(log(cast(new_range, "float32")) / math.log(2), "int32") - # Scan: Kogge-Stone adder - with ib.if_scope(tvm.all(bx < batch_size, tx < tvm.min(new_range, num_anchors))): - with ib.for_range(0, iteration) as k: - with ib.if_scope(k == 0): - with ib.if_scope(tvm.all(tx > 0, tx < tvm.min(new_range, num_anchors))): - partial[bx * new_range + tx] = \ - partial_in[bx * new_range + tx] + partial_in[bx * new_range + tx - 1] - with ib.else_scope(): - partial[bx * new_range] = partial_in[bx * new_range] - with ib.else_scope(): - with ib.if_scope(tvm.all(tx >= cast(power(var, k), "int32"), \ - tx < tvm.min(new_range, num_anchors))): - partial[bx * new_range + tx] += \ - partial[bx * new_range + tx - cast(power(var, k), "int32")] - ib.emit(tvm.make.Call(None, 'tvm_storage_sync', - tvm.convert(['shared']), - tvm.expr.Call.Intrinsic, None, 0)) - return ib.get() - -def get_valid_counts_downsweep(data, idx_in, partial, idx): - """Low level IR to do downsweep of scan. - - Parameters - ---------- - data: Buffer - 3D Buffer with shape [batch_size, num_anchors, elem_length], output of nms. - - idx_in : Buffer - 2D Buffer of valid data indices with shape [batch_size, num_anchors]. + batch_size = flag.shape[0] + num_anchors = flag.shape[1] - partial : Buffer - 2D Buffer of valid data indices with shape [batch_size, new_range]. + ib = tvm.ir_builder.create() - idx : Buffer - 2D Buffer of valid data indices with shape [batch_size, num_anchors]. + flag = ib.buffer_ptr(flag) + prefix_sum = ib.buffer_ptr(prefix_sum) - Returns - ------- - stmt : Stmt - The result IR statement. - """ - batch_size = data.shape[0] - num_anchors = data.shape[1] - ib = tvm.ir_builder.create() - idx_in = ib.buffer_ptr(idx_in) - idx = ib.buffer_ptr(idx) - partial = ib.buffer_ptr(partial) - max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) - elem_per_thread = num_anchors // max_threads + 1 + max_threads = int(tvm.target.Target.current( + allow_none=False).max_num_threads) nthread_tx = max_threads nthread_bx = batch_size * num_anchors // max_threads + 1 tx = tvm.thread_axis("threadIdx.x") @@ -243,23 +155,23 @@ def get_valid_counts_downsweep(data, idx_in, partial, idx): ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr(bx, "thread_extent", nthread_bx) tid = bx * max_threads + tx - new_range = num_anchors // elem_per_thread + 1 idxd = tvm.indexdiv idxm = tvm.indexmod - # Scan: Downsweep: - with ib. if_scope(tid < batch_size * num_anchors): - i = idxd(tid, num_anchors) # number of batches - j = idxm(tid, num_anchors) # number of anchors - with ib.if_scope(j < elem_per_thread): - idx[tid] = idx_in[tid] - with ib.else_scope(): - idx[tid] = idx_in[tid] + partial[i * new_range + idxd(j, elem_per_thread) - 1] + + # initialize prefix_sum + with ib.if_scope(tid < batch_size * num_anchors): + prefix_sum[tid] = 0 + with ib.if_scope(tid < batch_size * num_anchors): + i = idxd(tid, num_anchors) + j = idxm(tid, num_anchors) + with ib.for_range(0, j) as r: + prefix_sum[tid] += flag[i * num_anchors + r] return ib.get() -def get_valid_counts_ir(data, flag, idx, valid_count, out): - """Low level IR to get valid count of bounding boxes - given a score threshold. Also moves valid boxes to the + +def out_rewrite(data, flag, prefix_sum, valid_count, out): + """Low level IR to move valid boxes to the top of input data. Parameters @@ -270,11 +182,12 @@ def get_valid_counts_ir(data, flag, idx, valid_count, out): flag : Buffer 2D Buffer of flag indicating valid data with shape [batch_size, num_anchors]. - idx : Buffer - 2D Buffer of valid data indices with shape [batch_size, num_anchors]. + prefix_sum : Buffer + 2D Buffer of prefix sum of flags indicating new locations of valid boxes + with same shape as flag. valid_count : Buffer - 1-D buffer for valid number of boxes. + 1D buffer for valid number of boxes with shape [batch_size, ]. out : Buffer Rearranged data buffer. @@ -284,28 +197,28 @@ def get_valid_counts_ir(data, flag, idx, valid_count, out): stmt : Stmt The result IR statement. """ - batch_size = data.shape[0] - num_anchors = data.shape[1] - elem_length = data.shape[2] - size = batch_size * num_anchors * elem_length + batch_size = out.shape[0] + num_anchors = out.shape[1] + elem_length = out.shape[2] ib = tvm.ir_builder.create() + one = tvm.const(1, dtype=out.dtype) data = ib.buffer_ptr(data) flag = ib.buffer_ptr(flag) - idx = ib.buffer_ptr(idx) valid_count = ib.buffer_ptr(valid_count) + prefix_sum = ib.buffer_ptr(prefix_sum) out = ib.buffer_ptr(out) - max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + max_threads = int(tvm.target.Target.current( + allow_none=False).max_num_threads) nthread_tx = max_threads - nthread_bx = batch_size * num_anchors * elem_length // max_threads + 1 + nthread_bx = batch_size * num_anchors // max_threads + 1 tx = tvm.thread_axis("threadIdx.x") bx = tvm.thread_axis("blockIdx.x") ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr(bx, "thread_extent", nthread_bx) tid = bx * max_threads + tx - idxd = tvm.indexdiv idxm = tvm.indexmod @@ -313,17 +226,15 @@ def get_valid_counts_ir(data, flag, idx, valid_count, out): i = idxd(tid, num_anchors) j = idxm(tid, num_anchors) base_idx = i * num_anchors * elem_length - with ib.if_scope(flag[tid] > 0): + with ib.if_scope(tvm.all(flag[tid] > 0, prefix_sum[tid] >= 0, + prefix_sum[tid] < num_anchors)): + with ib.for_range(0, elem_length) as k: + out[base_idx + prefix_sum[tid] * elem_length + + k] = data[tid * elem_length + k] + with ib.if_scope(j >= valid_count[i]): with ib.for_range(0, elem_length) as k: - with ib.if_scope(base_idx + (idx[tid] - 1) * elem_length + k < size): - out[base_idx + (idx[tid] - 1) * elem_length + k] =\ - data[base_idx + j * elem_length + k] - with ib.if_scope(j == 0): - valid_count[i] = idx[tid + num_anchors - 1] - with ib.if_scope(j >= idx[i * num_anchors + num_anchors - 1]): - with ib.for_range(0, elem_length) as l: - with ib.if_scope(tid * elem_length + l < size): - out[tid * elem_length + l] = -1.0 + out[tid * elem_length + k] = -one + return ib.get() @@ -356,56 +267,47 @@ def get_valid_counts_gpu(data, score_threshold=0, id_index=0, score_index=1): """ batch_size = data.shape[0] num_anchors = data.shape[1] - max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) - elem_per_thread = num_anchors // max_threads + 1 - new_range = num_anchors // elem_per_thread + 1 + data_buf = api.decl_buffer( + data.shape, data.dtype, "data_buf", data_alignment=8) + valid_count_buf = api.decl_buffer( + (batch_size,), "int32", "valid_count_buf", data_alignment=8) temp_flag_buf = api.decl_buffer( (batch_size, num_anchors,), "int32", "temp_flag", data_alignment=8) - temp_idx_buf = api.decl_buffer( - (batch_size, num_anchors,), "int32", "temp_idx", data_alignment=8) temp_partial_buf = api.decl_buffer( - (batch_size, new_range), "int32", "temp_partial", data_alignment=8) - data_buf = api.decl_buffer( - data.shape, data.dtype, "data_buf", data_alignment=8) + (batch_size, num_anchors), "int32", "temp_partial", data_alignment=8) + out_buf = api.decl_buffer( + data.shape, data.dtype, "out_buf", data_alignment=8) - temp_flag, temp_idx = \ - tvm.extern([(batch_size, num_anchors,), (batch_size, num_anchors,)], [data], - lambda ins, outs: get_valid_counts_pre( - ins[0], outs[0], outs[1], score_threshold, id_index, score_index), - dtype=["int32", "int32"], - out_buffers=[temp_flag_buf, temp_idx_buf], - name="get_valid_counts_phase_one") - temp_idx_new, temp_partial = \ - tvm.extern([(batch_size, num_anchors,), (batch_size, new_range)], [data, temp_idx], - lambda ins, outs: get_valid_counts_upsweep( - ins[0], ins[1], outs[0], outs[1]), - dtype=["int32", "int32"], - out_buffers=[temp_idx_buf, temp_partial_buf], - name="get_valid_counts_phase_two") - temp_partial_new = \ - tvm.extern([(batch_size, new_range)], [data, temp_partial], - lambda ins, outs: get_valid_counts_scan( - ins[0], ins[1], outs[0]), - dtype=["int32"], - out_buffers=[temp_partial_buf], - name="get_valid_counts_phase_three") - temp_idx_final = \ - tvm.extern([(batch_size, num_anchors)], [data, temp_idx_new, temp_partial_new], - lambda ins, outs: get_valid_counts_downsweep( - ins[0], ins[1], ins[2], outs[0]), - dtype=["int32"], - out_buffers=[temp_idx_buf], - name="get_valid_counts_phase_four") - valid_count, out_tensor = \ - tvm.extern([(batch_size,), data.shape], [data, temp_flag, temp_idx_final], - lambda ins, outs: get_valid_counts_ir( - ins[0], ins[1], ins[2], outs[0], outs[1]), - dtype=["int32", data.dtype], - in_buffers=[data_buf, temp_flag_buf, temp_idx_buf], - name="get_valid_counts_phase_five", + valid_count, temp_flag = \ + tvm.extern([(batch_size,), (batch_size, num_anchors)], [data], + lambda ins, outs: get_valid_counts_ir( + ins[0], outs[0], outs[1], score_threshold, id_index, score_index), + dtype=["int32", "int32"], + in_buffers=[data_buf], + out_buffers=[valid_count_buf, temp_flag_buf], + name="get_valid_counts", tag="get_valid_counts_gpu") - return [valid_count, out_tensor] + temp_partial = \ + tvm.extern([(batch_size, num_anchors)], [temp_flag], + lambda ins, outs: flag_scan( + ins[0], outs[0]), + dtype=["int32"], + in_buffers=[temp_flag_buf], + out_buffers=[temp_partial_buf], + name="flag_scan") + + out = \ + tvm.extern([data.shape], [data, temp_flag, temp_partial, valid_count], + lambda ins, outs: out_rewrite( + ins[0], ins[1], ins[2], ins[3], outs[0]), + dtype=[data.dtype], + in_buffers=[data_buf, temp_flag_buf, + temp_partial_buf, valid_count_buf], + out_buffers=[out_buf], + name="out_rewrite") + + return [valid_count, out] def nms_ir(data, sorted_index, valid_count, out, box_indices, @@ -479,7 +381,8 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): valid_count = ib.buffer_ptr(valid_count) out = ib.buffer_ptr(out) box_indices = ib.buffer_ptr(box_indices) - num_valid_boxes = ib.allocate("int32", (1,), name="num_valid_boxes", scope="local") + num_valid_boxes = ib.allocate( + "int32", (1,), name="num_valid_boxes", scope="local") max_threads = int( tvm.target.Target.current(allow_none=False).max_num_threads) @@ -491,26 +394,29 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): ib.scope_attr(bx, "thread_extent", nthread_bx) j = bx * max_threads + tx - iou_threshold = tvm.make.node("FloatImm", dtype="float32", value=iou_threshold) + iou_threshold = tvm.make.node( + "FloatImm", dtype="float32", value=iou_threshold) top_k = tvm.make.node("IntImm", dtype="int32", value=top_k) coord_start = tvm.make.node("IntImm", dtype="int32", value=coord_start) id_index = tvm.make.node("IntImm", dtype="int32", value=id_index) score_index = tvm.make.node("IntImm", dtype="int32", value=score_index) - force_suppress = tvm.make.node("IntImm", dtype="int32", value=1 if force_suppress else 0) + force_suppress = tvm.make.node( + "IntImm", dtype="int32", value=1 if force_suppress else 0) with ib.for_range(0, batch_size, for_type="unroll") as i: base_idx = i * num_anchors * box_data_length with ib.if_scope(tvm.all(iou_threshold > 0, valid_count[i] > 0)): # Reorder output - nkeep = if_then_else( \ - tvm.all(top_k > 0, top_k < valid_count[i]), - top_k, valid_count[i]) + nkeep = if_then_else( + tvm.all(top_k > 0, top_k < valid_count[i]), + top_k, valid_count[i]) with ib.if_scope(j < nkeep): with ib.for_range(0, box_data_length) as k: out[(base_idx + j * box_data_length + k)] = \ - data[(base_idx + sorted_index[i * num_anchors + j] \ - * box_data_length + k)] - box_indices[i * num_anchors + j] = sorted_index[i * num_anchors + j] + data[(base_idx + sorted_index[i * num_anchors + j] + * box_data_length + k)] + box_indices[i * num_anchors + + j] = sorted_index[i * num_anchors + j] with ib.if_scope(tvm.all(top_k > 0, top_k < valid_count[i])): with ib.if_scope(j < valid_count[i] - nkeep): with ib.for_range(0, box_data_length) as k: @@ -519,16 +425,18 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): # Apply nms with ib.for_range(0, valid_count[i]) as k: offset_k = k * box_data_length - with ib.if_scope(tvm.all(out[base_idx + offset_k + score_index] > 0, \ - tvm.any(id_index < 0, out[base_idx + offset_k + id_index] >= 0))): + with ib.if_scope(tvm.all(out[base_idx + offset_k + score_index] > 0, + tvm.any(id_index < 0, out[base_idx + + offset_k + id_index] >= 0))): with ib.if_scope(j < valid_count[i]): offset_j = j * box_data_length - with ib.if_scope(tvm.all(j > k, \ - out[base_idx + offset_j + score_index] > 0, \ - tvm.any(id_index < 0, \ - out[base_idx + offset_j + id_index] >= 0), \ - tvm.any(force_suppress > 0, id_index < 0, \ - out[base_idx + offset_k + id_index] == \ + with ib.if_scope(tvm.all(j > k, + out[base_idx + offset_j + + score_index] > 0, + tvm.any(id_index < 0, + out[base_idx + offset_j + id_index] >= 0), + tvm.any(force_suppress > 0, id_index < 0, + out[base_idx + offset_k + id_index] == out[base_idx + offset_j + id_index]))): iou = calculate_overlap(out, base_idx + offset_j + coord_start, base_idx + offset_k + coord_start) @@ -541,12 +449,14 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): with ib.if_scope(j < valid_count[i]): offset_j = j * box_data_length with ib.for_range(0, box_data_length) as k: - out[(base_idx + offset_j + k)] = data[base_idx + offset_j + k] + out[(base_idx + offset_j + k) + ] = data[base_idx + offset_j + k] box_indices[i * num_anchors + j] = j # Set invalid entry to be -1 with ib.if_scope(j < num_anchors - valid_count[i]): with ib.for_range(0, box_data_length) as k: - out[base_idx + (j + valid_count[i]) * box_data_length + k] = -1.0 + out[base_idx + (j + valid_count[i]) * + box_data_length + k] = -1.0 box_indices[i * num_anchors + j + valid_count[i]] = -1 # Only return max_output_size number of valid boxes num_valid_boxes[0] = 0 @@ -671,7 +581,7 @@ def invalid_to_bottom_ir(data, flag, idx, out): with ib.if_scope(flag[i * num_anchors + j] > 0): with ib.for_range(0, elem_length) as k: out[base_idx + (idx[i * num_anchors + j] - 1) * elem_length + k] \ - = data[base_idx + j * elem_length + k] + = data[base_idx + j * elem_length + k] return ib.get() @@ -756,8 +666,10 @@ def non_max_suppression_gpu(data, valid_count, max_output_size=-1, "valid_count_buf", data_alignment=4) score_axis = score_index score_shape = (batch_size, num_anchors) - score_tensor = tvm.compute(score_shape, lambda i, j: data[i, j, score_axis], tag=tag.ELEMWISE) - sort_tensor = argsort(score_tensor, valid_count=valid_count, axis=1, is_ascend=False) + score_tensor = tvm.compute( + score_shape, lambda i, j: data[i, j, score_axis], tag=tag.ELEMWISE) + sort_tensor = argsort( + score_tensor, valid_count=valid_count, axis=1, is_ascend=False) sort_tensor_buf = api.decl_buffer(sort_tensor.shape, sort_tensor.dtype, "sort_tensor_buf", data_alignment=8) @@ -795,7 +707,8 @@ def non_max_suppression_gpu(data, valid_count, max_output_size=-1, ins[0], outs[0], outs[1]), dtype=["int32", "int32"], in_buffers=[out_buf], - out_buffers=[temp_flag_buf, temp_idx_buf], + out_buffers=[ + temp_flag_buf, temp_idx_buf], name="invalid_to_bottom_phase_one") output = tvm.extern([data.shape], [out, temp_flag, temp_idx], diff --git a/topi/tests/python/test_topi_vision.py b/topi/tests/python/test_topi_vision.py index a081f0797dad..85e4180a0892 100644 --- a/topi/tests/python/test_topi_vision.py +++ b/topi/tests/python/test_topi_vision.py @@ -67,8 +67,8 @@ def check_device(device): tvm.testing.assert_allclose(tvm_out2.asnumpy(), np_out2, rtol=1e-3) for device in ['llvm', 'cuda', 'opencl']: - # Disable gpu test for now - if device != "llvm": + # Disable opencl test for now + if device != "llvm" and device != "cuda": continue check_device(device)