Skip to content

Commit

Permalink
[topi] enable fp16 sort for arm (#4084)
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhliu authored Oct 8, 2019
1 parent ec375a8 commit 1c56c72
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 22 deletions.
28 changes: 27 additions & 1 deletion src/contrib/sort/sort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,11 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort_nms")

// Currently only supports input dtype to be float32.
CHECK_EQ(dtype.code, 2) << "Currently only supports input dtype "
"to be float32.";
"to be float.";
#if (__ARM_FP16_FORMAT_IEEE != 1)
CHECK_EQ(dtype.bits, 32) << "Currently only supports input dtype "
"to be float32.";
#endif
CHECK_LT(axis, input->ndim) << "Axis out of boundary for "
"input ndim " << input->ndim;

Expand All @@ -98,9 +100,25 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort_nms")
sorter.emplace_back(std::make_pair(k, *(data_ptr + full_idx)));
}
if (is_ascend) {
#if (__ARM_FP16_FORMAT_IEEE == 1)
if (dtype.bits == 16) {
std::stable_sort(sorter.begin(), sorter.end(), CompareAscend<__fp16>);
} else {
#endif
std::stable_sort(sorter.begin(), sorter.end(), CompareAscend<float>);
#if (__ARM_FP16_FORMAT_IEEE == 1)
}
#endif
} else {
#if (__ARM_FP16_FORMAT_IEEE == 1)
if (dtype.bits == 16) {
std::stable_sort(sorter.begin(), sorter.end(), CompareDescend<__fp16>);
} else {
#endif
std::stable_sort(sorter.begin(), sorter.end(), CompareDescend<float>);
#if (__ARM_FP16_FORMAT_IEEE == 1)
}
#endif
}
for (int32_t k = 0; k < input->shape[axis]; ++k) {
*(static_cast<int32_t *>(output->data) + base_idx + k * axis_mul_after)
Expand Down Expand Up @@ -192,6 +210,14 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort")
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
#if (__ARM_FP16_FORMAT_IEEE == 1)
} else if (data_dtype == "float16") {
if (out_dtype == "float16") {
argsort<__fp16, __fp16>(input, output, axis, is_ascend);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
#endif
} else if (data_dtype == "int32") {
if (out_dtype == "int32") {
argsort<int32_t, int32_t>(input, output, axis, is_ascend);
Expand Down
56 changes: 35 additions & 21 deletions topi/python/topi/vision/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from ..sort import argsort

@hybrid.script
def hybrid_rearrange_out(data):
def hybrid_rearrange_out(data, one):
"""Hybrid routine to rearrange nms output to
move all valid entries to top.
Expand All @@ -32,6 +32,9 @@ def hybrid_rearrange_out(data):
NMS output. 3-D tensor with shape
[batch_size, num_anchors, 6].
one: tvm.const
Constant one with the same dtype as data.
Returns
-------
output : tvm.Tensor or numpy NDArray
Expand All @@ -55,12 +58,12 @@ def hybrid_rearrange_out(data):
valid_idx += 1
if j >= valid_idx:
for k in range(elem_length):
output[i, j, k] = -1.0
output[i, j, k] = -one
return output


@hybrid.script
def hybrid_get_valid_counts(data, score_threshold, id_index, score_index):
def hybrid_get_valid_counts(data, score_threshold, id_index, score_index, one):
"""Hybrid routine to get valid count of bounding boxes
given a score threshold. Also moves valid boxes to the
top of input data.
Expand All @@ -80,6 +83,9 @@ def hybrid_get_valid_counts(data, score_threshold, id_index, score_index):
score_index: tvm.const
Index of the scores/confidence of boxes.
one: tvm.const
Constant one with the same dtype as data.
Returns
-------
out_tensor : tvm.Tensor or numpy NDArray
Expand Down Expand Up @@ -107,7 +113,7 @@ def hybrid_get_valid_counts(data, score_threshold, id_index, score_index):
valid_count[i] += 1
if j >= valid_count[i]:
for k in range(box_data_length):
out_tensor[i, j, k] = -1.0
out_tensor[i, j, k] = -one
return valid_count, out_tensor

@tvm.target.generic_func
Expand Down Expand Up @@ -138,17 +144,18 @@ def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1):
valid_count : tvm.Tensor
1-D tensor for valid number of boxes.
"""
score_threshold_const = tvm.const(score_threshold, "float32")
score_threshold_const = tvm.const(score_threshold, data.dtype)
id_index_const = tvm.const(id_index, "int32")
score_index_const = tvm.const(score_index, "int32")
return hybrid_get_valid_counts(data, score_threshold_const,
id_index_const, score_index_const)
id_index_const, score_index_const,
tvm.const(1, data.dtype))


@hybrid.script
def hybrid_nms(data, sorted_index, valid_count,
max_output_size, iou_threshold, force_suppress,
top_k, coord_start, id_index, score_index):
top_k, coord_start, id_index, score_index, zero, one):
"""Hybrid routing for non-maximum suppression.
Parameters
Expand Down Expand Up @@ -186,6 +193,12 @@ def hybrid_nms(data, sorted_index, valid_count,
score_index: tvm.const
Index of the scores/confidence of boxes.
zero: tvm.const
Constant zero with the same dtype as data.
one: tvm.const
Constant one with the same dtype as data.
Returns
-------
output : tvm.Tensor
Expand All @@ -200,8 +213,7 @@ def hybrid_nms(data, sorted_index, valid_count,
box_indices = output_tensor((batch_size, num_anchors), "int32")
output = output_tensor((batch_size,
num_anchors,
box_data_length,),
data.dtype)
box_data_length,), data.dtype)

for i in range(batch_size):
if iou_threshold > 0:
Expand All @@ -217,7 +229,7 @@ def hybrid_nms(data, sorted_index, valid_count,
if 0 < top_k < valid_count[i]:
for j in parallel(valid_count[i] - nkeep):
for k in range(box_data_length):
output[i, j + nkeep, k] = -1.0
output[i, j + nkeep, k] = -one
box_indices[i, j + nkeep] = -1
# Apply nms
box_start_idx = coord_start
Expand All @@ -243,15 +255,15 @@ def hybrid_nms(data, sorted_index, valid_count,
b_b = output[batch_idx, box_b_idx, box_start_idx + 3]
b_l = output[batch_idx, box_b_idx, box_start_idx]
b_r = output[batch_idx, box_b_idx, box_start_idx + 2]
w = max(0.0, min(a_r, b_r) - max(a_l, b_l))
h = max(0.0, min(a_b, b_b) - max(a_t, b_t))
w = max(zero, min(a_r, b_r) - max(a_l, b_l))
h = max(zero, min(a_b, b_b) - max(a_t, b_t))
area = h * w
u = (a_r - a_l) * (a_b - a_t) + (b_r - b_l) * (b_b - b_t) - area
iou = 0.0 if u <= 0.0 else area / u
iou = zero if u <= zero else area / u
if iou >= iou_threshold:
output[i, k, score_index] = -1.0
output[i, k, score_index] = -one
if id_index >= 0:
output[i, k, id_index] = -1.0
output[i, k, id_index] = -one
box_indices[i, k] = -1
else:
for j in parallel(valid_count[i]):
Expand All @@ -261,16 +273,16 @@ def hybrid_nms(data, sorted_index, valid_count,
# Set invalid entry to be -1
for j in parallel(num_anchors - valid_count[i]):
for k in range(box_data_length):
output[i, j + valid_count[i], k] = -1.0
output[i, j + valid_count[i], k] = -one
box_indices[i, j + valid_count[i]] = -1
# Only return max_output_size valid boxes
num_valid_boxes = 0
if max_output_size > 0:
for j in parallel(valid_count[i]):
if output[i, j, 0] >= 0:
if output[i, j, 0] >= zero:
if num_valid_boxes == max_output_size:
for k in range(box_data_length):
output[i, j, k] = -1.0
output[i, j, k] = -one
box_indices[i, j] = -1
else:
num_valid_boxes += 1
Expand Down Expand Up @@ -356,13 +368,15 @@ def non_max_suppression(data, valid_count, max_output_size=-1,
sort_tensor = argsort(score_tensor, valid_count=valid_count, axis=1, is_ascend=False)
out, box_indices = hybrid_nms(data, sort_tensor, valid_count,
tvm.const(max_output_size, dtype="int32"),
tvm.const(iou_threshold, dtype="float32"),
tvm.const(iou_threshold, dtype=data.dtype),
tvm.const(force_suppress, dtype="bool"),
tvm.const(top_k, dtype="int32"),
tvm.const(coord_start, dtype="int32"),
tvm.const(id_index, dtype="int32"),
tvm.const(score_index, dtype="int32"))
tvm.const(score_index, dtype="int32"),
zero=tvm.const(0, dtype=data.dtype),
one=tvm.const(1, dtype=data.dtype))
if not return_indices and invalid_to_bottom:
out = hybrid_rearrange_out(out)
out = hybrid_rearrange_out(out, one=tvm.const(1, dtype=data.dtype))

return box_indices if return_indices else out

0 comments on commit 1c56c72

Please sign in to comment.