diff --git a/docs/predict.md b/docs/predict.md index 8e27f0c3b..da8ee7a75 100644 --- a/docs/predict.md +++ b/docs/predict.md @@ -69,3 +69,31 @@ result = predict( ) ``` + +- Exclude custom classes on inference: + +```python +from sahi.predict import get_sliced_prediction +from sahi import AutoDetectionModel + +# init a model +detection_model = AutoDetectionModel.from_pretrained(...) + +# define the class names to exclude from custom model inference +exclude_classes_by_name = ["car"] + +# or exclude classes by its custom id +exclude_classes_by_id = [0] + +result = get_sliced_prediction( + image, + detection_model, + slice_height = 256, + slice_width = 256, + overlap_height_ratio = 0.2, + overlap_width_ratio = 0.2, + exclude_classes_by_name = exclude_classes_by_name + # exclude_classes_by_id = exclude_classes_by_id +) + +```