Skip to content

Commit

Permalink
batch issue fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 29, 2021
1 parent 19e3e84 commit 3fe91e8
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 6 deletions.
15 changes: 12 additions & 3 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,10 +855,19 @@ def _impl(inputs, attr, params, mod):
nmsed_boxes = _op.minimum(nmsed_boxes, _expr.const(1, dtype="float32"))

# Fill in invalid entries with 0
box_range = _op.arange(_expr.const(0, dtype="int64"), _op.cast(max_total_size, "int64"), dtype="int64")
box_range = _op.arange(
_expr.const(0, dtype="int64"), _op.cast(max_total_size, "int64"), dtype="int64"
)
batch_size = indices_shape[0]
box_range = _op.tile(box_range, _op.const([batch_size]))
valid_mask = _op.cast(_op.less(box_range, _op.expand_dims(num_detections, axis=1)), "float32")

if isinstance(batch_size, tvm.tir.Any):
box_range_2d = _op.tile(box_range, _op.concatenate([batch_size, 1]))
else:
box_range_2d = _op.tile(box_range, _op.const([batch_size, 1]))

valid_mask = _op.cast(
_op.less(box_range_2d, _op.expand_dims(num_detections, axis=1)), "float32"
)
nmsed_scores = nmsed_scores * valid_mask
nmsed_classes = nmsed_classes * valid_mask
nmsed_boxes = nmsed_boxes * _op.expand_dims(valid_mask, axis=2)
Expand Down
11 changes: 9 additions & 2 deletions python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1156,11 +1156,18 @@ def all_class_non_max_suppression(
selected_indices, selected_scores = collect_selected_indices_tf(
selected_indices, selected_scores, num_detections_per_batch, row_offsets
)

# TODO
# max_total_detections = reduction.max(num_total_detections)
# selected_scores = strided_slice(
# selected_scores, begin=[0, 0], end=[batch, reduction.max(num_total_detections)]
# selected_scores, begin=[0, 0], end=expand_dims(max_total_detections, 0)
# )
topk_indices = topk(selected_scores, k=max_detection_per_batch, axis=1, ret_type="indices")[0]
topk_indices = expand_dims(topk_indices, axis=0)
final_indices = gather_nd(selected_indices, topk_indices, batch_dims=1)

# TODO
# num_detections = minimum(num_total_detections, max_detection_per_batch)
return [final_indices, num_total_detections]
num_detections = num_total_detections

return [final_indices, num_detections]
2 changes: 1 addition & 1 deletion src/relay/op/vision/nms.cc
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ bool AllClassNMSRel(const Array<Type>& types, int num_inputs, const Attrs& attrs
ICHECK(param->max_total_size) << "max_total_size required for tf mode";
Integer max_total_size = param->max_total_size.value();
std::vector<IndexExpr> oshape{batch, max_total_size, 2};
std::vector<IndexExpr> countshape{1};
std::vector<IndexExpr> countshape{batch};
fields.push_back(TensorType(oshape, DataType::Int(64)));
fields.push_back(TensorType(countshape, DataType::Int(64)));
}
Expand Down

0 comments on commit 3fe91e8

Please sign in to comment.