Skip to content

Commit fa705eb

Browse files
gguzzyfcakyonCopilot
authored
Exclude classes from inference using pretrained or custom models (#1104)
Co-authored-by: fatih akyon <34196005+fcakyon@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 9320c73 commit fa705eb

File tree

1 file changed

+54
-2
lines changed

1 file changed

+54
-2
lines changed

sahi/predict.py

+54-2
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,24 @@
5757
logger = logging.getLogger(__name__)
5858

5959

60+
def filter_predictions(object_prediction_list, exclude_classes_by_name, exclude_classes_by_id):
61+
return [
62+
obj_pred
63+
for obj_pred in object_prediction_list
64+
if obj_pred.category.name not in (exclude_classes_by_name or [])
65+
and obj_pred.category.id not in (exclude_classes_by_id or [])
66+
]
67+
68+
6069
def get_prediction(
6170
image,
6271
detection_model,
6372
shift_amount: list = [0, 0],
6473
full_shape=None,
6574
postprocess: Optional[PostprocessPredictions] = None,
6675
verbose: int = 0,
76+
exclude_classes_by_name: Optional[List[str]] = None,
77+
exclude_classes_by_id: Optional[List[int]] = None,
6778
) -> PredictionResult:
6879
"""
6980
Function for performing prediction for given image using given detection_model.
@@ -81,7 +92,12 @@ def get_prediction(
8192
verbose: int
8293
0: no print (default)
8394
1: print prediction duration
84-
95+
exclude_classes_by_name: Optional[List[str]]
96+
None: if no classes are excluded
97+
List[str]: set of classes to exclude using its/their class label name/s
98+
exclude_classes_by_id: Optional[List[int]]
99+
None: if no classes are excluded
100+
List[int]: set of classes to exclude using one or more IDs
85101
Returns:
86102
A dict with fields:
87103
object_prediction_list: a list of ObjectPrediction
@@ -105,6 +121,7 @@ def get_prediction(
105121
full_shape=full_shape,
106122
)
107123
object_prediction_list: List[ObjectPrediction] = detection_model.object_prediction_list
124+
object_prediction_list = filter_predictions(object_prediction_list, exclude_classes_by_name, exclude_classes_by_id)
108125

109126
# postprocess matching predictions
110127
if postprocess is not None:
@@ -142,6 +159,8 @@ def get_sliced_prediction(
142159
auto_slice_resolution: bool = True,
143160
slice_export_prefix: Optional[str] = None,
144161
slice_dir: Optional[str] = None,
162+
exclude_classes_by_name: Optional[List[str]] = None,
163+
exclude_classes_by_id: Optional[List[int]] = None,
145164
) -> PredictionResult:
146165
"""
147166
Function for slice image + get predicion for each slice + combine predictions in full image.
@@ -191,7 +210,12 @@ def get_sliced_prediction(
191210
Prefix for the exported slices. Defaults to None.
192211
slice_dir: str
193212
Directory to save the slices. Defaults to None.
194-
213+
exclude_classes_by_name: Optional[List[str]]
214+
None: if no classes are excluded
215+
List[str]: set of classes to exclude using its/their class label name/s
216+
exclude_classes_by_id: Optional[List[int]]
217+
None: if no classes are excluded
218+
List[int]: set of classes to exclude using one or more IDs
195219
Returns:
196220
A Dict with fields:
197221
object_prediction_list: a list of sahi.prediction.ObjectPrediction
@@ -257,6 +281,8 @@ def get_sliced_prediction(
257281
slice_image_result.original_image_height,
258282
slice_image_result.original_image_width,
259283
],
284+
exclude_classes_by_name=exclude_classes_by_name,
285+
exclude_classes_by_id=exclude_classes_by_id,
260286
)
261287
# convert sliced predictions to full predictions
262288
for object_prediction in prediction_result.object_prediction_list:
@@ -278,6 +304,8 @@ def get_sliced_prediction(
278304
slice_image_result.original_image_width,
279305
],
280306
postprocess=None,
307+
exclude_classes_by_name=exclude_classes_by_name,
308+
exclude_classes_by_id=exclude_classes_by_id,
281309
)
282310
object_prediction_list.extend(prediction_result.object_prediction_list)
283311

