Skip to content

Commit

Permalink
[FRONTEND][TFLITE] Add support for TFLite's regular NMS operator (apa…
Browse files Browse the repository at this point in the history
…che#15117)

This PR adds support of regular NMS operator.

Open questions:

    1. How to properly test added functionality?
    Other NMS implementations, e.g., fast NMS, use a TF frozen graph from TF official website to convert a model to TFLite and keep NMS operations only. In order to create a similar test, we need to find an archive on TF official website that contains a frozen graph of a model compiled with --use-regular-nms=True flag. We haven't found it yet, so any help is appreciated.
    2. Regular NMS requires two sort operations:
        Sorting the scores after selecting scores above nms_score_threshold. This PR implements this with a simple bubble sort in order to prove the algorithm's semantics. We tried to replace it with tvm.contrib.sort.argsort. It works well when testing the regular NMS with run_tvm_graph as fast NMS test does or building and running the regular NMS with llvm target. At the same time, it fails to build (error is provided below) when target=ethos-u,cmsis-nn,c. It seems that __tvm_module_ctx variable is only being initialized when cpp runtime is chosen.
        The error:
        error: '__tvm_module_ctx' undeclared (first use in this function) 203 | if (TVMBackendGetFuncFromEnv(__tvm_module_ctx, "tvm.contrib.sort.argsort", &tvm_contrib_sort_argsort_packed) != 0) {
        Sorting the scores of previous and current NMS steps. There are two alternatives here:
            implement some sorting algorithm as part of hybrid script (to replace current bubble sort)
            save the result of each NMS step and use argsort after the hybrid script part. This approach has a drawback as it requires significant amount of memory to store the results of each NMS step.
  • Loading branch information
ilyag-grovety authored Jul 31, 2023
1 parent 0556653 commit 619bb1d
Show file tree
Hide file tree
Showing 14 changed files with 695 additions and 49 deletions.
25 changes: 25 additions & 0 deletions include/tvm/relay/attrs/vision.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,17 @@ struct MultiBoxTransformLocAttrs : public tvm::AttrsNode<MultiBoxTransformLocAtt
bool clip;
double threshold;
Array<IndexExpr> variances;
bool keep_background;

TVM_DECLARE_ATTRS(MultiBoxTransformLocAttrs, "relay.attrs.MultiBoxTransformLocAttrs") {
TVM_ATTR_FIELD(clip).set_default(true).describe("Clip out-of-boundary boxes.");
TVM_ATTR_FIELD(threshold).set_default(0.01).describe("Threshold to be a positive prediction.");
TVM_ATTR_FIELD(variances)
.set_default(Array<IndexExpr>({0.1f, 0.1f, 0.2f, 0.2f}))
.describe("Variances to be decoded from box regression output.");
TVM_ATTR_FIELD(keep_background)
.set_default(false)
.describe("Whether to keep boxes detected as background or not");
}
};

Expand Down Expand Up @@ -129,6 +133,27 @@ struct AllClassNonMaximumSuppressionAttrs
}
};

/*! \brief Attributes used in regular_non_maximum_suppression operator */
struct RegularNonMaximumSuppressionAttrs
: public tvm::AttrsNode<RegularNonMaximumSuppressionAttrs> {
int32_t max_detections_per_class;
int32_t max_detections;
int32_t num_classes;
double iou_threshold;
double score_threshold;

TVM_DECLARE_ATTRS(RegularNonMaximumSuppressionAttrs,
"relay.attrs.RegularNonMaximumSuppressionAttrs") {
TVM_ATTR_FIELD(max_detections_per_class)
.describe("The maxinum number of output selected boxes per class.");
TVM_ATTR_FIELD(max_detections).describe("The maxinum number of output selected boxes.");
TVM_ATTR_FIELD(num_classes).describe("The number of classes without background.");
TVM_ATTR_FIELD(iou_threshold).describe("The IoU threshold for box the overlap test.");
TVM_ATTR_FIELD(score_threshold)
.describe("Score threshold to filter out low score boxes early.");
}
};

/*! \brief Attributes used in roi_align operators */
struct ROIAlignAttrs : public tvm::AttrsNode<ROIAlignAttrs> {
Array<IndexExpr> pooled_size;
Expand Down
45 changes: 32 additions & 13 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -3443,12 +3443,7 @@ def convert_detection_postprocess(self, op):
flexbuffer = op.CustomOptionsAsNumpy().tobytes()
custom_options = FlexBufferDecoder(flexbuffer).decode()

if "use_regular_nms" in custom_options:
if custom_options["use_regular_nms"]:
raise tvm.error.OpAttributeUnImplemented(
"use_regular_nms=True is not yet supported for operator "
"TFLite_Detection_PostProcess."
)
use_regular_nms = "use_regular_nms" in custom_options and custom_options["use_regular_nms"]

inputs = self.get_input_tensors(op)
assert len(inputs) == 3, "inputs length should be 3"
Expand Down Expand Up @@ -3481,15 +3476,14 @@ def convert_detection_postprocess(self, op):
input_zero_point=inputs[2].qnn_params["zero_point"],
)

