57
57
logger = logging .getLogger (__name__ )
58
58
59
59
60
+ def filter_predictions (object_prediction_list , exclude_classes_by_name , exclude_classes_by_id ):
61
+ return [
62
+ obj_pred
63
+ for obj_pred in object_prediction_list
64
+ if obj_pred .category .name not in (exclude_classes_by_name or [])
65
+ and obj_pred .category .id not in (exclude_classes_by_id or [])
66
+ ]
67
+
68
+
60
69
def get_prediction (
61
70
image ,
62
71
detection_model ,
63
72
shift_amount : list = [0 , 0 ],
64
73
full_shape = None ,
65
74
postprocess : Optional [PostprocessPredictions ] = None ,
66
75
verbose : int = 0 ,
76
+ exclude_classes_by_name : Optional [List [str ]] = None ,
77
+ exclude_classes_by_id : Optional [List [int ]] = None ,
67
78
) -> PredictionResult :
68
79
"""
69
80
Function for performing prediction for given image using given detection_model.
@@ -81,7 +92,12 @@ def get_prediction(
81
92
verbose: int
82
93
0: no print (default)
83
94
1: print prediction duration
84
-
95
+ exclude_classes_by_name: Optional[List[str]]
96
+ None: if no classes are excluded
97
+ List[str]: set of classes to exclude using its/their class label name/s
98
+ exclude_classes_by_id: Optional[List[int]]
99
+ None: if no classes are excluded
100
+ List[int]: set of classes to exclude using one or more IDs
85
101
Returns:
86
102
A dict with fields:
87
103
object_prediction_list: a list of ObjectPrediction
@@ -105,6 +121,7 @@ def get_prediction(
105
121
full_shape = full_shape ,
106
122
)
107
123
object_prediction_list : List [ObjectPrediction ] = detection_model .object_prediction_list
124
+ object_prediction_list = filter_predictions (object_prediction_list , exclude_classes_by_name , exclude_classes_by_id )
108
125
109
126
# postprocess matching predictions
110
127
if postprocess is not None :
@@ -142,6 +159,8 @@ def get_sliced_prediction(
142
159
auto_slice_resolution : bool = True ,
143
160
slice_export_prefix : Optional [str ] = None ,
144
161
slice_dir : Optional [str ] = None ,
162
+ exclude_classes_by_name : Optional [List [str ]] = None ,
163
+ exclude_classes_by_id : Optional [List [int ]] = None ,
145
164
) -> PredictionResult :
146
165
"""
147
166
Function for slice image + get predicion for each slice + combine predictions in full image.
@@ -191,7 +210,12 @@ def get_sliced_prediction(
191
210
Prefix for the exported slices. Defaults to None.
192
211
slice_dir: str
193
212
Directory to save the slices. Defaults to None.
194
-
213
+ exclude_classes_by_name: Optional[List[str]]
214
+ None: if no classes are excluded
215
+ List[str]: set of classes to exclude using its/their class label name/s
216
+ exclude_classes_by_id: Optional[List[int]]
217
+ None: if no classes are excluded
218
+ List[int]: set of classes to exclude using one or more IDs
195
219
Returns:
196
220
A Dict with fields:
197
221
object_prediction_list: a list of sahi.prediction.ObjectPrediction
@@ -257,6 +281,8 @@ def get_sliced_prediction(
257
281
slice_image_result .original_image_height ,
258
282
slice_image_result .original_image_width ,
259
283
],
284
+ exclude_classes_by_name = exclude_classes_by_name ,
285
+ exclude_classes_by_id = exclude_classes_by_id ,
260
286
)
261
287
# convert sliced predictions to full predictions
262
288
for object_prediction in prediction_result .object_prediction_list :
@@ -278,6 +304,8 @@ def get_sliced_prediction(
278
304
slice_image_result .original_image_width ,
279
305
],
280
306
postprocess = None ,
307
+ exclude_classes_by_name = exclude_classes_by_name ,
308
+ exclude_classes_by_id = exclude_classes_by_id ,
281
309
)
282
310
object_prediction_list .extend (prediction_result .object_prediction_list )
283
311
@@ -380,6 +408,8 @@ def predict(
380
408
verbose : int = 1 ,
381
409
return_dict : bool = False ,
382
410
force_postprocess_type : bool = False ,
411
+ exclude_classes_by_name : Optional [List [str ]] = None ,
412
+ exclude_classes_by_id : Optional [List [int ]] = None ,
383
413
** kwargs ,
384
414
):
385
415
"""
@@ -466,6 +496,12 @@ def predict(
466
496
If True, returns a dict with 'export_dir' field.
467
497
force_postprocess_type: bool
468
498
If True, auto postprocess check will e disabled
499
+ exclude_classes_by_name: Optional[List[str]]
500
+ None: if no classes are excluded
501
+ List[str]: set of classes to exclude using its/their class label name/s
502
+ exclude_classes_by_id: Optional[List[int]]
503
+ None: if no classes are excluded
504
+ List[int]: set of classes to exclude using one or more IDs
469
505
"""
470
506
# assert prediction type
471
507
if no_standard_prediction and no_sliced_prediction :
@@ -574,6 +610,8 @@ def predict(
574
610
postprocess_match_threshold = postprocess_match_threshold ,
575
611
postprocess_class_agnostic = postprocess_class_agnostic ,
576
612
verbose = 1 if verbose else 0 ,
613
+ exclude_classes_by_name = exclude_classes_by_name ,
614
+ exclude_classes_by_id = exclude_classes_by_id ,
577
615
)
578
616
object_prediction_list = prediction_result .object_prediction_list
579
617
if prediction_result .durations_in_seconds :
@@ -587,6 +625,8 @@ def predict(
587
625
full_shape = None ,
588
626
postprocess = None ,
589
627
verbose = 0 ,
628
+ exclude_classes_by_name = exclude_classes_by_name ,
629
+ exclude_classes_by_id = exclude_classes_by_id ,
590
630
)
591
631
object_prediction_list = prediction_result .object_prediction_list
592
632
@@ -753,6 +793,8 @@ def predict_fiftyone(
753
793
postprocess_match_threshold : float = 0.5 ,
754
794
postprocess_class_agnostic : bool = False ,
755
795
verbose : int = 1 ,
796
+ exclude_classes_by_name : Optional [List [str ]] = None ,
797
+ exclude_classes_by_id : Optional [List [int ]] = None ,
756
798
):
757
799
"""
758
800
Performs prediction for all present images in given folder.
@@ -811,6 +853,12 @@ def predict_fiftyone(
811
853
verbose: int
812
854
0: no print
813
855
1: print slice/prediction durations, number of slices, model loading/file exporting durations
856
+ exclude_classes_by_name: Optional[List[str]]
857
+ None: if no classes are excluded
858
+ List[str]: set of classes to exclude using its/their class label name/s
859
+ exclude_classes_by_id: Optional[List[int]]
860
+ None: if no classes are excluded
861
+ List[int]: set of classes to exclude using one or more IDs
814
862
"""
815
863
check_requirements (["fiftyone" ])
816
864
@@ -863,6 +911,8 @@ def predict_fiftyone(
863
911
postprocess_match_metric = postprocess_match_metric ,
864
912
postprocess_class_agnostic = postprocess_class_agnostic ,
865
913
verbose = verbose ,
914
+ exclude_classes_by_name = exclude_classes_by_name ,
915
+ exclude_classes_by_id = exclude_classes_by_id ,
866
916
)
867
917
durations_in_seconds ["slice" ] += prediction_result .durations_in_seconds ["slice" ]
868
918
else :
@@ -874,6 +924,8 @@ def predict_fiftyone(
874
924
full_shape = None ,
875
925
postprocess = None ,
876
926
verbose = 0 ,
927
+ exclude_classes_by_name = exclude_classes_by_name ,
928
+ exclude_classes_by_id = exclude_classes_by_id ,
877
929
)
878
930
durations_in_seconds ["prediction" ] += prediction_result .durations_in_seconds ["prediction" ]
879
931
0 commit comments