@@ -380,6 +408,8 @@ def predict(
380408
verbose: int = 1,
381409
return_dict: bool = False,
382410
force_postprocess_type: bool = False,
411+
exclude_classes_by_name: Optional[List[str]] = None,
412+
exclude_classes_by_id: Optional[List[int]] = None,
383413
**kwargs,
384414
):
385415
"""
@@ -466,6 +496,12 @@ def predict(
466496
If True, returns a dict with 'export_dir' field.
467497
force_postprocess_type: bool
468498
If True, auto postprocess check will e disabled
499+
exclude_classes_by_name: Optional[List[str]]
500+
None: if no classes are excluded
501+
List[str]: set of classes to exclude using its/their class label name/s
502+
exclude_classes_by_id: Optional[List[int]]
503+
None: if no classes are excluded
504+
List[int]: set of classes to exclude using one or more IDs
469505
"""
470506
# assert prediction type
471507
if no_standard_prediction and no_sliced_prediction:
@@ -574,6 +610,8 @@ def predict(
574610
postprocess_match_threshold=postprocess_match_threshold,
575611
postprocess_class_agnostic=postprocess_class_agnostic,
576612
verbose=1 if verbose else 0,
613+
exclude_classes_by_name=exclude_classes_by_name,
614+
exclude_classes_by_id=exclude_classes_by_id,
577615
)
578616
object_prediction_list = prediction_result.object_prediction_list
579617
if prediction_result.durations_in_seconds:
@@ -587,6 +625,8 @@ def predict(
587625
full_shape=None,
588626
postprocess=None,
589627
verbose=0,
628+
exclude_classes_by_name=exclude_classes_by_name,
629+
exclude_classes_by_id=exclude_classes_by_id,
590630
)
591631
object_prediction_list = prediction_result.object_prediction_list
592632

@@ -753,6 +793,8 @@ def predict_fiftyone(
753793
postprocess_match_threshold: float = 0.5,
754794
postprocess_class_agnostic: bool = False,
755795
verbose: int = 1,
796+
exclude_classes_by_name: Optional[List[str]] = None,
797+
exclude_classes_by_id: Optional[List[int]] = None,
756798
):
757799
"""
758800
Performs prediction for all present images in given folder.
@@ -811,6 +853,12 @@ def predict_fiftyone(
811853
verbose: int
812854
0: no print
813855
1: print slice/prediction durations, number of slices, model loading/file exporting durations
856+
exclude_classes_by_name: Optional[List[str]]
857+
None: if no classes are excluded
858+
List[str]: set of classes to exclude using its/their class label name/s
859+
exclude_classes_by_id: Optional[List[int]]
860+
None: if no classes are excluded
861+
List[int]: set of classes to exclude using one or more IDs
814862
"""
815863
check_requirements(["fiftyone"])
816864

@@ -863,6 +911,8 @@ def predict_fiftyone(
863911
postprocess_match_metric=postprocess_match_metric,
864912
postprocess_class_agnostic=postprocess_class_agnostic,
865913
verbose=verbose,
914+
exclude_classes_by_name=exclude_classes_by_name,
915+
exclude_classes_by_id=exclude_classes_by_id,
866916
)
867917
durations_in_seconds["slice"] += prediction_result.durations_in_seconds["slice"]
868918
else:
@@ -874,6 +924,8 @@ def predict_fiftyone(
874924
full_shape=None,
875925
postprocess=None,
876926
verbose=0,
927+
exclude_classes_by_name=exclude_classes_by_name,
928+
exclude_classes_by_id=exclude_classes_by_id,
877929
)
878930
durations_in_seconds["prediction"] += prediction_result.durations_in_seconds["prediction"]
879931

0 commit comments

Comments
 (0)