# reshape the cls_pred and loc_prob tensors so
# they can be consumed by multibox_transform_loc
cls_pred = _op.transpose(cls_pred, [0, 2, 1])
# loc_prob coords are in yxhw format
# need to convert to xywh
loc_coords = _op.split(loc_prob, 4, axis=2)
loc_prob = _op.concatenate(
[loc_coords[1], loc_coords[0], loc_coords[3], loc_coords[2]], axis=2
)
# reshape loc_prob tensor so is can be consumed by
# multibox_transform_loc
loc_prob = _op.reshape(loc_prob, [batch_size, anchor_boxes * 4])

# anchor coords are in yxhw format
Expand All @@ -3511,13 +3505,41 @@ def convert_detection_postprocess(self, op):
# attributes for multibox_transform_loc
multibox_transform_loc_attrs = {}
multibox_transform_loc_attrs["clip"] = False
multibox_transform_loc_attrs["threshold"] = custom_options["nms_score_threshold"]
multibox_transform_loc_attrs["threshold"] = (
0.0 if use_regular_nms else custom_options["nms_score_threshold"]
)
multibox_transform_loc_attrs["variances"] = (
1 / custom_options["x_scale"],
1 / custom_options["y_scale"],
1 / custom_options["w_scale"],
1 / custom_options["h_scale"],
)
multibox_transform_loc_attrs["keep_background"] = use_regular_nms

ret = _op.vision.multibox_transform_loc(
# reshape cls_pred so it can be consumed by
# multibox_transform_loc
_op.transpose(cls_pred, [0, 2, 1]),
loc_prob,
anchor_expr,
**multibox_transform_loc_attrs,
)

if use_regular_nms:
# box coordinates need to be converted from ltrb to (ymin, xmin, ymax, xmax)
_, transformed_boxes = _op.split(ret[0], (2,), axis=2)
box_l, box_t, box_r, box_b = _op.split(transformed_boxes, 4, axis=2)
transformed_boxes = _op.concatenate([box_t, box_l, box_b, box_r], axis=2)

return _op.vision.regular_non_max_suppression(
boxes=transformed_boxes,
scores=cls_pred,
max_detections_per_class=custom_options["detections_per_class"],
max_detections=custom_options["max_detections"],
num_classes=custom_options["num_classes"],
iou_threshold=custom_options["nms_iou_threshold"],
score_threshold=custom_options["nms_score_threshold"],
)

# attributes for non_max_suppression
non_max_suppression_attrs = {}
Expand All @@ -3528,9 +3550,6 @@ def convert_detection_postprocess(self, op):
non_max_suppression_attrs["max_output_size"] = custom_options["max_detections"]
non_max_suppression_attrs["invalid_to_bottom"] = False

ret = _op.vision.multibox_transform_loc(
cls_pred, loc_prob, anchor_expr, **multibox_transform_loc_attrs
)
ret = _op.vision.non_max_suppression(ret[0], ret[1], ret[1], **non_max_suppression_attrs)
ret = _op.vision.get_valid_counts(ret, 0)
valid_count = ret[0]
Expand Down
34 changes: 33 additions & 1 deletion python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1184,7 +1184,10 @@ def _compute_multibox_transform_loc(attrs, inputs, _):
clip = bool(get_const_int(attrs.clip))
threshold = get_const_float(attrs.threshold)
variances = get_float_tuple(attrs.variances)
return topi_compute(inputs[0], inputs[1], inputs[2], clip, threshold, variances)
keep_background = bool(get_const_int(attrs.keep_background))
return topi_compute(
inputs[0], inputs[1], inputs[2], clip, threshold, variances, keep_background
)

return _compute_multibox_transform_loc

Expand Down Expand Up @@ -1316,6 +1319,35 @@ def all_class_nms_strategy(attrs, inputs, out_type, target):
return strategy


def wrap_compute_regular_nms(topi_compute):
"""wrap regular nms topi compute"""

def _compute_nms(attrs, inputs, out_type):
return topi_compute(
inputs[0],
inputs[1],
attrs.max_detections_per_class,
attrs.max_detections,
attrs.num_classes,
attrs.iou_threshold,
attrs.score_threshold,
)

return _compute_nms


