Skip to content

Commit

Permalink
[TOPI] Fix GPU Dynamic Op Schedule (#7117)
Browse files Browse the repository at this point in the history
* Fix GPU dynamic op schedules

* Fix dynamic shape nms

* Fix

* Fix test format
  • Loading branch information
kevinthesun authored Dec 17, 2020
1 parent fb8de5a commit bad149e
Show file tree
Hide file tree
Showing 7 changed files with 117 additions and 10 deletions.
7 changes: 6 additions & 1 deletion python/tvm/topi/cuda/conv2d_transpose_nchw.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,10 @@ def _callback(op):
##### space definition begin #####
n, f, y, x = s[conv].op.axis
rc = s[conv].op.reduce_axis[0]
cfg.define_split("tile_n", cfg.axis(n), num_outputs=4)
# TODO(@kevinthesun): Support tuning/optimization for dynamic shape.
bs = pad_data.shape[0]
n_tuning_axis = n if isinstance(bs, tvm.tir.IntImm) else 1
cfg.define_split("tile_n", cfg.axis(n_tuning_axis), num_outputs=4)
cfg.define_split("tile_f", cfg.axis(f), num_outputs=4)
cfg.define_split("tile_y", cfg.axis(y), num_outputs=4)
cfg.define_split("tile_x", cfg.axis(x), num_outputs=4)
Expand All @@ -194,6 +197,8 @@ def _callback(op):

if cfg.is_fallback:
N, F, Y, X = get_const_tuple(conv.shape)
if not isinstance(N, int):
N = 1
_fallback_schedule(N, F, Y, X)

##### space definition end #####
Expand Down
13 changes: 12 additions & 1 deletion python/tvm/topi/cuda/injective.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,16 @@ def schedule_injective_from_existing(sch, out):
# bandwidth.
vector_width = 4 if out.dtype == "float16" else 1

is_dynamic_output = False
for dim in out.shape:
if not isinstance(dim, tvm.tir.IntImm):
is_dynamic_output = True
break

out_len = utils.prod(out.shape)

try:
const_size = utils.get_const_int(utils.prod(out.shape))
const_size = utils.get_const_int(out_len)
need_block_split = const_size > max_block * num_thread * vector_width
except ValueError:
need_block_split = False
Expand All @@ -61,6 +69,9 @@ def schedule_injective_from_existing(sch, out):
sch[out].bind(bx, te.thread_axis("blockIdx.x"))
sch[out].bind(tx, te.thread_axis("threadIdx.x"))
else:
# Use less threads for dynamic shape ops to avoid runtime error.
if is_dynamic_output:
num_thread //= 2
bx, tx = sch[out].split(fused, factor=num_thread)
sch[out].bind(tx, te.thread_axis("threadIdx.x"))
sch[out].bind(bx, te.thread_axis("blockIdx.x"))
Expand Down
49 changes: 46 additions & 3 deletions python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

from tvm.tir import if_then_else
from .sort import argsort, argsort_thrust
from .. import tag


def cuda_atomic_add_rule(op):
Expand Down Expand Up @@ -95,7 +94,7 @@ def rearrange_indices_out_ir(data, output, valid_box_count):
with ib.new_scope():
i = te.thread_axis("blockIdx.x")
ib.scope_attr(i, "thread_extent", batch_size)
valid_idx = ib.allocate("int32", (1), name="valid_idx", scope="local")
valid_idx = ib.allocate("int32", (1,), name="valid_idx", scope="local")
valid_idx[0] = 0
with ib.for_range(0, num_anchors, name="j") as j:
with ib.if_scope(data[i, j] >= 0):
Expand Down Expand Up @@ -654,6 +653,35 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
return ib.get()


def _fetch_score_ir(data, score, axis):
"""
Fetch score from data.
This routine is required for dynamic shape nms.
"""
batch_size = data.shape[0]
num_anchors = data.shape[1]
elem_length = data.shape[2]

ib = tvm.tir.ir_builder.create()

data = ib.buffer_ptr(data)
score = ib.buffer_ptr(score)
with ib.if_scope(num_anchors > 0):
max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
nthread_tx = max_threads
nthread_bx = batch_size * num_anchors // max_threads + 1
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)

tid = bx * max_threads + tx
with ib.if_scope(tid < batch_size * num_anchors):
score[tid] = data[tid * elem_length + axis]

return ib.get()


