Skip to content

Commit

Permalink
add max_total_size to attributes
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 29, 2021
1 parent 7218b2f commit 4a4b8df
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 14 deletions.
9 changes: 6 additions & 3 deletions include/tvm/relay/attrs/vision.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,14 @@ struct NonMaximumSuppressionAttrs : public tvm::AttrsNode<NonMaximumSuppressionA
/*! \brief Attributes used in non_maximum_suppression operator */
struct AllClassNonMaximumSuppressionAttrs
: public tvm::AttrsNode<AllClassNonMaximumSuppressionAttrs> {
Optional<Integer> max_total_size;
std::string output_format;
TVM_ATTR_FIELD(output_format).set_default("onnx").describe(
"Output format. onnx or tensorflow");

TVM_DECLARE_ATTRS(AllClassNonMaximumSuppressionAttrs,
"relay.attrs.AllClassNonMaximumSuppressionAttrs") {}
"relay.attrs.AllClassNonMaximumSuppressionAttrs") {
TVM_ATTR_FIELD(max_total_size).set_default(NullValue<Integer>()).describe("TODO");
TVM_ATTR_FIELD(output_format).set_default("onnx").describe("Output format. onnx or tensorflow");
}
};

/*! \brief Attributes used in roi_align operators */
Expand Down
3 changes: 3 additions & 0 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,7 @@ def _impl(inputs, attr, params, mod):
# Get parameter values
boxes = inputs[0]
scores = inputs[1]

try:
max_output_boxes_per_class = int(
np.atleast_1d(inputs[2].data.asnumpy().astype("int64"))[0]
Expand All @@ -809,6 +810,7 @@ def _impl(inputs, attr, params, mod):
)
except Exception:
max_output_boxes_per_class = inputs[2]

max_total_size = inputs[3]
iou_threshold = np.atleast_1d(inputs[4].data.numpy())[0]
score_threshold = np.atleast_1d(inputs[5].data.numpy())[0]
Expand All @@ -832,6 +834,7 @@ def _impl(inputs, attr, params, mod):
max_output_boxes_per_class,
iou_threshold,
score_threshold,
max_total_size,
output_format="tensorflow",
)

Expand Down
12 changes: 11 additions & 1 deletion python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1095,7 +1095,17 @@ def _compute_nms(attrs, inputs, out_type):
max_output_size = inputs[2]
iou_threshold = inputs[3]
score_threshold = inputs[4]
return topi_compute(inputs[0], inputs[1], max_output_size, iou_threshold, score_threshold)
max_total_size = attrs.max_total_size
output_format = attrs.output_format
return topi_compute(
inputs[0],
inputs[1],
max_output_size,
iou_threshold,
score_threshold,
max_total_size,
output_format,
)

return _compute_nms

Expand Down
16 changes: 14 additions & 2 deletions python/tvm/relay/op/vision/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,13 @@ def non_max_suppression(


def all_class_non_max_suppression(
boxes, scores, max_output_boxes_per_class=-1, iou_threshold=-1.0, score_threshold=-1.0, output_format="onnx"
boxes,
scores,
max_output_boxes_per_class=-1,
iou_threshold=-1.0,
score_threshold=-1.0,
max_total_size=None,
output_format="onnx",
):
"""Non-maximum suppression operator for object detection, corresponding to ONNX
NonMaxSuppression and TensorFlow combined_non_max_suppression.
Expand Down Expand Up @@ -195,6 +201,12 @@ def all_class_non_max_suppression(
score_threshold = expr.const(score_threshold, "float32")

out = _make.all_class_non_max_suppression(
boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, output_format
boxes,
scores,
max_output_boxes_per_class,
iou_threshold,
score_threshold,
max_total_size,
output_format,
)
return expr.TupleWrapper(out, 2)
4 changes: 2 additions & 2 deletions python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1066,7 +1066,7 @@ def collect_selected_indices_tf(selected_indices, selected_scores, num_detection


def all_class_non_max_suppression(
boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, output_format="onnx"
boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, max_total_size, output_format="onnx"
):
"""Non-maximum suppression operator for object detection, corresponding to ONNX
NonMaxSuppression and TensorFlow combined_non_max_suppression.
Expand Down Expand Up @@ -1127,7 +1127,7 @@ def all_class_non_max_suppression(
)
return [selected_indices, num_total_detections]

max_detection_per_batch = 100
max_detection_per_batch = max_total_size

selected_indices, selected_scores, num_detections = run_all_class_nms(
boxes,
Expand Down
27 changes: 21 additions & 6 deletions src/relay/op/vision/nms.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,19 +157,34 @@ bool AllClassNMSRel(const Array<Type>& types, int num_inputs, const Attrs& attrs
num_total_boxes = batch * num_classes * num_boxes;
}

// assign output type
const auto* param = attrs.as<AllClassNonMaximumSuppressionAttrs>();
CHECK(param);

std::vector<Type> fields;
std::vector<IndexExpr> oshape{num_total_boxes, 3};
fields.push_back(TensorType(oshape, DataType::Int(64)));
std::vector<IndexExpr> countshape{1};
fields.push_back(TensorType(countshape, DataType::Int(64)));
if (param->output_format == "onnx") {
std::vector<IndexExpr> oshape{num_total_boxes, 3};
std::vector<IndexExpr> countshape{1};
fields.push_back(TensorType(oshape, DataType::Int(64)));
fields.push_back(TensorType(countshape, DataType::Int(64)));
} else {
ICHECK(param->max_total_size) << "max_total_size required for tf mode";
Integer max_total_size = param->max_total_size.value();
std::vector<IndexExpr> oshape{batch, max_total_size, 2};
std::vector<IndexExpr> countshape{1};
fields.push_back(TensorType(oshape, DataType::Int(64)));
fields.push_back(TensorType(countshape, DataType::Int(64)));
}

reporter->Assign(types[5], TupleType(Array<Type>(fields)));
return true;
}

Expr MakeAllClassNMS(Expr boxes, Expr scores, Expr max_output_boxes_per_class, Expr iou_threshold,
Expr score_threshold) {
Expr score_threshold, Optional<Integer> max_total_size = NullValue<Integer>(),
std::string output_format = "onnx") {
auto attrs = make_object<AllClassNonMaximumSuppressionAttrs>();
attrs->max_total_size = std::move(max_total_size);
attrs->output_format = std::move(output_format);
static const Op& op = Op::Get("vision.all_class_non_max_suppression");
return Call(op, {boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold},
Attrs(attrs), {});
Expand Down

0 comments on commit 4a4b8df

Please sign in to comment.