@override_native_generic_func("regular_non_max_suppression_strategy")
def regular_nms_strategy(attrs, inputs, out_type, target):
"""regular nms generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_regular_nms(topi.vision.regular_non_max_suppression),
wrap_topi_schedule(topi.generic.schedule_nms),
name="regular_nms.generic",
)
return strategy


# roi_align
def wrap_compute_roi_align(topi_compute):
"""wrap roi_align topi compute"""
Expand Down
30 changes: 30 additions & 0 deletions python/tvm/relay/op/vision/_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@
reg.register_strategy("vision.all_class_non_max_suppression", strategy.all_class_nms_strategy)
reg.register_pattern("vision.all_class_non_max_suppression", OpPattern.OPAQUE)

reg.register_strategy("vision.regular_non_max_suppression", strategy.regular_nms_strategy)
reg.register_pattern("vision.regular_non_max_suppression", OpPattern.OPAQUE)


@script
def _get_valid_counts_shape_func(data_shape):
Expand Down Expand Up @@ -122,6 +125,33 @@ def all_class_nms_shape_func(attrs, inputs, _):
return _all_class_nms_shape_func_tf(inputs[0], inputs[1])


@script
def _regular_nms_shape_func(boxes_shape, scores_shape, attrs):
out_boxes_shape = output_tensor((3,), "int64")
out_classes_shape = output_tensor((2,), "int64")
out_scores_shape = output_tensor((2,), "int64")
out_num_detections_shape = output_tensor((1,), "int64")

out_boxes_shape[0] = boxes_shape[0]
out_boxes_shape[1] = int64(attrs.max_detections)
out_boxes_shape[2] = int64(4)

out_classes_shape[0] = boxes_shape[0]
out_classes_shape[1] = int64(attrs.max_detections)

out_scores_shape[0] = boxes_shape[0]
out_scores_shape[1] = int64(attrs.max_detections)

out_num_detections_shape[0] = boxes_shape[0]

return out_boxes_shape, out_classes_shape, out_scores_shape, out_num_detections_shape


@reg.register_shape_func("vision.regular_non_max_suppression", False)
def regular_nms_shape_func(attrs, inputs, _):
return _regular_nms_shape_func(inputs[0], inputs[1], attrs)


@script
def _roi_align_shape_func_nchw(data_shape, rois_shape, pooled_size):
out = output_tensor((4,), "int64")
Expand Down
22 changes: 20 additions & 2 deletions python/tvm/relay/op/vision/multibox.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,13 @@ def multibox_prior(


def multibox_transform_loc(
cls_prob, loc_pred, anchor, clip=True, threshold=0.01, variances=(0.1, 0.1, 0.2, 0.2)
cls_prob,
loc_pred,
anchor,
clip=True,
threshold=0.01,
variances=(0.1, 0.1, 0.2, 0.2),
keep_background=False,
):
"""Location transformation for multibox detection
Expand All @@ -77,10 +83,22 @@ def multibox_transform_loc(
variances : Tuple of float, optional
variances to be decoded from box regression output.
keep_background : boolean, optional
Whether to keep boxes detected as background or not.
Returns
-------
ret : tuple of tvm.relay.Expr
"""
return expr.TupleWrapper(
_make.multibox_transform_loc(cls_prob, loc_pred, anchor, clip, threshold, variances), 2
_make.multibox_transform_loc(
cls_prob,
loc_pred,
anchor,
clip,
threshold,
variances,
keep_background,
),
2,
)
59 changes: 59 additions & 0 deletions python/tvm/relay/op/vision/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,3 +226,62 @@ def all_class_non_max_suppression(
return expr.TupleWrapper(out, 2)

return expr.TupleWrapper(out, 3)


def regular_non_max_suppression(
boxes,
scores,
max_detections_per_class,
max_detections,
num_classes,
iou_threshold,
score_threshold,
):
"""Regular non-maximum suppression operator for object detection, corresponding to TFLite's
regular NMS. NMS is performed for each class separately.
Parameters
----------
boxes : relay.Expr
3-D tensor with shape (batch_size, num_boxes, 4). The four values in boxes
encode (ymin, xmin, ymax, xmax) coordinates of a box
scores: relay.Expr
3-D tensor with shape (batch_size, num_boxes, num_classes_with_background)
max_detections_per_class : int
The maxinum number of output selected boxes per class
max_detections : int
The maxinum number of output selected boxes
num_classes : int
The number of classes without background
iou_threshold : float
IoU test threshold
score_threshold : float
Score threshold to filter out low score boxes early
Returns
-------
out : relay.Tuple
The output is a relay.Tuple of four tensors. The first is `detection_boxes` of size
`(batch_size, max_detections , 4)`, the second is `detection_classes` of size
`(batch_size, max_detections)`, the third is `detection_scores` of size
`(batch_size, max_detections)`, and the fourth is `num_detections` of size `(batch_size,)`
representing the total number of selected boxes per batch.
"""
return expr.TupleWrapper(
_make.regular_non_max_suppression(
boxes,
scores,
max_detections_per_class,
max_detections,
num_classes,
iou_threshold,
score_threshold,
),
4,
)
Loading

0 comments on commit 619bb1d

Please sign in to comment.