Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

efficientNMSPlugin: support class-independent nms with new parameter class_agnostic #2645

Open
wants to merge 3 commits into
base: release/8.6
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,23 @@ versions:
- max_output_boxes
- background_class
- score_activation
- class_agnostic
- box_coding
attribute_types:
score_threshold: float32
iou_threshold: float32
max_output_boxes: int32
background_class: int32
score_activation: int32
class_agnostic: int32
box_coding: int32
attribute_length:
score_threshold: 1
iou_threshold: 1
max_output_boxes: 1
background_class: 1
score_activation: 1
class_agnostic: 1
box_coding: 1
attribute_options:
score_threshold:
Expand All @@ -40,6 +43,9 @@ versions:
score_activation:
- 0
- 1
class_agnostic:
- 0
- 1
box_coding:
- 0
- 1
Expand Down
5 changes: 5 additions & 0 deletions plugin/efficientNMSPlugin/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ The following four output tensors are generated:
|`int` |`max_output_boxes` |The maximum number of detections to output per image.
|`int` |`background_class` |The label ID for the background class. If there is no background class, set it to `-1`.
|`bool` |`score_activation` * |Set to true to apply sigmoid activation to the confidence scores during NMS operation.
|`bool` |`class_agnostic` |Set to true to do class-independent NMS; otherwise, boxes of different classes would be considered separately during NMS.
|`int` |`box_coding` |Coding type used for boxes (and anchors if applicable), 0 = BoxCorner, 1 = BoxCenterSize.

Parameters marked with a `*` have a non-negligible effect on runtime latency. See the [Performance Tuning](#performance-tuning) section below for more details on how to set them optimally.
Expand Down Expand Up @@ -134,6 +135,10 @@ The algorithm is highly sensitive to the selected `score_threshold` parameter. W

Depending on network configuration, it is usually more efficient to provide raw scores (pre-sigmoid) to the NMS plugin scores input, and enable the `score_activation` parameter. Doing so applies a sigmoid activation only to the last `max_output_boxes` selected scores, instead of all the predicted scores, largely reducing the computational cost.

#### Class Independent NMS

Some object detection networks/architectures like YOLO series need to use class-independent NMS operations. If `class_agnostic` is enabled, class-independent NMS is performed; otherwise, different classes would do NMS separately.

#### Using the Fused Box Decoder

When using networks with many anchors, such as EfficientDet or SSD, it may be more efficient to do box decoding within the NMS plugin. For this, pass the raw box predictions as the boxes input, and the default anchor coordinates as the optional third input to the plugin.
Expand Down
6 changes: 5 additions & 1 deletion plugin/efficientNMSPlugin/efficientNMSInference.cu
Original file line number Diff line number Diff line change
Expand Up @@ -314,12 +314,16 @@ __global__ void EfficientNMS(EfficientNMSParameters param, const int* topNumData

for (int tile = 0; tile < numTiles; tile++)
{
bool check_class = true;
samurdhikaru marked this conversation as resolved.
Show resolved Hide resolved
if (!param.classAgnostic)
check_class = threadClass[tile] == testClass;

// IOU
if (boxIdx[tile] > i && // Make sure two different boxes are being tested, and that it's a higher index;
boxIdx[tile] < numSelectedBoxes && // Make sure the box is within numSelectedBoxes;
blockState == 1 && // Signal that allows IOU checks to be performed;
threadState[tile] == 0 && // Make sure this box hasn't been either dropped or kept already;
threadClass[tile] == testClass && // Compare only boxes of matching classes;
check_class && // Compare only boxes of matching classes when classAgnostic is false;
lte_mp(threadScore[tile], testScore) && // Make sure the sorting order of scores is as expected;
IOU<T>(param, threadBox[tile], testBox) >= param.iouThreshold) // And... IOU overlap.
{
Expand Down
1 change: 1 addition & 0 deletions plugin/efficientNMSPlugin/efficientNMSParameters.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ struct EfficientNMSParameters
bool scoreSigmoid = false;
bool clipBoxes = false;
int boxCoding = 0;
bool classAgnostic = false;

// Related to NMS Internals
int numSelectedBoxes = 4096;
Expand Down
5 changes: 5 additions & 0 deletions plugin/efficientNMSPlugin/efficientNMSPlugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,7 @@ EfficientNMSPluginCreator::EfficientNMSPluginCreator()
mPluginAttributes.emplace_back(PluginField("max_output_boxes", nullptr, PluginFieldType::kINT32, 1));
mPluginAttributes.emplace_back(PluginField("background_class", nullptr, PluginFieldType::kINT32, 1));
mPluginAttributes.emplace_back(PluginField("score_activation", nullptr, PluginFieldType::kINT32, 1));
mPluginAttributes.emplace_back(PluginField("class_agnostic", nullptr, PluginFieldType::kINT32, 1));
mPluginAttributes.emplace_back(PluginField("box_coding", nullptr, PluginFieldType::kINT32, 1));
mFC.nbFields = mPluginAttributes.size();
mFC.fields = mPluginAttributes.data();
Expand Down Expand Up @@ -493,6 +494,10 @@ IPluginV2DynamicExt* EfficientNMSPluginCreator::createPlugin(const char* name, c
PLUGIN_VALIDATE(scoreSigmoid == 0 || scoreSigmoid == 1);
mParam.scoreSigmoid = static_cast<bool>(scoreSigmoid);
}
if (!strcmp(attrName, "class_agnostic"))
{
mParam.classAgnostic = *(static_cast<bool const*>(fields[i].data));
}
if (!strcmp(attrName, "box_coding"))
{
PLUGIN_VALIDATE(fields[i].type == PluginFieldType::kINT32);
Expand Down
1 change: 1 addition & 0 deletions samples/python/detectron2/create_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ def NMS(self, boxes, scores, anchors, background_class, score_activation, max_pr
'score_threshold': max(0.01, score_threshold),
'iou_threshold': iou_threshold,
'score_activation': score_activation,
'class_agnostic': False,
'box_coding': 1,
}
)
Expand Down
1 change: 1 addition & 0 deletions samples/python/efficientdet/create_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ def get_anchor_np(output_idx, op):
'score_threshold': max(0.01, score_threshold), # Keep threshold to at least 0.01 for better efficiency
'iou_threshold': iou_threshold,
'score_activation': True,
'class_agnostic': False,
'box_coding': 1,
}
nms_output_classes_dtype = np.int32
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ def NMS(self, box_net_tensor, class_net_tensor, anchors_tensor, background_class
'score_threshold': max(0.01, score_threshold),
'iou_threshold': iou_threshold,
'score_activation': score_activation,
'class_agnostic': False,
'box_coding': 1,
}
)
Expand Down