def non_max_suppression(
data,
valid_count,
Expand Down Expand Up @@ -754,7 +782,22 @@ def non_max_suppression(
)
score_axis = score_index
score_shape = (batch_size, num_anchors)
score_tensor = te.compute(score_shape, lambda i, j: data[i, j, score_axis], tag=tag.ELEMWISE)
data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
score_buf = tvm.tir.decl_buffer(score_shape, data.dtype, "score_buf", data_alignment=8)
score_tensor = te.extern(
[score_shape],
[data],
lambda ins, outs: _fetch_score_ir(
ins[0],
outs[0],
score_axis,
),
dtype=[data.dtype],
in_buffers=[data_buf],
out_buffers=[score_buf],
name="fetch_score",
tag="fetch_score",
)
target = tvm.target.Target.current()
if (
target
Expand Down
3 changes: 3 additions & 0 deletions python/tvm/topi/cuda/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,9 @@ def topk_thrust(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int
tag="topk_gpu",
)

if isinstance(k, tvm.tir.IntImm):
k = k.value

if not isinstance(k, int) or k > 0:
beg = [0] * ndim
end = data.shape[:-1] + [k if isinstance(k, int) else tvm.te.size_var("dim")]
Expand Down
9 changes: 9 additions & 0 deletions src/runtime/contrib/thrust/thrust.cu
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.thrust.stable_sort_by_key")
if (value_dtype == "int32") {
thrust_stable_sort_by_key<int, int>(keys_in, values_in, keys_out, values_out,
for_scatter);
} else if (value_dtype == "int64") {
thrust_stable_sort_by_key<int, int64_t>(keys_in, values_in, keys_out, values_out,
for_scatter);
} else if (value_dtype == "float32") {
thrust_stable_sort_by_key<int, float>(keys_in, values_in, keys_out, values_out,
for_scatter);
Expand All @@ -215,6 +218,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.thrust.stable_sort_by_key")
if (value_dtype == "int32") {
thrust_stable_sort_by_key<int64_t, int>(keys_in, values_in, keys_out, values_out,
for_scatter);
} else if (value_dtype == "int64") {
thrust_stable_sort_by_key<int64_t, int64_t>(keys_in, values_in, keys_out, values_out,
for_scatter);
} else if (value_dtype == "float32") {
thrust_stable_sort_by_key<int64_t, float>(keys_in, values_in, keys_out, values_out,
for_scatter);
Expand All @@ -225,6 +231,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.thrust.stable_sort_by_key")
if (value_dtype == "int32") {
thrust_stable_sort_by_key<float, int>(keys_in, values_in, keys_out, values_out,
for_scatter);
} else if (value_dtype == "int64") {
thrust_stable_sort_by_key<float, int64_t>(keys_in, values_in, keys_out, values_out,
for_scatter);
} else if (value_dtype == "float32") {
thrust_stable_sort_by_key<float, float>(keys_in, values_in, keys_out, values_out,
for_scatter);
Expand Down
17 changes: 15 additions & 2 deletions src/runtime/vm/vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func, In
std::vector<int> codes(arity);
runtime::TVMArgsSetter setter(values.data(), codes.data());
int idx = 0;
bool is_empty_output = false;
for (Index i = 0; i < arg_count; i++) {
if (const auto* dt_cell = args[i].as<ADTObj>()) {
for (size_t fi = 0; fi < dt_cell->size; ++fi) {
Expand All @@ -254,12 +255,24 @@ void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func, In
}
} else {
auto nd_array = Downcast<NDArray>(args[i]);
// We can safely skip CallPacked if there is only one
// output and it is empty.
if (i == arg_count - 1 && output_size == 1) {
for (const auto& dim : nd_array.Shape()) {
if (!dim) {
is_empty_output = true;
break;
}
}
}
setter(idx++, nd_array);
}
}

TVMRetValue rv;
func.CallPacked(TVMArgs(values.data(), codes.data(), arity), &rv);
if (!is_empty_output) {
TVMRetValue rv;
func.CallPacked(TVMArgs(values.data(), codes.data(), arity), &rv);
}
}

void VirtualMachine::LoadExecutable(const Executable* exec) {
Expand Down
29 changes: 26 additions & 3 deletions tests/python/relay/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def check_result(
for kind in ["debug", "vm"]:
targets = targets or tvm.testing.enabled_targets()
for tgt, ctx in targets:
print(tgt)
if disable_targets and tgt in disable_targets:
continue
if kind == "debug" and (only_vm or ctx.device_type != tvm.cpu().device_type):
Expand Down Expand Up @@ -199,6 +200,15 @@ def test_any_concat():
ref = np.concatenate([x_np - 3.0, y_np * 5.0], axis=0)
check_result([x_np, y_np], mod, ref)

num_inputs = 25
x = [relay.var("x", shape=(relay.Any(),), dtype="float32") for _ in range(num_inputs)]
z = relay.op.concatenate(x, axis=0)
mod = tvm.IRModule()
mod["main"] = relay.Function(x, z)
x_np = [np.random.uniform(size=(1,)).astype("float32") for _ in range(num_inputs)]
ref = np.concatenate(x_np, axis=0)
check_result(x_np, mod, ref)


def verify_any_reshape(x_shape, newshape, x_np_shape, out_shape, variable_newshape=False):
x = relay.var("x", shape=x_shape, dtype="float32")
Expand Down Expand Up @@ -572,9 +582,7 @@ def verify_any_conv2d_transpose_nchw(
mod["main"] = relay.Function([data, kernel], y)
data_np = np.random.uniform(size=static_data_shape).astype(dtype)
kernel_np = np.random.uniform(size=kernel_shape).astype(dtype)
check_result(
[data_np, kernel_np], mod, ref_out_shape, assert_shape=True, targets=[("llvm", tvm.cpu())]
)
check_result([data_np, kernel_np], mod, ref_out_shape, assert_shape=True)


# TODO(@kevinthesun): Support dynamic input height and width.
Expand Down Expand Up @@ -1430,6 +1438,21 @@ def test_non_max_suppression():
disable_targets=["nvptx"],
)

np_data = np.zeros((1, 0, 6)).astype("float32")
np_valid_count = np.array([0]).astype("int32")
np_indices = np.zeros((1, 0)).astype("int32")
np_max_output_size = -1
np_indices_result = np.zeros((1, 0))
np_valid_box_count = np.array([[0]]).astype("int32")

check_result(
[np_data, np_valid_count, np_indices, np_max_output_size],
mod,
[np_indices_result, np_valid_box_count],
only_vm=False,
disable_targets=["nvptx"],
)


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit bad149e

Please sign in to comment.