Skip to content

Commit f2a97f7

Browse files
committed
fix bugs
1 parent 8219688 commit f2a97f7

File tree

3 files changed

+21
-54
lines changed

3 files changed

+21
-54
lines changed

supervision/detection/core.py

+5-8
Original file line numberDiff line numberDiff line change
@@ -1323,19 +1323,17 @@ def with_nms(
13231323
return self[indices]
13241324

13251325
def with_soft_nms(
1326-
self, threshold: float = 0.5, class_agnostic: bool = False, sigma: float = 0.5
1326+
self, sigma: float = 0.5, class_agnostic: bool = False
13271327
) -> Detections:
13281328
"""
13291329
Perform soft non-maximum suppression on the current set of object detections.
13301330
13311331
Args:
1332-
threshold (float): The intersection-over-union threshold
1333-
to use for non-maximum suppression. Defaults to 0.5.
1332+
sigma (float): The sigma value to use for the soft non-maximum suppression
1333+
algorithm. Defaults to 0.5.
13341334
class_agnostic (bool): Whether to perform class-agnostic
13351335
non-maximum suppression. If True, the class_id of each detection
13361336
will be ignored. Defaults to False.
1337-
sigma (float): The sigma value to use for the soft non-maximum suppression
1338-
algorithm. Defaults to 0.5.
13391337
13401338
Returns:
13411339
Detections: A new Detections object containing the subset of detections
@@ -1370,13 +1368,12 @@ def with_soft_nms(
13701368
soft_confidences = mask_soft_non_max_suppression(
13711369
predictions=predictions,
13721370
masks=self.mask,
1373-
iou_threshold=threshold,
13741371
sigma=sigma,
13751372
)
13761373
self.confidence = soft_confidences
13771374
else:
1378-
indices, soft_confidences = box_soft_non_max_suppression(
1379-
predictions=predictions, iou_threshold=threshold, sigma=sigma
1375+
soft_confidences = box_soft_non_max_suppression(
1376+
predictions=predictions, sigma=sigma
13801377
)
13811378
self.confidence = soft_confidences
13821379

supervision/detection/overlap_filter.py

+13-21
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ def resize_masks(masks: np.ndarray, max_dimension: int = 640) -> np.ndarray:
3939

4040

4141
def __prepare_data_for_mask_nms(
42-
iou_threshold: float,
4342
mask_dimension: int,
4443
masks: np.ndarray,
4544
predictions: np.ndarray,
@@ -48,8 +47,6 @@ def __prepare_data_for_mask_nms(
4847
Get IOUs from mask. Prepare the data for non-max suppression.
4948
5049
Args:
51-
iou_threshold (float): The intersection-over-union threshold
52-
to use for non-maximum suppression.
5350
mask_dimension (int): The dimension to which the masks should be
5451
resized before computing IOU values.
5552
masks (np.ndarray): A 3D array of binary masks corresponding to the predictions.
@@ -68,10 +65,6 @@ def __prepare_data_for_mask_nms(
6865
AssertionError: If `iou_threshold` is not within the closed range from
6966
`0` to `1`.
7067
"""
71-
assert 0 <= iou_threshold <= 1, (
72-
"Value of `iou_threshold` must be in the closed range from 0 to 1, "
73-
f"{iou_threshold} given."
74-
)
7568
rows, columns = predictions.shape
7669

7770
if columns == 5:
@@ -117,8 +110,12 @@ def mask_non_max_suppression(
117110
AssertionError: If `iou_threshold` is not within the closed
118111
range from `0` to `1`.
119112
"""
113+
assert 0 <= iou_threshold <= 1, (
114+
"Value of `iou_threshold` must be in the closed range from 0 to 1, "
115+
f"{iou_threshold} given."
116+
)
120117
_, categories, ious, rows, sort_index = __prepare_data_for_mask_nms(
121-
iou_threshold, mask_dimension, masks, predictions
118+
mask_dimension, masks, predictions
122119
)
123120

124121
keep = np.ones(rows, dtype=bool)
@@ -133,7 +130,6 @@ def mask_non_max_suppression(
133130
def mask_soft_non_max_suppression(
134131
predictions: np.ndarray,
135132
masks: np.ndarray,
136-
iou_threshold: float = 0.5,
137133
mask_dimension: int = 640,
138134
sigma: float = 0.5,
139135
) -> np.ndarray:
@@ -160,7 +156,7 @@ def mask_soft_non_max_suppression(
160156
0 < sigma < 1
161157
), f"Value of `sigma` must be greater than 0 and less than 1, {sigma} given."
162158
predictions, categories, ious, rows, sort_index = __prepare_data_for_mask_nms(
163-
iou_threshold, mask_dimension, masks, predictions
159+
mask_dimension, masks, predictions
164160
)
165161

166162
not_this_row = np.ones(rows)
@@ -175,14 +171,12 @@ def mask_soft_non_max_suppression(
175171

176172

177173
def __prepare_data_for_box_nsm(
178-
iou_threshold: float, predictions: np.ndarray
174+
predictions: np.ndarray,
179175
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, int, np.ndarray]:
180176
"""
181177
Prepare the data for non-max suppression.
182178
183179
Args:
184-
iou_threshold (float): The intersection-over-union threshold
185-
to use for non-maximum suppression.
186180
predictions (np.ndarray): An array of object detection predictions in the
187181
format of `(x_min, y_min, x_max, y_max, score)` or
188182
`(x_min, y_min, x_max, y_max, score, class)`. Shape: `(N, 5)` or `(N, 6)`,
@@ -198,10 +192,6 @@ def __prepare_data_for_box_nsm(
198192
199193
200194
"""
201-
assert 0 <= iou_threshold <= 1, (
202-
"Value of `iou_threshold` must be in the closed range from 0 to 1, "
203-
f"{iou_threshold} given."
204-
)
205195
rows, columns = predictions.shape
206196

207197
# add column #5 - category filled with zeros for agnostic nms
@@ -240,9 +230,11 @@ def box_non_max_suppression(
240230
AssertionError: If `iou_threshold` is not within the
241231
closed range from `0` to `1`.
242232
"""
243-
_, categories, ious, rows, sort_index = __prepare_data_for_box_nsm(
244-
iou_threshold, predictions
233+
assert 0 <= iou_threshold <= 1, (
234+
"Value of `iou_threshold` must be in the closed range from 0 to 1, "
235+
f"{iou_threshold} given."
245236
)
237+
_, categories, ious, rows, sort_index = __prepare_data_for_box_nsm(predictions)
246238

247239
keep = np.ones(rows, dtype=bool)
248240
for index, (iou, category) in enumerate(zip(ious, categories)):
@@ -258,7 +250,7 @@ def box_non_max_suppression(
258250

259251

260252
def box_soft_non_max_suppression(
261-
predictions: np.ndarray, iou_threshold: float = 0.5, sigma: float = 0.5
253+
predictions: np.ndarray, sigma: float = 0.5
262254
) -> np.ndarray:
263255
"""
264256
Perform Soft Non-Maximum Suppression (Soft-NMS) on object detection predictions.
@@ -283,7 +275,7 @@ def box_soft_non_max_suppression(
283275
0 < sigma < 1
284276
), f"Value of `sigma` must be greater than 0 and less than 1, {sigma} given."
285277
predictions, categories, ious, rows, sort_index = __prepare_data_for_box_nsm(
286-
iou_threshold, predictions
278+
predictions
287279
)
288280

289281
not_this_row = np.ones(rows)

test/detection/test_overlap_filter.py

+3-25
Original file line numberDiff line numberDiff line change
@@ -246,25 +246,22 @@ def test_box_non_max_suppression(
246246

247247

248248
@pytest.mark.parametrize(
249-
"predictions, iou_threshold, sigma, expected_result, exception",
249+
"predictions, sigma, expected_result, exception",
250250
[
251251
(
252252
np.empty(shape=(0, 5)),
253-
0.5,
254253
0.1,
255254
np.array([]),
256255
DoesNotRaise(),
257256
), # single box with no category
258257
(
259258
np.array([[10.0, 10.0, 40.0, 40.0, 0.8]]),
260-
0.5,
261259
0.8,
262260
np.array([0.8]),
263261
DoesNotRaise(),
264262
), # single box with no category
265263
(
266264
np.array([[10.0, 10.0, 40.0, 40.0, 0.8, 0]]),
267-
0.5,
268265
0.9,
269266
np.array([0.8]),
270267
DoesNotRaise(),
@@ -276,7 +273,6 @@ def test_box_non_max_suppression(
276273
[15.0, 15.0, 40.0, 40.0, 0.9],
277274
]
278275
),
279-
0.5,
280276
0.2,
281277
np.array([0.07176137, 0.9]),
282278
DoesNotRaise(),
@@ -288,7 +284,6 @@ def test_box_non_max_suppression(
288284
[15.0, 15.0, 40.0, 40.0, 0.9, 1],
289285
]
290286
),
291-
0.5,
292287
0.3,
293288
np.array([0.8, 0.9]),
294289
DoesNotRaise(),
@@ -300,7 +295,6 @@ def test_box_non_max_suppression(
300295
[15.0, 15.0, 40.0, 40.0, 0.9, 0],
301296
]
302297
),
303-
0.5,
304298
0.9,
305299
np.array([0.46814354, 0.9]),
306300
DoesNotRaise(),
@@ -313,7 +307,6 @@ def test_box_non_max_suppression(
313307
[10.0, 10.0, 40.0, 50.0, 0.85],
314308
]
315309
),
316-
0.5,
317310
0.7,
318311
np.array([0.42648529, 0.9, 0.53109062]),
319312
DoesNotRaise(),
@@ -327,7 +320,6 @@ def test_box_non_max_suppression(
327320
]
328321
),
329322
0.5,
330-
0.5,
331323
np.array([0.8, 0.9, 0.85]),
332324
DoesNotRaise(),
333325
), # three boxes with same category
@@ -339,7 +331,6 @@ def test_box_non_max_suppression(
339331
[10.0, 10.0, 40.0, 50.0, 0.85, 1],
340332
]
341333
),
342-
0.5,
343334
0.9,
344335
np.array([0.55491779, 0.9, 0.85]),
345336
DoesNotRaise(),
@@ -348,15 +339,12 @@ def test_box_non_max_suppression(
348339
)
349340
def test_box_soft_non_max_suppression(
350341
predictions: np.ndarray,
351-
iou_threshold: float,
352342
sigma: float,
353343
expected_result: Optional[np.ndarray],
354344
exception: Exception,
355345
) -> None:
356346
with exception:
357-
result = box_soft_non_max_suppression(
358-
predictions=predictions, iou_threshold=iou_threshold, sigma=sigma
359-
)
347+
result = box_soft_non_max_suppression(predictions=predictions, sigma=sigma)
360348
np.testing.assert_almost_equal(result, expected_result, decimal=5)
361349

362350

@@ -567,12 +555,11 @@ def test_mask_non_max_suppression(
567555

568556

569557
@pytest.mark.parametrize(
570-
"predictions, masks, iou_threshold, sigma, expected_result, exception",
558+
"predictions, masks, sigma, expected_result, exception",
571559
[
572560
(
573561
np.empty((0, 6)),
574562
np.empty((0, 5, 5)),
575-
0.5,
576563
0.1,
577564
np.array([]),
578565
DoesNotRaise(),
@@ -590,7 +577,6 @@ def test_mask_non_max_suppression(
590577
]
591578
]
592579
),
593-
0.5,
594580
0.2,
595581
np.array([0.8]),
596582
DoesNotRaise(),
@@ -608,7 +594,6 @@ def test_mask_non_max_suppression(
608594
]
609595
]
610596
),
611-
0.5,
612597
0.99,
613598
np.array([0.8]),
614599
DoesNotRaise(),
@@ -633,7 +618,6 @@ def test_mask_non_max_suppression(
633618
],
634619
]
635620
),
636-
0.5,
637621
0.8,
638622
np.array([0.8, 0.9]),
639623
DoesNotRaise(),
@@ -658,7 +642,6 @@ def test_mask_non_max_suppression(
658642
],
659643
]
660644
),
661-
0.4,
662645
0.6,
663646
np.array([0.3831756, 0.9]),
664647
DoesNotRaise(),
@@ -683,7 +666,6 @@ def test_mask_non_max_suppression(
683666
],
684667
]
685668
),
686-
0.5,
687669
0.9,
688670
np.array([0.8, 0.9]),
689671
DoesNotRaise(),
@@ -721,7 +703,6 @@ def test_mask_non_max_suppression(
721703
],
722704
]
723705
),
724-
0.5,
725706
0.3,
726707
np.array([0.02853919, 0.85, 0.9]),
727708
DoesNotRaise(),
@@ -759,7 +740,6 @@ def test_mask_non_max_suppression(
759740
],
760741
]
761742
),
762-
0.5,
763743
0.1,
764744
np.array([0.8, 0.85, 0.9]),
765745
DoesNotRaise(),
@@ -769,7 +749,6 @@ def test_mask_non_max_suppression(
769749
def test_mask_soft_non_max_suppression(
770750
predictions: np.ndarray,
771751
masks: np.ndarray,
772-
iou_threshold: float,
773752
sigma: float,
774753
expected_result: Optional[np.ndarray],
775754
exception: Exception,
@@ -778,7 +757,6 @@ def test_mask_soft_non_max_suppression(
778757
result = mask_soft_non_max_suppression(
779758
predictions=predictions,
780759
masks=masks,
781-
iou_threshold=iou_threshold,
782760
sigma=sigma,
783761
)
784762
np.testing.assert_almost_equal(result, expected_result, decimal=6)

0 commit comments

Comments
 (0)