From 8a33c3a75ee2d511d3c218ab33500d30fe94cd86 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20P=C4=99czek?= Date: Thu, 2 Nov 2023 17:04:42 +0100 Subject: [PATCH 01/15] Add sample close to threshold strategy --- .../core/active_learning/configuration.py | 11 +- inference/core/active_learning/sampling.py | 179 +++++++++++++++++- 2 files changed, 184 insertions(+), 6 deletions(-) diff --git a/inference/core/active_learning/configuration.py b/inference/core/active_learning/configuration.py index ec76cae05..d42b030eb 100644 --- a/inference/core/active_learning/configuration.py +++ b/inference/core/active_learning/configuration.py @@ -6,8 +6,10 @@ RoboflowProjectMetadata, SamplingMethod, ) -from inference.core.active_learning.sampling import initialize_random_sampling -from inference.core.cache.base import BaseCache +from inference.core.active_learning.sampling import ( + initialize_close_to_threshold_sampling, + initialize_random_sampling, +) from inference.core.env import ACTIVE_LEARNING_ENABLED from inference.core.roboflow_api import ( get_roboflow_active_learning_configuration, @@ -16,7 +18,10 @@ ) from inference.core.utils.roboflow import get_model_id_chunks -TYPE2SAMPLING_INITIALIZERS = {"random_sampling": initialize_random_sampling} +TYPE2SAMPLING_INITIALIZERS = { + "random_sampling": initialize_random_sampling, + "close_to_threshold_sampling": initialize_close_to_threshold_sampling, +} def prepare_active_learning_configuration( diff --git a/inference/core/active_learning/sampling.py b/inference/core/active_learning/sampling.py index f639a7da9..5d316e1fb 100644 --- a/inference/core/active_learning/sampling.py +++ b/inference/core/active_learning/sampling.py @@ -1,6 +1,6 @@ import random from functools import partial -from typing import Any, Dict +from typing import Any, Dict, Set import numpy as np @@ -9,6 +9,7 @@ PredictionType, SamplingMethod, ) +from inference.core.constants import CLASSIFICATION_TASK def initialize_random_sampling(strategy_config: Dict[str, Any]) -> SamplingMethod: @@ -28,6 +29,178 @@ def sample_randomly( prediction_type: PredictionType, traffic_percentage: float, ) -> bool: - if random.random() >= traffic_percentage: + return random.random() < traffic_percentage + + +def initialize_close_to_threshold_sampling( + strategy_config: Dict[str, Any] +) -> SamplingMethod: + sample_function = partial( + sample_close_to_threshold, + selected_class_names=strategy_config["selected_class_names"], + threshold=strategy_config["threshold"], + epsilon=strategy_config["epsilon"], + only_top_classes=strategy_config["only_top_classes"], + minimum_objects_close_to_threshold=strategy_config[ + "minimum_objects_close_to_threshold" + ], + probability=strategy_config["probability"], + ) + return SamplingMethod( + name=strategy_config["name"], + sample=sample_function, + ) + + +def sample_close_to_threshold( + image: np.ndarray, + prediction: Prediction, + prediction_type: PredictionType, + selected_class_names: Set[str], + threshold: float, + epsilon: float, + only_top_classes: bool, + minimum_objects_close_to_threshold: int, + probability: float, +) -> bool: + if is_prediction_a_stub(prediction=prediction): + return False + is_close_to_threshold = prediction_is_close_to_threshold( + prediction=prediction, + prediction_type=prediction_type, + selected_class_names=selected_class_names, + threshold=threshold, + epsilon=epsilon, + only_top_classes=only_top_classes, + minimum_objects_close_to_threshold=minimum_objects_close_to_threshold, + ) + if not is_close_to_threshold: return False - return True + return random.random() < probability + + +def is_prediction_a_stub(prediction: Prediction) -> bool: + return prediction.get("is_stub", False) + + +def prediction_is_close_to_threshold( + prediction: Prediction, + prediction_type: PredictionType, + selected_class_names: Set[str], + threshold: float, + epsilon: float, + only_top_classes: bool, + minimum_objects_close_to_threshold: int, +) -> bool: + if CLASSIFICATION_TASK not in prediction_type: + return detection_prediction_is_close_to_threshold( + prediction=prediction, + selected_class_names=selected_class_names, + threshold=threshold, + epsilon=epsilon, + minimum_objects_close_to_threshold=minimum_objects_close_to_threshold, + ) + checker = multi_label_classification_prediction_is_close_to_threshold + if "top" in prediction: + checker = multi_class_classification_prediction_is_close_to_threshold + return checker( + prediction=prediction, + selected_class_names=selected_class_names, + threshold=threshold, + epsilon=epsilon, + only_top_classes=only_top_classes, + ) + + +def multi_class_classification_prediction_is_close_to_threshold( + prediction: Prediction, + selected_class_names: Set[str], + threshold: float, + epsilon: float, + only_top_classes: bool, +) -> bool: + if only_top_classes: + return ( + multi_class_classification_prediction_is_close_to_threshold_for_top_class( + prediction=prediction, + selected_class_names=selected_class_names, + threshold=threshold, + epsilon=epsilon, + ) + ) + for prediction_details in prediction["predictions"]: + if prediction_details["class"] not in selected_class_names: + continue + if close_to_threshold( + value=prediction_details["confidence"], threshold=threshold, epsilon=epsilon + ): + return True + return False + + +def multi_class_classification_prediction_is_close_to_threshold_for_top_class( + prediction: Prediction, + selected_class_names: Set[str], + threshold: float, + epsilon: float, +) -> bool: + if prediction["top"] not in selected_class_names: + return False + return abs(prediction["confidence"] - threshold) < epsilon + + +def multi_label_classification_prediction_is_close_to_threshold( + prediction: Prediction, + selected_class_names: Set[str], + threshold: float, + epsilon: float, + only_top_classes: bool, +) -> bool: + predicted_classes = set(prediction["predicted_classes"]) + for class_name, prediction_details in prediction["predictions"].items(): + if only_top_classes and class_name not in predicted_classes: + continue + if class_name not in selected_class_names: + continue + if close_to_threshold( + value=prediction_details["confidence"], threshold=threshold, epsilon=epsilon + ): + return True + return False + + +def detection_prediction_is_close_to_threshold( + prediction: Prediction, + selected_class_names: Set[str], + threshold: float, + epsilon: float, + minimum_objects_close_to_threshold: int, +) -> bool: + detections_close_to_threshold = count_detections_close_to_threshold( + prediction=prediction, + selected_class_names=selected_class_names, + threshold=threshold, + epsilon=epsilon, + ) + return detections_close_to_threshold >= minimum_objects_close_to_threshold + + +def count_detections_close_to_threshold( + prediction: Prediction, + selected_class_names: Set[str], + threshold: float, + epsilon: float, +) -> int: + counter = 0 + for prediction_details in prediction["predictions"]: + if prediction_details["class"] not in selected_class_names: + continue + if close_to_threshold( + value=prediction_details["confidence"], threshold=threshold, epsilon=epsilon + ): + counter += 1 + return counter + + +def close_to_threshold(value: float, threshold: float, epsilon: float) -> bool: + return abs(value - threshold) < epsilon From 400d3c96c68716b23d0681d2b63fd4baa30af2b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20P=C4=99czek?= Date: Thu, 2 Nov 2023 17:11:58 +0100 Subject: [PATCH 02/15] Push sampling logic into separate package --- .../core/active_learning/configuration.py | 4 +-- .../core/active_learning/samplers/__init__.py | 0 .../close_to_threshold.py} | 20 ------------ .../core/active_learning/samplers/random.py | 31 +++++++++++++++++++ .../core/active_learning/samplers/__init__.py | 0 .../test_random.py} | 6 ++-- 6 files changed, 36 insertions(+), 25 deletions(-) create mode 100644 inference/core/active_learning/samplers/__init__.py rename inference/core/active_learning/{sampling.py => samplers/close_to_threshold.py} (91%) create mode 100644 inference/core/active_learning/samplers/random.py create mode 100644 tests/inference/unit_tests/core/active_learning/samplers/__init__.py rename tests/inference/unit_tests/core/active_learning/{test_sampling.py => samplers/test_random.py} (91%) diff --git a/inference/core/active_learning/configuration.py b/inference/core/active_learning/configuration.py index d42b030eb..18a799117 100644 --- a/inference/core/active_learning/configuration.py +++ b/inference/core/active_learning/configuration.py @@ -6,10 +6,10 @@ RoboflowProjectMetadata, SamplingMethod, ) -from inference.core.active_learning.sampling import ( +from inference.core.active_learning.samplers.close_to_threshold import ( initialize_close_to_threshold_sampling, - initialize_random_sampling, ) +from inference.core.active_learning.samplers.random import initialize_random_sampling from inference.core.env import ACTIVE_LEARNING_ENABLED from inference.core.roboflow_api import ( get_roboflow_active_learning_configuration, diff --git a/inference/core/active_learning/samplers/__init__.py b/inference/core/active_learning/samplers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/inference/core/active_learning/sampling.py b/inference/core/active_learning/samplers/close_to_threshold.py similarity index 91% rename from inference/core/active_learning/sampling.py rename to inference/core/active_learning/samplers/close_to_threshold.py index 5d316e1fb..23825f1e7 100644 --- a/inference/core/active_learning/sampling.py +++ b/inference/core/active_learning/samplers/close_to_threshold.py @@ -12,26 +12,6 @@ from inference.core.constants import CLASSIFICATION_TASK -def initialize_random_sampling(strategy_config: Dict[str, Any]) -> SamplingMethod: - sample_function = partial( - sample_randomly, - traffic_percentage=strategy_config["traffic_percentage"], - ) - return SamplingMethod( - name=strategy_config["name"], - sample=sample_function, - ) - - -def sample_randomly( - image: np.ndarray, - prediction: Prediction, - prediction_type: PredictionType, - traffic_percentage: float, -) -> bool: - return random.random() < traffic_percentage - - def initialize_close_to_threshold_sampling( strategy_config: Dict[str, Any] ) -> SamplingMethod: diff --git a/inference/core/active_learning/samplers/random.py b/inference/core/active_learning/samplers/random.py new file mode 100644 index 000000000..f906a2e40 --- /dev/null +++ b/inference/core/active_learning/samplers/random.py @@ -0,0 +1,31 @@ +import random +from functools import partial +from typing import Any, Dict + +import numpy as np + +from inference.core.active_learning.entities import ( + Prediction, + PredictionType, + SamplingMethod, +) + + +def initialize_random_sampling(strategy_config: Dict[str, Any]) -> SamplingMethod: + sample_function = partial( + sample_randomly, + traffic_percentage=strategy_config["traffic_percentage"], + ) + return SamplingMethod( + name=strategy_config["name"], + sample=sample_function, + ) + + +def sample_randomly( + image: np.ndarray, + prediction: Prediction, + prediction_type: PredictionType, + traffic_percentage: float, +) -> bool: + return random.random() < traffic_percentage diff --git a/tests/inference/unit_tests/core/active_learning/samplers/__init__.py b/tests/inference/unit_tests/core/active_learning/samplers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/inference/unit_tests/core/active_learning/test_sampling.py b/tests/inference/unit_tests/core/active_learning/samplers/test_random.py similarity index 91% rename from tests/inference/unit_tests/core/active_learning/test_sampling.py rename to tests/inference/unit_tests/core/active_learning/samplers/test_random.py index 6ca7c1662..a2417a95f 100644 --- a/tests/inference/unit_tests/core/active_learning/test_sampling.py +++ b/tests/inference/unit_tests/core/active_learning/samplers/test_random.py @@ -4,11 +4,11 @@ import numpy as np import pytest -from inference.core.active_learning.sampling import initialize_random_sampling -from inference.core.active_learning import sampling +from inference.core.active_learning.samplers.random import initialize_random_sampling +from inference.core.active_learning.samplers import random -@mock.patch.object(sampling.random, "random") +@mock.patch.object(random.random, "random") def test_initialize_random_sampling_when_config_is_valid( random_mock: MagicMock, ) -> None: From 2b363a4494c42ff99008490cd5d4d4ff052b4260 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20P=C4=99czek?= Date: Thu, 2 Nov 2023 17:34:05 +0100 Subject: [PATCH 03/15] Add sampling based on class names --- .../samplers/contains_classes.py | 49 +++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 inference/core/active_learning/samplers/contains_classes.py diff --git a/inference/core/active_learning/samplers/contains_classes.py b/inference/core/active_learning/samplers/contains_classes.py new file mode 100644 index 000000000..1d79e0dfc --- /dev/null +++ b/inference/core/active_learning/samplers/contains_classes.py @@ -0,0 +1,49 @@ +from functools import partial +from typing import Any, Dict, Set + +import numpy as np + +from inference.core.active_learning.entities import ( + Prediction, + PredictionType, + SamplingMethod, +) +from inference.core.active_learning.samplers.close_to_threshold import ( + sample_close_to_threshold, +) + + +def initialize_classes_based_sampling( + strategy_config: Dict[str, Any] +) -> SamplingMethod: + sample_function = partial( + sample_close_to_threshold, + selected_class_names=strategy_config["selected_class_names"], + minimum_objects=strategy_config["minimum_objects"], + probability=strategy_config["probability"], + ) + return SamplingMethod( + name=strategy_config["name"], + sample=sample_function, + ) + + +def sample_based_on_classes( + image: np.ndarray, + prediction: Prediction, + prediction_type: PredictionType, + selected_class_names: Set[str], + minimum_objects: int, + probability: float, +) -> bool: + return sample_close_to_threshold( + image=image, + prediction=prediction, + prediction_type=prediction_type, + selected_class_names=selected_class_names, + threshold=0.5, + epsilon=1.0, + only_top_classes=True, + minimum_objects_close_to_threshold=minimum_objects, + probability=probability, + ) From 9d988d418151be9a135da98d0afe12438cc4bf39 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20P=C4=99czek?= Date: Thu, 2 Nov 2023 18:06:24 +0100 Subject: [PATCH 04/15] Add sampling based on number of detections --- .../samplers/close_to_threshold.py | 40 ++++++---- .../samplers/contains_classes.py | 2 +- .../samplers/number_of_detections.py | 77 +++++++++++++++++++ inference/core/exceptions.py | 4 + 4 files changed, 107 insertions(+), 16 deletions(-) create mode 100644 inference/core/active_learning/samplers/number_of_detections.py diff --git a/inference/core/active_learning/samplers/close_to_threshold.py b/inference/core/active_learning/samplers/close_to_threshold.py index 23825f1e7..290ef43ca 100644 --- a/inference/core/active_learning/samplers/close_to_threshold.py +++ b/inference/core/active_learning/samplers/close_to_threshold.py @@ -1,6 +1,6 @@ import random from functools import partial -from typing import Any, Dict, Set +from typing import Any, Dict, Optional, Set import numpy as np @@ -15,15 +15,19 @@ def initialize_close_to_threshold_sampling( strategy_config: Dict[str, Any] ) -> SamplingMethod: + selected_class_names = strategy_config.get("selected_class_names") + if selected_class_names is not None: + selected_class_names = set(selected_class_names) sample_function = partial( sample_close_to_threshold, - selected_class_names=strategy_config["selected_class_names"], + selected_class_names=selected_class_names, threshold=strategy_config["threshold"], epsilon=strategy_config["epsilon"], - only_top_classes=strategy_config["only_top_classes"], - minimum_objects_close_to_threshold=strategy_config[ - "minimum_objects_close_to_threshold" - ], + only_top_classes=strategy_config.get("only_top_classes", True), + minimum_objects_close_to_threshold=strategy_config.get( + "minimum_objects_close_to_threshold", + 1, + ), probability=strategy_config["probability"], ) return SamplingMethod( @@ -36,7 +40,7 @@ def sample_close_to_threshold( image: np.ndarray, prediction: Prediction, prediction_type: PredictionType, - selected_class_names: Set[str], + selected_class_names: Optional[Set[str]], threshold: float, epsilon: float, only_top_classes: bool, @@ -66,7 +70,7 @@ def is_prediction_a_stub(prediction: Prediction) -> bool: def prediction_is_close_to_threshold( prediction: Prediction, prediction_type: PredictionType, - selected_class_names: Set[str], + selected_class_names: Optional[Set[str]], threshold: float, epsilon: float, only_top_classes: bool, @@ -120,18 +124,21 @@ def multi_class_classification_prediction_is_close_to_threshold( def multi_class_classification_prediction_is_close_to_threshold_for_top_class( prediction: Prediction, - selected_class_names: Set[str], + selected_class_names: Optional[Set[str]], threshold: float, epsilon: float, ) -> bool: - if prediction["top"] not in selected_class_names: + if ( + selected_class_names is not None + and prediction["top"] not in selected_class_names + ): return False return abs(prediction["confidence"] - threshold) < epsilon def multi_label_classification_prediction_is_close_to_threshold( prediction: Prediction, - selected_class_names: Set[str], + selected_class_names: Optional[Set[str]], threshold: float, epsilon: float, only_top_classes: bool, @@ -140,7 +147,7 @@ def multi_label_classification_prediction_is_close_to_threshold( for class_name, prediction_details in prediction["predictions"].items(): if only_top_classes and class_name not in predicted_classes: continue - if class_name not in selected_class_names: + if selected_class_names is not None and class_name not in selected_class_names: continue if close_to_threshold( value=prediction_details["confidence"], threshold=threshold, epsilon=epsilon @@ -151,7 +158,7 @@ def multi_label_classification_prediction_is_close_to_threshold( def detection_prediction_is_close_to_threshold( prediction: Prediction, - selected_class_names: Set[str], + selected_class_names: Optional[Set[str]], threshold: float, epsilon: float, minimum_objects_close_to_threshold: int, @@ -167,13 +174,16 @@ def detection_prediction_is_close_to_threshold( def count_detections_close_to_threshold( prediction: Prediction, - selected_class_names: Set[str], + selected_class_names: Optional[Set[str]], threshold: float, epsilon: float, ) -> int: counter = 0 for prediction_details in prediction["predictions"]: - if prediction_details["class"] not in selected_class_names: + if ( + selected_class_names is not None + and prediction_details["class"] not in selected_class_names + ): continue if close_to_threshold( value=prediction_details["confidence"], threshold=threshold, epsilon=epsilon diff --git a/inference/core/active_learning/samplers/contains_classes.py b/inference/core/active_learning/samplers/contains_classes.py index 1d79e0dfc..50929ca4e 100644 --- a/inference/core/active_learning/samplers/contains_classes.py +++ b/inference/core/active_learning/samplers/contains_classes.py @@ -19,7 +19,7 @@ def initialize_classes_based_sampling( sample_function = partial( sample_close_to_threshold, selected_class_names=strategy_config["selected_class_names"], - minimum_objects=strategy_config["minimum_objects"], + minimum_objects=strategy_config.get("minimum_objects", 1), probability=strategy_config["probability"], ) return SamplingMethod( diff --git a/inference/core/active_learning/samplers/number_of_detections.py b/inference/core/active_learning/samplers/number_of_detections.py new file mode 100644 index 000000000..8fdad1ea1 --- /dev/null +++ b/inference/core/active_learning/samplers/number_of_detections.py @@ -0,0 +1,77 @@ +import random +from functools import partial +from typing import Any, Dict, Optional, Set + +import numpy as np + +from inference.core.active_learning.entities import ( + Prediction, + PredictionType, + SamplingMethod, +) +from inference.core.active_learning.samplers.close_to_threshold import ( + count_detections_close_to_threshold, + is_prediction_a_stub, + sample_close_to_threshold, +) +from inference.core.constants import CLASSIFICATION_TASK +from inference.core.exceptions import ActiveLearningConfigurationError + + +def initialize_detections_number_based_sampling( + strategy_config: Dict[str, Any] +) -> SamplingMethod: + less_than_objects = strategy_config.get("less_than_objects") + more_than_objects = strategy_config.get("more_than_objects") + both_nones = less_than_objects is None and more_than_objects is None + both_has_values = less_than_objects is not None and more_than_objects is not None + if both_nones or both_has_values: + raise ActiveLearningConfigurationError( + f"Only one from `less_than_objects` and `more_than_objects` values must be set." + ) + selected_class_names = strategy_config.get("selected_class_names") + if selected_class_names is not None: + selected_class_names = set(selected_class_names) + sample_function = partial( + sample_close_to_threshold, + less_than_objects=less_than_objects, + more_than_objects=more_than_objects, + selected_class_names=selected_class_names, + probability=strategy_config["probability"], + ) + return SamplingMethod( + name=strategy_config["name"], + sample=sample_function, + ) + + +def sample_based_on_detections_number( + image: np.ndarray, + prediction: Prediction, + prediction_type: PredictionType, + less_than_objects: Optional[int], + more_than_objects: Optional[int], + selected_class_names: Optional[Set[str]], + probability: float, +) -> bool: + if CLASSIFICATION_TASK in prediction_type: + return False + if is_prediction_a_stub(prediction=prediction): + return False + detections_close_to_threshold = count_detections_close_to_threshold( + prediction=prediction, + selected_class_names=selected_class_names, + threshold=0.5, + epsilon=1.0, + ) + if ( + less_than_objects is not None + and detections_close_to_threshold >= less_than_objects + ): + return False + if ( + more_than_objects is not None + and detections_close_to_threshold <= more_than_objects + ): + return False + return random.random() < probability diff --git a/inference/core/exceptions.py b/inference/core/exceptions.py index 0be636c2f..41f509e28 100644 --- a/inference/core/exceptions.py +++ b/inference/core/exceptions.py @@ -172,3 +172,7 @@ class PredictionFormatNotSupported(ActiveLearningError): class ActiveLearningConfigurationDecodingError(ActiveLearningError): pass + + +class ActiveLearningConfigurationError(ActiveLearningError): + pass From 8ad7966044a78bf21b241dc5f33de9c3de9f3d3b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20P=C4=99czek?= Date: Thu, 2 Nov 2023 18:14:49 +0100 Subject: [PATCH 05/15] Add automatic initialisation for new sampling strategies --- inference/core/active_learning/configuration.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/inference/core/active_learning/configuration.py b/inference/core/active_learning/configuration.py index 18a799117..dffbec111 100644 --- a/inference/core/active_learning/configuration.py +++ b/inference/core/active_learning/configuration.py @@ -9,6 +9,12 @@ from inference.core.active_learning.samplers.close_to_threshold import ( initialize_close_to_threshold_sampling, ) +from inference.core.active_learning.samplers.contains_classes import ( + initialize_classes_based_sampling, +) +from inference.core.active_learning.samplers.number_of_detections import ( + initialize_detections_number_based_sampling, +) from inference.core.active_learning.samplers.random import initialize_random_sampling from inference.core.env import ACTIVE_LEARNING_ENABLED from inference.core.roboflow_api import ( @@ -21,6 +27,8 @@ TYPE2SAMPLING_INITIALIZERS = { "random_sampling": initialize_random_sampling, "close_to_threshold_sampling": initialize_close_to_threshold_sampling, + "classes_based_sampling": initialize_classes_based_sampling, + "detections_number_based_sampling": initialize_detections_number_based_sampling, } From e3a3f2d536f1fc87055dc037516c0609d66a25f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20P=C4=99czek?= Date: Thu, 2 Nov 2023 18:48:59 +0100 Subject: [PATCH 06/15] Add basic test for close-to-threshold sampling --- .../samplers/test_close_to_threshold.py | 377 ++++++++++++++++++ 1 file changed, 377 insertions(+) create mode 100644 tests/inference/unit_tests/core/active_learning/samplers/test_close_to_threshold.py diff --git a/tests/inference/unit_tests/core/active_learning/samplers/test_close_to_threshold.py b/tests/inference/unit_tests/core/active_learning/samplers/test_close_to_threshold.py new file mode 100644 index 000000000..f5ae1029b --- /dev/null +++ b/tests/inference/unit_tests/core/active_learning/samplers/test_close_to_threshold.py @@ -0,0 +1,377 @@ +import pytest + +from inference.core.active_learning.samplers.close_to_threshold import ( + close_to_threshold, + count_detections_close_to_threshold, + detection_prediction_is_close_to_threshold, + multi_label_classification_prediction_is_close_to_threshold, + multi_class_classification_prediction_is_close_to_threshold_for_top_class, +) + +OBJECT_DETECTION_PREDICTION = { + "predictions": [ + { + "x": 784.5, + "y": 397.5, + "width": 187.0, + "height": 309.0, + "confidence": 0.9, + "class": "a", + "class_id": 1, + }, + { + "x": 784.5, + "y": 397.5, + "width": 187.0, + "height": 309.0, + "confidence": 0.7, + "class": "b", + "class_id": 1, + }, + ] +} +INSTANCE_SEGMENTATION_PREDICTION = { + "predictions": [ + { + "x": 784.5, + "y": 397.5, + "width": 187.0, + "height": 309.0, + "confidence": 0.9, + "class": "a", + "class_id": 1, + "points": [{"x": 207.0453125, "y": 106.559375}], + }, + { + "x": 784.5, + "y": 397.5, + "width": 187.0, + "height": 309.0, + "confidence": 0.7, + "class": "b", + "class_id": 1, + "points": [{"x": 207.0453125, "y": 106.559375}], + }, + ] +} +KEYPOINTS_PREDICTION = { + "predictions": [ + { + "x": 784.5, + "y": 397.5, + "width": 187.0, + "height": 309.0, + "confidence": 0.9, + "class": "a", + "class_id": 1, + "keypoints": [ + { + "x": 207.0453125, + "y": 106.559375, + "confidence": 0.92, + "class_id": 1, + "class": "eye", + } + ], + }, + { + "x": 784.5, + "y": 397.5, + "width": 187.0, + "height": 309.0, + "confidence": 0.7, + "class": "b", + "class_id": 1, + "keypoints": [ + { + "x": 207.0453125, + "y": 106.559375, + "confidence": 0.92, + "class_id": 1, + "class": "eye", + } + ], + }, + ] +} + +MULTI_LABEL_CLASSIFICATION_PREDICTION = { + "image": {"width": 416, "height": 416}, + "predictions": { + "cat": {"confidence": 0.97}, + "dog": {"confidence": 0.03}, + }, + "predicted_classes": ["cat"], +} + +MULTI_CLASS_CLASSIFICATION_PREDICTION = { + "image": {"width": 3487, "height": 2039}, + "predictions": [ + {"class": "Ambulance", "class_id": 0, "confidence": 0.6}, + {"class": "Limousine", "class_id": 16, "confidence": 0.3}, + {"class": "Helicopter", "class_id": 15, "confidence": 0.1}, + ], + "top": "Ambulance", + "confidence": 0.6, +} + + +@pytest.mark.parametrize( + "value, threshold, epsilon", + [ + (0.45, 0.5, 0.06), + (0.55, 0.5, 0.06), + ], +) +def test_close_to_threshold_when_value_is_close( + value: float, threshold: float, epsilon: float +) -> None: + # when + result = close_to_threshold(value=value, threshold=threshold, epsilon=epsilon) + + # then + assert result is True + + +@pytest.mark.parametrize( + "value, threshold, epsilon", + [ + (0.44, 0.5, 0.05), + (0.56, 0.5, 0.05), + ], +) +def test_close_to_threshold_when_value_is_not_close( + value: float, threshold: float, epsilon: float +) -> None: + # when + result = close_to_threshold(value=value, threshold=threshold, epsilon=epsilon) + + # then + assert result is False + + +def test_count_detections_close_to_threshold_when_no_detections_in_prediction() -> None: + # given + prediction = {"predictions": []} + + # when + result = count_detections_close_to_threshold( + prediction=prediction, + selected_class_names=None, + threshold=0.5, + epsilon=1.0, + ) + + # then + assert result == 0 + + +@pytest.mark.parametrize( + "prediction", + [ + OBJECT_DETECTION_PREDICTION, + INSTANCE_SEGMENTATION_PREDICTION, + KEYPOINTS_PREDICTION, + ], +) +def test_count_detections_close_to_threshold_when_no_selected_class_names_pointed( + prediction: dict, +) -> None: + # when + result = count_detections_close_to_threshold( + prediction=prediction, + selected_class_names=None, + threshold=0.5, + epsilon=1.0, + ) + + # then + assert result == 2 + + +@pytest.mark.parametrize( + "prediction", + [ + OBJECT_DETECTION_PREDICTION, + INSTANCE_SEGMENTATION_PREDICTION, + KEYPOINTS_PREDICTION, + ], +) +def test_count_detections_close_to_threshold_when_no_selected_class_names_filter_out_predictions( + prediction: dict, +) -> None: + # when + result = count_detections_close_to_threshold( + prediction=prediction, + selected_class_names={"a", "c"}, + threshold=0.5, + epsilon=1.0, + ) + + # then + assert result == 1 + + +@pytest.mark.parametrize( + "prediction", + [ + OBJECT_DETECTION_PREDICTION, + INSTANCE_SEGMENTATION_PREDICTION, + KEYPOINTS_PREDICTION, + ], +) +def test_count_detections_close_to_threshold_when_no_selected_threshold_filter_out_predictions( + prediction: dict, +) -> None: + # when + result = count_detections_close_to_threshold( + prediction=prediction, + selected_class_names=None, + threshold=0.6, + epsilon=0.15, + ) + + # then + assert result == 1 + + +@pytest.mark.parametrize( + "prediction", + [ + OBJECT_DETECTION_PREDICTION, + INSTANCE_SEGMENTATION_PREDICTION, + KEYPOINTS_PREDICTION, + ], +) +def test_detection_prediction_is_close_to_threshold_when_minimum_objects_criterion_met( + prediction: dict, +) -> None: + # when + result = detection_prediction_is_close_to_threshold( + prediction=prediction, + selected_class_names=None, + threshold=0.6, + epsilon=0.15, + minimum_objects_close_to_threshold=1, + ) + + # then + assert result is True + + +@pytest.mark.parametrize( + "prediction", + [ + OBJECT_DETECTION_PREDICTION, + INSTANCE_SEGMENTATION_PREDICTION, + KEYPOINTS_PREDICTION, + ], +) +def test_detection_prediction_is_close_to_threshold_when_minimum_objects_criterion_not_met( + prediction: dict, +) -> None: + # when + result = detection_prediction_is_close_to_threshold( + prediction=prediction, + selected_class_names=None, + threshold=0.6, + epsilon=0.15, + minimum_objects_close_to_threshold=2, + ) + + # then + assert result is False + + +def test_multi_label_classification_prediction_is_close_to_threshold_when_top_class_meet_criteria() -> ( + None +): + # when + result = multi_label_classification_prediction_is_close_to_threshold( + prediction=MULTI_LABEL_CLASSIFICATION_PREDICTION, + selected_class_names=None, + threshold=0.95, + epsilon=0.05, + only_top_classes=True, + ) + + # then + assert result is True + + +def test_multi_label_classification_prediction_is_close_to_threshold_when_non_top_class_meet_threshold_but_filtered_out_by_top_classes() -> ( + None +): + # when + result = multi_label_classification_prediction_is_close_to_threshold( + prediction=MULTI_LABEL_CLASSIFICATION_PREDICTION, + selected_class_names=None, + threshold=0.05, + epsilon=0.05, + only_top_classes=True, + ) + + # then + assert result is False + + +def test_multi_label_classification_prediction_is_close_to_threshold_when_non_top_class_meet_threshold_but_filtered_out_by_class_names() -> ( + None +): + # when + result = multi_label_classification_prediction_is_close_to_threshold( + prediction=MULTI_LABEL_CLASSIFICATION_PREDICTION, + selected_class_names={"cat", "tiger"}, + threshold=0.05, + epsilon=0.05, + only_top_classes=False, + ) + + # then + assert result is False + + +def test_multi_label_classification_prediction_is_close_to_threshold_when_non_top_class_meet_criteria() -> ( + None +): + # when + result = multi_label_classification_prediction_is_close_to_threshold( + prediction=MULTI_LABEL_CLASSIFICATION_PREDICTION, + selected_class_names=None, + threshold=0.05, + epsilon=0.05, + only_top_classes=False, + ) + + # then + assert result is True + + +def test_multi_class_classification_prediction_is_close_to_threshold_for_top_class_when_classes_are_not_selected_and_threshold_met() -> ( + None +): + # when + result = multi_class_classification_prediction_is_close_to_threshold_for_top_class( + prediction=MULTI_CLASS_CLASSIFICATION_PREDICTION, + selected_class_names=None, + threshold=0.6, + epsilon=0.1, + ) + + # then + assert result is True + + +def test_multi_class_classification_prediction_is_close_to_threshold_for_top_class_when_classes_are_not_selected_and_threshold_not_met() -> ( + None +): + # when + result = multi_class_classification_prediction_is_close_to_threshold_for_top_class( + prediction=MULTI_CLASS_CLASSIFICATION_PREDICTION, + selected_class_names=None, + threshold=0.8, + epsilon=0.1, + ) + + # then + assert result is False From e90734ea8d38202be28e64458bc2df01b004cbb5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20P=C4=99czek?= Date: Fri, 3 Nov 2023 12:34:07 +0100 Subject: [PATCH 07/15] Add tests for close_to_threshold sampling module --- .../samplers/close_to_threshold.py | 57 ++- .../samplers/number_of_detections.py | 16 +- inference/core/constants.py | 1 + .../samplers/test_close_to_threshold.py | 472 +++++++++++++++++- 4 files changed, 518 insertions(+), 28 deletions(-) diff --git a/inference/core/active_learning/samplers/close_to_threshold.py b/inference/core/active_learning/samplers/close_to_threshold.py index 290ef43ca..1328aa82c 100644 --- a/inference/core/active_learning/samplers/close_to_threshold.py +++ b/inference/core/active_learning/samplers/close_to_threshold.py @@ -9,7 +9,19 @@ PredictionType, SamplingMethod, ) -from inference.core.constants import CLASSIFICATION_TASK +from inference.core.constants import ( + CLASSIFICATION_TASK, + INSTANCE_SEGMENTATION_TASK, + KEYPOINTS_DETECTION_TASK, + OBJECT_DETECTION_TASK, +) + +ELIGIBLE_PREDICTION_TYPES = { + CLASSIFICATION_TASK, + INSTANCE_SEGMENTATION_TASK, + KEYPOINTS_DETECTION_TASK, + OBJECT_DETECTION_TASK, +} def initialize_close_to_threshold_sampling( @@ -23,7 +35,7 @@ def initialize_close_to_threshold_sampling( selected_class_names=selected_class_names, threshold=strategy_config["threshold"], epsilon=strategy_config["epsilon"], - only_top_classes=strategy_config.get("only_top_classes", True), + only_top_classes=strategy_config.get("only_top_classes", False), minimum_objects_close_to_threshold=strategy_config.get( "minimum_objects_close_to_threshold", 1, @@ -49,7 +61,9 @@ def sample_close_to_threshold( ) -> bool: if is_prediction_a_stub(prediction=prediction): return False - is_close_to_threshold = prediction_is_close_to_threshold( + if prediction_type not in ELIGIBLE_PREDICTION_TYPES: + return False + close_to_threshold = prediction_is_close_to_threshold( prediction=prediction, prediction_type=prediction_type, selected_class_names=selected_class_names, @@ -58,7 +72,7 @@ def sample_close_to_threshold( only_top_classes=only_top_classes, minimum_objects_close_to_threshold=minimum_objects_close_to_threshold, ) - if not is_close_to_threshold: + if not close_to_threshold: return False return random.random() < probability @@ -77,7 +91,7 @@ def prediction_is_close_to_threshold( minimum_objects_close_to_threshold: int, ) -> bool: if CLASSIFICATION_TASK not in prediction_type: - return detection_prediction_is_close_to_threshold( + return detections_are_close_to_threshold( prediction=prediction, selected_class_names=selected_class_names, threshold=threshold, @@ -98,7 +112,7 @@ def prediction_is_close_to_threshold( def multi_class_classification_prediction_is_close_to_threshold( prediction: Prediction, - selected_class_names: Set[str], + selected_class_names: Optional[Set[str]], threshold: float, epsilon: float, only_top_classes: bool, @@ -113,9 +127,12 @@ def multi_class_classification_prediction_is_close_to_threshold( ) ) for prediction_details in prediction["predictions"]: - if prediction_details["class"] not in selected_class_names: + if class_to_be_excluded( + class_name=prediction_details["class"], + selected_class_names=selected_class_names, + ): continue - if close_to_threshold( + if is_close_to_threshold( value=prediction_details["confidence"], threshold=threshold, epsilon=epsilon ): return True @@ -147,16 +164,18 @@ def multi_label_classification_prediction_is_close_to_threshold( for class_name, prediction_details in prediction["predictions"].items(): if only_top_classes and class_name not in predicted_classes: continue - if selected_class_names is not None and class_name not in selected_class_names: + if class_to_be_excluded( + class_name=class_name, selected_class_names=selected_class_names + ): continue - if close_to_threshold( + if is_close_to_threshold( value=prediction_details["confidence"], threshold=threshold, epsilon=epsilon ): return True return False -def detection_prediction_is_close_to_threshold( +def detections_are_close_to_threshold( prediction: Prediction, selected_class_names: Optional[Set[str]], threshold: float, @@ -180,17 +199,23 @@ def count_detections_close_to_threshold( ) -> int: counter = 0 for prediction_details in prediction["predictions"]: - if ( - selected_class_names is not None - and prediction_details["class"] not in selected_class_names + if class_to_be_excluded( + class_name=prediction_details["class"], + selected_class_names=selected_class_names, ): continue - if close_to_threshold( + if is_close_to_threshold( value=prediction_details["confidence"], threshold=threshold, epsilon=epsilon ): counter += 1 return counter -def close_to_threshold(value: float, threshold: float, epsilon: float) -> bool: +def class_to_be_excluded( + class_name: str, selected_class_names: Optional[Set[str]] +) -> bool: + return selected_class_names is not None and class_name not in selected_class_names + + +def is_close_to_threshold(value: float, threshold: float, epsilon: float) -> bool: return abs(value - threshold) < epsilon diff --git a/inference/core/active_learning/samplers/number_of_detections.py b/inference/core/active_learning/samplers/number_of_detections.py index 8fdad1ea1..f8c50e3e2 100644 --- a/inference/core/active_learning/samplers/number_of_detections.py +++ b/inference/core/active_learning/samplers/number_of_detections.py @@ -14,9 +14,19 @@ is_prediction_a_stub, sample_close_to_threshold, ) -from inference.core.constants import CLASSIFICATION_TASK +from inference.core.constants import ( + INSTANCE_SEGMENTATION_TASK, + KEYPOINTS_DETECTION_TASK, + OBJECT_DETECTION_TASK, +) from inference.core.exceptions import ActiveLearningConfigurationError +ELIGIBLE_PREDICTION_TYPES = { + INSTANCE_SEGMENTATION_TASK, + KEYPOINTS_DETECTION_TASK, + OBJECT_DETECTION_TASK, +} + def initialize_detections_number_based_sampling( strategy_config: Dict[str, Any] @@ -54,10 +64,10 @@ def sample_based_on_detections_number( selected_class_names: Optional[Set[str]], probability: float, ) -> bool: - if CLASSIFICATION_TASK in prediction_type: - return False if is_prediction_a_stub(prediction=prediction): return False + if prediction_type not in ELIGIBLE_PREDICTION_TYPES: + return False detections_close_to_threshold = count_detections_close_to_threshold( prediction=prediction, selected_class_names=selected_class_names, diff --git a/inference/core/constants.py b/inference/core/constants.py index 383e36845..a17766247 100644 --- a/inference/core/constants.py +++ b/inference/core/constants.py @@ -1,3 +1,4 @@ CLASSIFICATION_TASK = "classification" OBJECT_DETECTION_TASK = "object-detection" INSTANCE_SEGMENTATION_TASK = "instance-segmentation" +KEYPOINTS_DETECTION_TASK = "keypoints-detection" diff --git a/tests/inference/unit_tests/core/active_learning/samplers/test_close_to_threshold.py b/tests/inference/unit_tests/core/active_learning/samplers/test_close_to_threshold.py index f5ae1029b..b7f26ff53 100644 --- a/tests/inference/unit_tests/core/active_learning/samplers/test_close_to_threshold.py +++ b/tests/inference/unit_tests/core/active_learning/samplers/test_close_to_threshold.py @@ -1,11 +1,29 @@ +from typing import Optional, Set +from unittest import mock +from unittest.mock import MagicMock + +import numpy as np import pytest +from inference.core.active_learning.entities import PredictionType from inference.core.active_learning.samplers.close_to_threshold import ( - close_to_threshold, + is_close_to_threshold, count_detections_close_to_threshold, - detection_prediction_is_close_to_threshold, + detections_are_close_to_threshold, multi_label_classification_prediction_is_close_to_threshold, - multi_class_classification_prediction_is_close_to_threshold_for_top_class, + multi_class_classification_prediction_is_close_to_threshold, + class_to_be_excluded, + prediction_is_close_to_threshold, + is_prediction_a_stub, + sample_close_to_threshold, + initialize_close_to_threshold_sampling, +) +from inference.core.active_learning.samplers import close_to_threshold +from inference.core.constants import ( + CLASSIFICATION_TASK, + INSTANCE_SEGMENTATION_TASK, + KEYPOINTS_DETECTION_TASK, + OBJECT_DETECTION_TASK, ) OBJECT_DETECTION_PREDICTION = { @@ -127,7 +145,7 @@ def test_close_to_threshold_when_value_is_close( value: float, threshold: float, epsilon: float ) -> None: # when - result = close_to_threshold(value=value, threshold=threshold, epsilon=epsilon) + result = is_close_to_threshold(value=value, threshold=threshold, epsilon=epsilon) # then assert result is True @@ -144,12 +162,40 @@ def test_close_to_threshold_when_value_is_not_close( value: float, threshold: float, epsilon: float ) -> None: # when - result = close_to_threshold(value=value, threshold=threshold, epsilon=epsilon) + result = is_close_to_threshold(value=value, threshold=threshold, epsilon=epsilon) + + # then + assert result is False + + +def test_class_to_be_excluded_when_classes_not_selected() -> None: + # when + result = class_to_be_excluded(class_name="some", selected_class_names=None) # then assert result is False +def test_class_to_be_excluded_when_classes_selected_and_specific_class_matches() -> ( + None +): + # when + result = class_to_be_excluded(class_name="a", selected_class_names={"a", "b", "c"}) + + # then + assert result is False + + +def test_class_to_be_excluded_when_classes_selected_and_specific_class_does_not_matche() -> ( + None +): + # when + result = class_to_be_excluded(class_name="d", selected_class_names={"a", "b", "c"}) + + # then + assert result is True + + def test_count_detections_close_to_threshold_when_no_detections_in_prediction() -> None: # given prediction = {"predictions": []} @@ -247,7 +293,7 @@ def test_detection_prediction_is_close_to_threshold_when_minimum_objects_criteri prediction: dict, ) -> None: # when - result = detection_prediction_is_close_to_threshold( + result = detections_are_close_to_threshold( prediction=prediction, selected_class_names=None, threshold=0.6, @@ -271,7 +317,7 @@ def test_detection_prediction_is_close_to_threshold_when_minimum_objects_criteri prediction: dict, ) -> None: # when - result = detection_prediction_is_close_to_threshold( + result = detections_are_close_to_threshold( prediction=prediction, selected_class_names=None, threshold=0.6, @@ -351,11 +397,12 @@ def test_multi_class_classification_prediction_is_close_to_threshold_for_top_cla None ): # when - result = multi_class_classification_prediction_is_close_to_threshold_for_top_class( + result = multi_class_classification_prediction_is_close_to_threshold( prediction=MULTI_CLASS_CLASSIFICATION_PREDICTION, selected_class_names=None, threshold=0.6, epsilon=0.1, + only_top_classes=True, ) # then @@ -366,12 +413,419 @@ def test_multi_class_classification_prediction_is_close_to_threshold_for_top_cla None ): # when - result = multi_class_classification_prediction_is_close_to_threshold_for_top_class( + result = multi_class_classification_prediction_is_close_to_threshold( prediction=MULTI_CLASS_CLASSIFICATION_PREDICTION, selected_class_names=None, threshold=0.8, epsilon=0.1, + only_top_classes=True, + ) + + # then + assert result is False + + +def test_multi_class_classification_prediction_is_close_to_threshold_for_top_class_when_classes_are_selected_and_top_class_does_not_match() -> ( + None +): + # when + result = multi_class_classification_prediction_is_close_to_threshold( + prediction=MULTI_CLASS_CLASSIFICATION_PREDICTION, + selected_class_names={"Limousine", "Helicopter"}, + threshold=0.6, + epsilon=0.1, + only_top_classes=True, + ) + + # then + assert result is False + + +def test_multi_class_classification_prediction_is_close_to_threshold_for_top_class_when_classes_are_selected_and_top_class_matches() -> ( + None +): + # when + result = multi_class_classification_prediction_is_close_to_threshold( + prediction=MULTI_CLASS_CLASSIFICATION_PREDICTION, + selected_class_names={"Ambulance"}, + threshold=0.6, + epsilon=0.1, + only_top_classes=True, + ) + + # then + assert result is True + + +def test_multi_class_classification_prediction_is_close_to_threshold_not_only_for_top_class_when_classes_not_selected() -> ( + None +): + # when + result = multi_class_classification_prediction_is_close_to_threshold( + prediction=MULTI_CLASS_CLASSIFICATION_PREDICTION, + selected_class_names=None, + threshold=0.3, + epsilon=0.1, + only_top_classes=False, + ) + + # then + assert result is True + + +def test_multi_class_classification_prediction_is_close_to_threshold_not_only_for_top_class_when_classes_not_selected_and_no_match_expected() -> ( + None +): + # when + result = multi_class_classification_prediction_is_close_to_threshold( + prediction=MULTI_CLASS_CLASSIFICATION_PREDICTION, + selected_class_names=None, + threshold=0.5, + epsilon=0.05, + only_top_classes=False, + ) + + # then + assert result is False + + +def test_multi_class_classification_prediction_is_close_to_threshold_not_only_for_top_class_when_classes_are_selected() -> ( + None +): + # when + result = multi_class_classification_prediction_is_close_to_threshold( + prediction=MULTI_CLASS_CLASSIFICATION_PREDICTION, + selected_class_names={"Ambulance", "Helicopter"}, + threshold=0.3, + epsilon=0.1, + only_top_classes=False, ) # then assert result is False + + +@pytest.mark.parametrize( + "prediction, prediction_type, selected_class_names, threshold, epsilon, only_top_classes, " + "minimum_objects_close_to_threshold, expected_result", + [ + ( + OBJECT_DETECTION_PREDICTION, + OBJECT_DETECTION_TASK, + None, + 0.9, + 0.05, + False, + 1, + True, + ), + ( + INSTANCE_SEGMENTATION_PREDICTION, + INSTANCE_SEGMENTATION_TASK, + None, + 0.9, + 0.05, + False, + 2, + False, + ), + ( + KEYPOINTS_PREDICTION, + KEYPOINTS_DETECTION_TASK, + None, + 0.8, + 0.05, + False, + 1, + False, + ), + ( + MULTI_CLASS_CLASSIFICATION_PREDICTION, + CLASSIFICATION_TASK, + None, + 0.6, + 0.1, + True, + 1, + True, + ), + ( + MULTI_LABEL_CLASSIFICATION_PREDICTION, + CLASSIFICATION_TASK, + None, + 0.05, + 0.05, + False, + 1, + True, + ), + ], +) +def test_prediction_is_close_to_threshold( + prediction: dict, + prediction_type: PredictionType, + selected_class_names: Optional[Set[str]], + threshold: float, + epsilon: float, + only_top_classes: bool, + minimum_objects_close_to_threshold: int, + expected_result: bool, +) -> None: + # when + result = prediction_is_close_to_threshold( + prediction=prediction, + prediction_type=prediction_type, + selected_class_names=selected_class_names, + threshold=threshold, + epsilon=epsilon, + only_top_classes=only_top_classes, + minimum_objects_close_to_threshold=minimum_objects_close_to_threshold, + ) + + # then + assert result is expected_result + + +@pytest.mark.parametrize( + "prediction", + [ + OBJECT_DETECTION_PREDICTION, + KEYPOINTS_PREDICTION, + INSTANCE_SEGMENTATION_PREDICTION, + MULTI_CLASS_CLASSIFICATION_PREDICTION, + MULTI_LABEL_CLASSIFICATION_PREDICTION, + ], +) +def test_is_prediction_a_stub_when_prediction_is_not_a_stub(prediction: dict) -> None: + # when + result = is_prediction_a_stub(prediction=prediction) + + # then + assert result is False + + +def test_is_prediction_a_stub_when_prediction_is_a_stub() -> None: + # given + prediction = {"is_stub": True} + + # when + result = is_prediction_a_stub(prediction=prediction) + + # then + assert result is True + + +def test_sample_close_to_threshold_when_prediction_is_sub() -> None: + # when + result = sample_close_to_threshold( + image=np.zeros((128, 128, 3), dtype=np.uint8), + prediction={"is_stub": True}, + prediction_type=CLASSIFICATION_TASK, + selected_class_names=None, + threshold=0.5, + epsilon=0.1, + only_top_classes=True, + minimum_objects_close_to_threshold=1, + probability=1.0, + ) + + # then + assert result is False + + +def test_sample_close_to_threshold_when_prediction_type_is_unknown() -> None: + # when + result = sample_close_to_threshold( + image=np.zeros((128, 128, 3), dtype=np.uint8), + prediction={"is_stub": True}, + prediction_type="unknown", + selected_class_names=None, + threshold=0.5, + epsilon=0.1, + only_top_classes=True, + minimum_objects_close_to_threshold=1, + probability=1.0, + ) + + # then + assert result is False + + +def test_sample_close_to_threshold_when_prediction_is_not_close_to_threshold() -> None: + # when + result = sample_close_to_threshold( + image=np.zeros((128, 128, 3), dtype=np.uint8), + prediction=MULTI_CLASS_CLASSIFICATION_PREDICTION, + prediction_type=CLASSIFICATION_TASK, + selected_class_names={"Ambulance"}, + threshold=0.8, + epsilon=0.1, + only_top_classes=True, + minimum_objects_close_to_threshold=1, + probability=1.0, + ) + + # then + assert result is False + + +@mock.patch.object(close_to_threshold.random, "random") +def test_sample_close_to_threshold_when_prediction_is_close_to_threshold( + random_mock: MagicMock, +) -> None: + # given + random_mock.return_value = 0.49 + + # when + result = sample_close_to_threshold( + image=np.zeros((128, 128, 3), dtype=np.uint8), + prediction=MULTI_CLASS_CLASSIFICATION_PREDICTION, + prediction_type=CLASSIFICATION_TASK, + selected_class_names={"Ambulance"}, + threshold=0.6, + epsilon=0.1, + only_top_classes=True, + minimum_objects_close_to_threshold=1, + probability=0.5, + ) + + # then + assert result is True + + +def test_initialize_close_to_threshold_sampling() -> None: + # given + strategy_config = { + "name": "ambulance_high_confidence", + "selected_class_names": ["Ambulance"], + "threshold": 0.75, + "epsilon": 0.25, + "probability": 1.0, + } + + # when + sampling_method = initialize_close_to_threshold_sampling( + strategy_config=strategy_config + ) + result = sampling_method.sample( + np.zeros((128, 128, 3), dtype=np.uint8), + MULTI_CLASS_CLASSIFICATION_PREDICTION, + CLASSIFICATION_TASK, + ) + + # then + assert result is True + assert sampling_method.name == "ambulance_high_confidence" + + +@mock.patch.object(close_to_threshold, "partial") +def test_initialize_close_to_threshold_sampling_when_classes_not_selected( + partial_mock: MagicMock, +) -> None: + # given + strategy_config = { + "name": "ambulance_high_confidence", + "threshold": 0.75, + "epsilon": 0.25, + "probability": 0.6, + } + + # when + _ = initialize_close_to_threshold_sampling(strategy_config=strategy_config) + + # then + partial_mock.assert_called_once_with( + sample_close_to_threshold, + selected_class_names=None, + threshold=0.75, + epsilon=0.25, + only_top_classes=False, + minimum_objects_close_to_threshold=1, + probability=0.6, + ) + + +@mock.patch.object(close_to_threshold, "partial") +def test_initialize_close_to_threshold_sampling_when_classes_selected( + partial_mock: MagicMock, +) -> None: + # given + strategy_config = { + "name": "ambulance_high_confidence", + "selected_class_names": ["Ambulance", "Helicopter"], + "threshold": 0.75, + "epsilon": 0.25, + "probability": 0.6, + } + + # when + _ = initialize_close_to_threshold_sampling(strategy_config=strategy_config) + + # then + partial_mock.assert_called_once_with( + sample_close_to_threshold, + selected_class_names={"Ambulance", "Helicopter"}, + threshold=0.75, + epsilon=0.25, + only_top_classes=False, + minimum_objects_close_to_threshold=1, + probability=0.6, + ) + + +@mock.patch.object(close_to_threshold, "partial") +def test_initialize_close_to_threshold_sampling_when_only_top_classes_mode_enabled( + partial_mock: MagicMock, +) -> None: + # given + strategy_config = { + "name": "ambulance_high_confidence", + "selected_class_names": ["Ambulance"], + "threshold": 0.75, + "epsilon": 0.25, + "probability": 0.6, + "only_top_classes": True, + } + + # when + _ = initialize_close_to_threshold_sampling(strategy_config=strategy_config) + + # then + partial_mock.assert_called_once_with( + sample_close_to_threshold, + selected_class_names={"Ambulance"}, + threshold=0.75, + epsilon=0.25, + only_top_classes=True, + minimum_objects_close_to_threshold=1, + probability=0.6, + ) + + +@mock.patch.object(close_to_threshold, "partial") +def test_initialize_close_to_threshold_sampling_when_objects_close_to_threshold_specified( + partial_mock: MagicMock, +) -> None: + # given + strategy_config = { + "name": "ambulance_high_confidence", + "selected_class_names": ["Ambulance"], + "threshold": 0.75, + "epsilon": 0.25, + "probability": 0.6, + "minimum_objects_close_to_threshold": 6, + } + + # when + _ = initialize_close_to_threshold_sampling(strategy_config=strategy_config) + + # then + partial_mock.assert_called_once_with( + sample_close_to_threshold, + selected_class_names={"Ambulance"}, + threshold=0.75, + epsilon=0.25, + only_top_classes=False, + minimum_objects_close_to_threshold=6, + probability=0.6, + ) From a16d3530bd5bccaebc3b5d1629de41cc643bfd5e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20P=C4=99czek?= Date: Fri, 3 Nov 2023 12:37:02 +0100 Subject: [PATCH 08/15] Simplify tests --- .../samplers/test_close_to_threshold.py | 99 +++++-------------- 1 file changed, 25 insertions(+), 74 deletions(-) diff --git a/tests/inference/unit_tests/core/active_learning/samplers/test_close_to_threshold.py b/tests/inference/unit_tests/core/active_learning/samplers/test_close_to_threshold.py index b7f26ff53..506da6221 100644 --- a/tests/inference/unit_tests/core/active_learning/samplers/test_close_to_threshold.py +++ b/tests/inference/unit_tests/core/active_learning/samplers/test_close_to_threshold.py @@ -505,85 +505,36 @@ def test_multi_class_classification_prediction_is_close_to_threshold_not_only_fo assert result is False -@pytest.mark.parametrize( - "prediction, prediction_type, selected_class_names, threshold, epsilon, only_top_classes, " - "minimum_objects_close_to_threshold, expected_result", - [ - ( - OBJECT_DETECTION_PREDICTION, - OBJECT_DETECTION_TASK, - None, - 0.9, - 0.05, - False, - 1, - True, - ), - ( - INSTANCE_SEGMENTATION_PREDICTION, - INSTANCE_SEGMENTATION_TASK, - None, - 0.9, - 0.05, - False, - 2, - False, - ), - ( - KEYPOINTS_PREDICTION, - KEYPOINTS_DETECTION_TASK, - None, - 0.8, - 0.05, - False, - 1, - False, - ), - ( - MULTI_CLASS_CLASSIFICATION_PREDICTION, - CLASSIFICATION_TASK, - None, - 0.6, - 0.1, - True, - 1, - True, - ), - ( - MULTI_LABEL_CLASSIFICATION_PREDICTION, - CLASSIFICATION_TASK, - None, - 0.05, - 0.05, - False, - 1, - True, - ), - ], -) -def test_prediction_is_close_to_threshold( - prediction: dict, - prediction_type: PredictionType, - selected_class_names: Optional[Set[str]], - threshold: float, - epsilon: float, - only_top_classes: bool, - minimum_objects_close_to_threshold: int, - expected_result: bool, -) -> None: +def test_prediction_is_close_to_threshold_for_detection_prediction() -> None: # when result = prediction_is_close_to_threshold( - prediction=prediction, - prediction_type=prediction_type, - selected_class_names=selected_class_names, - threshold=threshold, - epsilon=epsilon, - only_top_classes=only_top_classes, - minimum_objects_close_to_threshold=minimum_objects_close_to_threshold, + prediction=OBJECT_DETECTION_PREDICTION, + prediction_type=OBJECT_DETECTION_TASK, + selected_class_names=None, + threshold=0.9, + epsilon=0.05, + only_top_classes=False, + minimum_objects_close_to_threshold=1, ) # then - assert result is expected_result + assert result is True + + +def test_prediction_is_close_to_threshold_for_classification_prediction() -> None: + # when + result = prediction_is_close_to_threshold( + prediction=MULTI_LABEL_CLASSIFICATION_PREDICTION, + prediction_type=CLASSIFICATION_TASK, + selected_class_names=None, + threshold=0.05, + epsilon=0.05, + only_top_classes=False, + minimum_objects_close_to_threshold=1, + ) + + # then + assert result is True @pytest.mark.parametrize( From f821db076c46c526bc59e9ed13a99f1d3ec6eb87 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20P=C4=99czek?= Date: Fri, 3 Nov 2023 12:37:24 +0100 Subject: [PATCH 09/15] Simplify tests --- .../core/active_learning/samplers/test_close_to_threshold.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/inference/unit_tests/core/active_learning/samplers/test_close_to_threshold.py b/tests/inference/unit_tests/core/active_learning/samplers/test_close_to_threshold.py index 506da6221..692332df0 100644 --- a/tests/inference/unit_tests/core/active_learning/samplers/test_close_to_threshold.py +++ b/tests/inference/unit_tests/core/active_learning/samplers/test_close_to_threshold.py @@ -1,11 +1,9 @@ -from typing import Optional, Set from unittest import mock from unittest.mock import MagicMock import numpy as np import pytest -from inference.core.active_learning.entities import PredictionType from inference.core.active_learning.samplers.close_to_threshold import ( is_close_to_threshold, count_detections_close_to_threshold, @@ -21,8 +19,6 @@ from inference.core.active_learning.samplers import close_to_threshold from inference.core.constants import ( CLASSIFICATION_TASK, - INSTANCE_SEGMENTATION_TASK, - KEYPOINTS_DETECTION_TASK, OBJECT_DETECTION_TASK, ) From dfdc2ac13160a52e0ad5f354444681d952e0a304 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20P=C4=99czek?= Date: Fri, 3 Nov 2023 13:01:25 +0100 Subject: [PATCH 10/15] Add tests for classes based sampling --- .../samplers/close_to_threshold.py | 44 ++-- .../samplers/contains_classes.py | 26 ++- .../samplers/number_of_detections.py | 49 +++-- .../core/active_learning/samplers/random.py | 22 +- .../samplers/test_close_to_threshold.py | 16 ++ .../samplers/test_contains_classes.py | 192 ++++++++++++++++++ .../active_learning/samplers/test_random.py | 5 +- 7 files changed, 294 insertions(+), 60 deletions(-) create mode 100644 tests/inference/unit_tests/core/active_learning/samplers/test_contains_classes.py diff --git a/inference/core/active_learning/samplers/close_to_threshold.py b/inference/core/active_learning/samplers/close_to_threshold.py index 1328aa82c..e1801ad32 100644 --- a/inference/core/active_learning/samplers/close_to_threshold.py +++ b/inference/core/active_learning/samplers/close_to_threshold.py @@ -15,6 +15,7 @@ KEYPOINTS_DETECTION_TASK, OBJECT_DETECTION_TASK, ) +from inference.core.exceptions import ActiveLearningConfigurationError ELIGIBLE_PREDICTION_TYPES = { CLASSIFICATION_TASK, @@ -27,25 +28,30 @@ def initialize_close_to_threshold_sampling( strategy_config: Dict[str, Any] ) -> SamplingMethod: - selected_class_names = strategy_config.get("selected_class_names") - if selected_class_names is not None: - selected_class_names = set(selected_class_names) - sample_function = partial( - sample_close_to_threshold, - selected_class_names=selected_class_names, - threshold=strategy_config["threshold"], - epsilon=strategy_config["epsilon"], - only_top_classes=strategy_config.get("only_top_classes", False), - minimum_objects_close_to_threshold=strategy_config.get( - "minimum_objects_close_to_threshold", - 1, - ), - probability=strategy_config["probability"], - ) - return SamplingMethod( - name=strategy_config["name"], - sample=sample_function, - ) + try: + selected_class_names = strategy_config.get("selected_class_names") + if selected_class_names is not None: + selected_class_names = set(selected_class_names) + sample_function = partial( + sample_close_to_threshold, + selected_class_names=selected_class_names, + threshold=strategy_config["threshold"], + epsilon=strategy_config["epsilon"], + only_top_classes=strategy_config.get("only_top_classes", False), + minimum_objects_close_to_threshold=strategy_config.get( + "minimum_objects_close_to_threshold", + 1, + ), + probability=strategy_config["probability"], + ) + return SamplingMethod( + name=strategy_config["name"], + sample=sample_function, + ) + except KeyError as error: + raise ActiveLearningConfigurationError( + f"In configuration of `close_to_threshold_sampling` missing key detected: {error}." + ) from error def sample_close_to_threshold( diff --git a/inference/core/active_learning/samplers/contains_classes.py b/inference/core/active_learning/samplers/contains_classes.py index 50929ca4e..defb86f64 100644 --- a/inference/core/active_learning/samplers/contains_classes.py +++ b/inference/core/active_learning/samplers/contains_classes.py @@ -11,21 +11,27 @@ from inference.core.active_learning.samplers.close_to_threshold import ( sample_close_to_threshold, ) +from inference.core.exceptions import ActiveLearningConfigurationError def initialize_classes_based_sampling( strategy_config: Dict[str, Any] ) -> SamplingMethod: - sample_function = partial( - sample_close_to_threshold, - selected_class_names=strategy_config["selected_class_names"], - minimum_objects=strategy_config.get("minimum_objects", 1), - probability=strategy_config["probability"], - ) - return SamplingMethod( - name=strategy_config["name"], - sample=sample_function, - ) + try: + sample_function = partial( + sample_based_on_classes, + selected_class_names=set(strategy_config["selected_class_names"]), + minimum_objects=strategy_config.get("minimum_objects", 1), + probability=strategy_config["probability"], + ) + return SamplingMethod( + name=strategy_config["name"], + sample=sample_function, + ) + except KeyError as error: + raise ActiveLearningConfigurationError( + f"In configuration of `classes_based_sampling` missing key detected: {error}." + ) from error def sample_based_on_classes( diff --git a/inference/core/active_learning/samplers/number_of_detections.py b/inference/core/active_learning/samplers/number_of_detections.py index f8c50e3e2..16bd1cc0a 100644 --- a/inference/core/active_learning/samplers/number_of_detections.py +++ b/inference/core/active_learning/samplers/number_of_detections.py @@ -31,28 +31,35 @@ def initialize_detections_number_based_sampling( strategy_config: Dict[str, Any] ) -> SamplingMethod: - less_than_objects = strategy_config.get("less_than_objects") - more_than_objects = strategy_config.get("more_than_objects") - both_nones = less_than_objects is None and more_than_objects is None - both_has_values = less_than_objects is not None and more_than_objects is not None - if both_nones or both_has_values: - raise ActiveLearningConfigurationError( - f"Only one from `less_than_objects` and `more_than_objects` values must be set." + try: + less_than_objects = strategy_config.get("less_than_objects") + more_than_objects = strategy_config.get("more_than_objects") + both_nones = less_than_objects is None and more_than_objects is None + both_has_values = ( + less_than_objects is not None and more_than_objects is not None ) - selected_class_names = strategy_config.get("selected_class_names") - if selected_class_names is not None: - selected_class_names = set(selected_class_names) - sample_function = partial( - sample_close_to_threshold, - less_than_objects=less_than_objects, - more_than_objects=more_than_objects, - selected_class_names=selected_class_names, - probability=strategy_config["probability"], - ) - return SamplingMethod( - name=strategy_config["name"], - sample=sample_function, - ) + if both_nones or both_has_values: + raise ActiveLearningConfigurationError( + f"Only one from `less_than_objects` and `more_than_objects` values must be set." + ) + selected_class_names = strategy_config.get("selected_class_names") + if selected_class_names is not None: + selected_class_names = set(selected_class_names) + sample_function = partial( + sample_close_to_threshold, + less_than_objects=less_than_objects, + more_than_objects=more_than_objects, + selected_class_names=selected_class_names, + probability=strategy_config["probability"], + ) + return SamplingMethod( + name=strategy_config["name"], + sample=sample_function, + ) + except KeyError as error: + raise ActiveLearningConfigurationError( + f"In configuration of `detections_number_based_sampling` missing key detected: {error}." + ) from error def sample_based_on_detections_number( diff --git a/inference/core/active_learning/samplers/random.py b/inference/core/active_learning/samplers/random.py index f906a2e40..42df157e5 100644 --- a/inference/core/active_learning/samplers/random.py +++ b/inference/core/active_learning/samplers/random.py @@ -9,17 +9,23 @@ PredictionType, SamplingMethod, ) +from inference.core.exceptions import ActiveLearningConfigurationError def initialize_random_sampling(strategy_config: Dict[str, Any]) -> SamplingMethod: - sample_function = partial( - sample_randomly, - traffic_percentage=strategy_config["traffic_percentage"], - ) - return SamplingMethod( - name=strategy_config["name"], - sample=sample_function, - ) + try: + sample_function = partial( + sample_randomly, + traffic_percentage=strategy_config["traffic_percentage"], + ) + return SamplingMethod( + name=strategy_config["name"], + sample=sample_function, + ) + except KeyError as error: + raise ActiveLearningConfigurationError( + f"In configuration of `random_sampling` missing key detected: {error}." + ) from error def sample_randomly( diff --git a/tests/inference/unit_tests/core/active_learning/samplers/test_close_to_threshold.py b/tests/inference/unit_tests/core/active_learning/samplers/test_close_to_threshold.py index 692332df0..6013ba050 100644 --- a/tests/inference/unit_tests/core/active_learning/samplers/test_close_to_threshold.py +++ b/tests/inference/unit_tests/core/active_learning/samplers/test_close_to_threshold.py @@ -21,6 +21,7 @@ CLASSIFICATION_TASK, OBJECT_DETECTION_TASK, ) +from inference.core.exceptions import ActiveLearningConfigurationError OBJECT_DETECTION_PREDICTION = { "predictions": [ @@ -776,3 +777,18 @@ def test_initialize_close_to_threshold_sampling_when_objects_close_to_threshold_ minimum_objects_close_to_threshold=6, probability=0.6, ) + + +def test_initialize_close_to_threshold_sampling_when_configuration_key_missing() -> ( + None +): + # given + strategy_config = { + "name": "ambulance_high_confidence", + "threshold": 0.75, + "epsilon": 0.25, + } + + # when + with pytest.raises(ActiveLearningConfigurationError): + _ = initialize_close_to_threshold_sampling(strategy_config=strategy_config) diff --git a/tests/inference/unit_tests/core/active_learning/samplers/test_contains_classes.py b/tests/inference/unit_tests/core/active_learning/samplers/test_contains_classes.py new file mode 100644 index 000000000..dfe0714d8 --- /dev/null +++ b/tests/inference/unit_tests/core/active_learning/samplers/test_contains_classes.py @@ -0,0 +1,192 @@ +from unittest import mock +from unittest.mock import MagicMock + +import numpy as np +import pytest + +from inference.core.active_learning.samplers.contains_classes import ( + sample_based_on_classes, + initialize_classes_based_sampling, +) +from inference.core.active_learning.samplers import contains_classes +from inference.core.constants import OBJECT_DETECTION_TASK, CLASSIFICATION_TASK +from inference.core.exceptions import ActiveLearningConfigurationError + +OBJECT_DETECTION_PREDICTION = { + "predictions": [ + { + "x": 784.5, + "y": 397.5, + "width": 187.0, + "height": 309.0, + "confidence": 0.9, + "class": "a", + "class_id": 1, + }, + { + "x": 784.5, + "y": 397.5, + "width": 187.0, + "height": 309.0, + "confidence": 0.7, + "class": "c", + "class_id": 1, + }, + ] +} +MULTI_LABEL_CLASSIFICATION_PREDICTION = { + "image": {"width": 416, "height": 416}, + "predictions": { + "cat": {"confidence": 0.97}, + "dog": {"confidence": 0.03}, + }, + "predicted_classes": ["cat"], +} + + +def test_sample_based_on_classes_for_detection_predictions_when_classes_detected() -> ( + None +): + # when + result = sample_based_on_classes( + image=np.zeros((128, 128, 3), dtype=np.uint8), + prediction=OBJECT_DETECTION_PREDICTION, + prediction_type=OBJECT_DETECTION_TASK, + selected_class_names={"a", "b", "c"}, + minimum_objects=2, + probability=1.0, + ) + + # then + assert result is True + + +def test_sample_based_on_classes_for_detection_predictions_when_classes_not_detected() -> ( + None +): + # when + result = sample_based_on_classes( + image=np.zeros((128, 128, 3), dtype=np.uint8), + prediction=OBJECT_DETECTION_PREDICTION, + prediction_type=OBJECT_DETECTION_TASK, + selected_class_names={"a", "b"}, + minimum_objects=2, + probability=1.0, + ) + + # then + assert result is False + + +def test_sample_based_on_classes_for_classification_prediction_when_classes_detected() -> ( + None +): + # when + result = sample_based_on_classes( + image=np.zeros((128, 128, 3), dtype=np.uint8), + prediction=MULTI_LABEL_CLASSIFICATION_PREDICTION, + prediction_type=CLASSIFICATION_TASK, + selected_class_names={"cat"}, + minimum_objects=1, + probability=1.0, + ) + + # then + assert result is True + + +def test_sample_based_on_classes_for_classification_prediction_when_classes_not_detected() -> ( + None +): + # when + result = sample_based_on_classes( + image=np.zeros((128, 128, 3), dtype=np.uint8), + prediction=MULTI_LABEL_CLASSIFICATION_PREDICTION, + prediction_type=CLASSIFICATION_TASK, + selected_class_names={"dog"}, + minimum_objects=1, + probability=1.0, + ) + + # then + assert result is False + + +def test_initialize_classes_based_sampling() -> None: + # given + strategy_config = { + "name": "detect_dogs", + "probability": 1.0, + "selected_class_names": ["dog"], + } + + # when + sampling_method = initialize_classes_based_sampling(strategy_config=strategy_config) + result = sampling_method.sample( + np.zeros((128, 128, 3), dtype=np.uint8), + MULTI_LABEL_CLASSIFICATION_PREDICTION, + CLASSIFICATION_TASK, + ) + + # then + assert result is False + + +@mock.patch.object(contains_classes, "partial") +def test_initialize_classes_based_sampling_when_minimum_objects_specified( + partial_mock: MagicMock, +) -> None: + # given + strategy_config = { + "name": "detect_dogs", + "probability": 1.0, + "selected_class_names": ["dog"], + } + + # when + _ = initialize_classes_based_sampling(strategy_config=strategy_config) + + # then + partial_mock.assert_called_once_with( + sample_based_on_classes, + selected_class_names={"dog"}, + minimum_objects=1, + probability=1.0, + ) + + +@mock.patch.object(contains_classes, "partial") +def test_initialize_classes_based_sampling_when_minimum_objects_not_specified( + partial_mock: MagicMock, +) -> None: + # given + strategy_config = { + "name": "detect_dogs", + "probability": 0.5, + "selected_class_names": ["dog"], + "minimum_objects": 10, + } + + # when + _ = initialize_classes_based_sampling(strategy_config=strategy_config) + + # then + partial_mock.assert_called_once_with( + sample_based_on_classes, + selected_class_names={"dog"}, + minimum_objects=10, + probability=0.5, + ) + + +def test_initialize_classes_based_sampling_when_configuration_key_missing() -> None: + # given + strategy_config = { + "name": "detect_dogs", + "selected_class_names": ["dog"], + "minimum_objects": 10, + } + + # when + with pytest.raises(ActiveLearningConfigurationError): + _ = initialize_classes_based_sampling(strategy_config=strategy_config) diff --git a/tests/inference/unit_tests/core/active_learning/samplers/test_random.py b/tests/inference/unit_tests/core/active_learning/samplers/test_random.py index a2417a95f..22b74567d 100644 --- a/tests/inference/unit_tests/core/active_learning/samplers/test_random.py +++ b/tests/inference/unit_tests/core/active_learning/samplers/test_random.py @@ -6,6 +6,7 @@ from inference.core.active_learning.samplers.random import initialize_random_sampling from inference.core.active_learning.samplers import random +from inference.core.exceptions import ActiveLearningConfigurationError @mock.patch.object(random.random, "random") @@ -57,7 +58,7 @@ def test_initialize_random_sampling_when_strategy_name_is_not_present() -> None: } # when - with pytest.raises(KeyError): + with pytest.raises(ActiveLearningConfigurationError): _ = initialize_random_sampling(strategy_config=strategy_config) @@ -75,5 +76,5 @@ def test_initialize_random_sampling_when_traffic_percentage_is_not_present() -> } # when - with pytest.raises(KeyError): + with pytest.raises(ActiveLearningConfigurationError): _ = initialize_random_sampling(strategy_config=strategy_config) From f0f0b91cd4594706be6fab2415d5db486add0939 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20P=C4=99czek?= Date: Fri, 3 Nov 2023 14:20:30 +0100 Subject: [PATCH 11/15] Add tests for detections number based sampling --- .../samplers/contains_classes.py | 9 +- .../samplers/number_of_detections.py | 65 ++-- .../samplers/test_contains_classes.py | 91 ++---- .../samplers/test_number_of_detections.py | 286 ++++++++++++++++++ 4 files changed, 350 insertions(+), 101 deletions(-) create mode 100644 tests/inference/unit_tests/core/active_learning/samplers/test_number_of_detections.py diff --git a/inference/core/active_learning/samplers/contains_classes.py b/inference/core/active_learning/samplers/contains_classes.py index defb86f64..854dc3716 100644 --- a/inference/core/active_learning/samplers/contains_classes.py +++ b/inference/core/active_learning/samplers/contains_classes.py @@ -11,8 +11,11 @@ from inference.core.active_learning.samplers.close_to_threshold import ( sample_close_to_threshold, ) +from inference.core.constants import CLASSIFICATION_TASK from inference.core.exceptions import ActiveLearningConfigurationError +ELIGIBLE_PREDICTION_TYPES = {CLASSIFICATION_TASK} + def initialize_classes_based_sampling( strategy_config: Dict[str, Any] @@ -21,7 +24,6 @@ def initialize_classes_based_sampling( sample_function = partial( sample_based_on_classes, selected_class_names=set(strategy_config["selected_class_names"]), - minimum_objects=strategy_config.get("minimum_objects", 1), probability=strategy_config["probability"], ) return SamplingMethod( @@ -39,9 +41,10 @@ def sample_based_on_classes( prediction: Prediction, prediction_type: PredictionType, selected_class_names: Set[str], - minimum_objects: int, probability: float, ) -> bool: + if prediction_type not in ELIGIBLE_PREDICTION_TYPES: + return False return sample_close_to_threshold( image=image, prediction=prediction, @@ -50,6 +53,6 @@ def sample_based_on_classes( threshold=0.5, epsilon=1.0, only_top_classes=True, - minimum_objects_close_to_threshold=minimum_objects, + minimum_objects_close_to_threshold=1, probability=probability, ) diff --git a/inference/core/active_learning/samplers/number_of_detections.py b/inference/core/active_learning/samplers/number_of_detections.py index 16bd1cc0a..4b80351ba 100644 --- a/inference/core/active_learning/samplers/number_of_detections.py +++ b/inference/core/active_learning/samplers/number_of_detections.py @@ -12,7 +12,6 @@ from inference.core.active_learning.samplers.close_to_threshold import ( count_detections_close_to_threshold, is_prediction_a_stub, - sample_close_to_threshold, ) from inference.core.constants import ( INSTANCE_SEGMENTATION_TASK, @@ -32,23 +31,16 @@ def initialize_detections_number_based_sampling( strategy_config: Dict[str, Any] ) -> SamplingMethod: try: - less_than_objects = strategy_config.get("less_than_objects") - more_than_objects = strategy_config.get("more_than_objects") - both_nones = less_than_objects is None and more_than_objects is None - both_has_values = ( - less_than_objects is not None and more_than_objects is not None - ) - if both_nones or both_has_values: - raise ActiveLearningConfigurationError( - f"Only one from `less_than_objects` and `more_than_objects` values must be set." - ) + more_than = strategy_config.get("more_than") + less_than = strategy_config.get("less_than") + ensure_range_configuration_is_valid(more_than=more_than, less_than=less_than) selected_class_names = strategy_config.get("selected_class_names") if selected_class_names is not None: selected_class_names = set(selected_class_names) sample_function = partial( - sample_close_to_threshold, - less_than_objects=less_than_objects, - more_than_objects=more_than_objects, + sample_based_on_detections_number, + less_than=less_than, + more_than=more_than, selected_class_names=selected_class_names, probability=strategy_config["probability"], ) @@ -66,8 +58,8 @@ def sample_based_on_detections_number( image: np.ndarray, prediction: Prediction, prediction_type: PredictionType, - less_than_objects: Optional[int], - more_than_objects: Optional[int], + more_than: Optional[int], + less_than: Optional[int], selected_class_names: Optional[Set[str]], probability: float, ) -> bool: @@ -81,14 +73,35 @@ def sample_based_on_detections_number( threshold=0.5, epsilon=1.0, ) - if ( - less_than_objects is not None - and detections_close_to_threshold >= less_than_objects - ): - return False - if ( - more_than_objects is not None - and detections_close_to_threshold <= more_than_objects + if is_in_range( + value=detections_close_to_threshold, less_than=less_than, more_than=more_than ): - return False - return random.random() < probability + return random.random() < probability + return False + + +def is_in_range( + value: int, + more_than: Optional[int], + less_than: Optional[int], +) -> bool: + # calculates value > more_than and value < less_than, with optional borders of range + less_than_satisfied, more_than_satisfied = less_than is None, more_than is None + if less_than is not None and value < less_than: + less_than_satisfied = True + if more_than is not None and value > more_than: + more_than_satisfied = True + return less_than_satisfied and more_than_satisfied + + +def ensure_range_configuration_is_valid( + more_than: Optional[int], + less_than: Optional[int], +) -> None: + if more_than is None or less_than is None: + return None + if more_than >= less_than: + raise ActiveLearningConfigurationError( + f"Misconfiguration of detections number sampling: " + f"`more_than` parameter ({more_than}) >= `less_than` ({less_than})." + ) diff --git a/tests/inference/unit_tests/core/active_learning/samplers/test_contains_classes.py b/tests/inference/unit_tests/core/active_learning/samplers/test_contains_classes.py index dfe0714d8..c6291bfbe 100644 --- a/tests/inference/unit_tests/core/active_learning/samplers/test_contains_classes.py +++ b/tests/inference/unit_tests/core/active_learning/samplers/test_contains_classes.py @@ -12,28 +12,7 @@ from inference.core.constants import OBJECT_DETECTION_TASK, CLASSIFICATION_TASK from inference.core.exceptions import ActiveLearningConfigurationError -OBJECT_DETECTION_PREDICTION = { - "predictions": [ - { - "x": 784.5, - "y": 397.5, - "width": 187.0, - "height": 309.0, - "confidence": 0.9, - "class": "a", - "class_id": 1, - }, - { - "x": 784.5, - "y": 397.5, - "width": 187.0, - "height": 309.0, - "confidence": 0.7, - "class": "c", - "class_id": 1, - }, - ] -} + MULTI_LABEL_CLASSIFICATION_PREDICTION = { "image": {"width": 416, "height": 416}, "predictions": { @@ -44,33 +23,28 @@ } -def test_sample_based_on_classes_for_detection_predictions_when_classes_detected() -> ( - None -): - # when - result = sample_based_on_classes( - image=np.zeros((128, 128, 3), dtype=np.uint8), - prediction=OBJECT_DETECTION_PREDICTION, - prediction_type=OBJECT_DETECTION_TASK, - selected_class_names={"a", "b", "c"}, - minimum_objects=2, - probability=1.0, - ) - - # then - assert result is True - +def test_sample_based_on_classes_for_detection_predictions() -> None: + # given + prediction = { + "predictions": [ + { + "x": 784.5, + "y": 397.5, + "width": 187.0, + "height": 309.0, + "confidence": 0.9, + "class": "a", + "class_id": 1, + } + ] + } -def test_sample_based_on_classes_for_detection_predictions_when_classes_not_detected() -> ( - None -): # when result = sample_based_on_classes( image=np.zeros((128, 128, 3), dtype=np.uint8), - prediction=OBJECT_DETECTION_PREDICTION, + prediction=prediction, prediction_type=OBJECT_DETECTION_TASK, - selected_class_names={"a", "b"}, - minimum_objects=2, + selected_class_names={"a", "b", "c"}, probability=1.0, ) @@ -87,7 +61,6 @@ def test_sample_based_on_classes_for_classification_prediction_when_classes_dete prediction=MULTI_LABEL_CLASSIFICATION_PREDICTION, prediction_type=CLASSIFICATION_TASK, selected_class_names={"cat"}, - minimum_objects=1, probability=1.0, ) @@ -104,7 +77,6 @@ def test_sample_based_on_classes_for_classification_prediction_when_classes_not_ prediction=MULTI_LABEL_CLASSIFICATION_PREDICTION, prediction_type=CLASSIFICATION_TASK, selected_class_names={"dog"}, - minimum_objects=1, probability=1.0, ) @@ -133,7 +105,7 @@ def test_initialize_classes_based_sampling() -> None: @mock.patch.object(contains_classes, "partial") -def test_initialize_classes_based_sampling_when_minimum_objects_specified( +def test_initialize_classes_based_sampling_against_parameters_correctness( partial_mock: MagicMock, ) -> None: # given @@ -150,35 +122,10 @@ def test_initialize_classes_based_sampling_when_minimum_objects_specified( partial_mock.assert_called_once_with( sample_based_on_classes, selected_class_names={"dog"}, - minimum_objects=1, probability=1.0, ) -@mock.patch.object(contains_classes, "partial") -def test_initialize_classes_based_sampling_when_minimum_objects_not_specified( - partial_mock: MagicMock, -) -> None: - # given - strategy_config = { - "name": "detect_dogs", - "probability": 0.5, - "selected_class_names": ["dog"], - "minimum_objects": 10, - } - - # when - _ = initialize_classes_based_sampling(strategy_config=strategy_config) - - # then - partial_mock.assert_called_once_with( - sample_based_on_classes, - selected_class_names={"dog"}, - minimum_objects=10, - probability=0.5, - ) - - def test_initialize_classes_based_sampling_when_configuration_key_missing() -> None: # given strategy_config = { diff --git a/tests/inference/unit_tests/core/active_learning/samplers/test_number_of_detections.py b/tests/inference/unit_tests/core/active_learning/samplers/test_number_of_detections.py new file mode 100644 index 000000000..2f7721640 --- /dev/null +++ b/tests/inference/unit_tests/core/active_learning/samplers/test_number_of_detections.py @@ -0,0 +1,286 @@ +from typing import Optional +from unittest import mock +from unittest.mock import MagicMock + +import numpy as np +import pytest + +from inference.core.active_learning.samplers.number_of_detections import ( + is_in_range, + sample_based_on_detections_number, + initialize_detections_number_based_sampling, +) +from inference.core.active_learning.samplers import number_of_detections +from inference.core.constants import CLASSIFICATION_TASK, OBJECT_DETECTION_TASK +from inference.core.exceptions import ActiveLearningConfigurationError + +OBJECT_DETECTION_PREDICTION = { + "predictions": [ + { + "x": 784.5, + "y": 397.5, + "width": 187.0, + "height": 309.0, + "confidence": 0.9, + "class": "a", + "class_id": 1, + }, + { + "x": 784.5, + "y": 397.5, + "width": 187.0, + "height": 309.0, + "confidence": 0.7, + "class": "b", + "class_id": 1, + }, + ] +} + + +@pytest.mark.parametrize( + "value, more_than, less_than", + [ + (1, None, 2), + (2, 1, 5), + (5, 4, None), + (1, None, None), + ], +) +def test_is_in_range_value_meets_condition( + value: int, more_than: Optional[int], less_than: Optional[int] +) -> None: + # when + result = is_in_range(value=value, more_than=more_than, less_than=less_than) + + # then + assert result is True + + +@pytest.mark.parametrize( + "value, more_than, less_than", + [ + (1, 2, None), + (2, 5, 1), + (5, None, 4), + ], +) +def test_is_in_range_value_does_not_meet_condition( + value: int, more_than: Optional[int], less_than: Optional[int] +) -> None: + # when + result = is_in_range(value=value, more_than=more_than, less_than=less_than) + + # then + assert result is False + + +def test_sample_based_on_detections_number_when_classification_prediction_given() -> ( + None +): + # given + prediction = { + "image": {"width": 416, "height": 416}, + "predictions": { + "cat": {"confidence": 0.97}, + "dog": {"confidence": 0.03}, + }, + "predicted_classes": ["cat"], + } + + # when + result = sample_based_on_detections_number( + image=np.zeros((128, 128, 3), dtype=np.uint8), + prediction=prediction, + prediction_type=CLASSIFICATION_TASK, + less_than=None, + more_than=None, + selected_class_names=None, + probability=1.0, + ) + + # then + assert result is False + + +def test_sample_based_on_detections_number_when_stub_prediction_given() -> None: + # when + result = sample_based_on_detections_number( + image=np.zeros((128, 128, 3), dtype=np.uint8), + prediction={"is_stub": True}, + prediction_type=CLASSIFICATION_TASK, + less_than=None, + more_than=None, + selected_class_names=None, + probability=1.0, + ) + + # then + assert result is False + + +@mock.patch.object(number_of_detections.random, "random") +def test_sample_based_on_detections_number_when_detections_in_range_and_sampling_succeeds( + random_mock: MagicMock, +) -> None: + # given + random_mock.return_value = 0.29 + + # when + result = sample_based_on_detections_number( + image=np.zeros((128, 128, 3), dtype=np.uint8), + prediction=OBJECT_DETECTION_PREDICTION, + prediction_type=OBJECT_DETECTION_TASK, + more_than=1, + less_than=3, + selected_class_names={"a", "b"}, + probability=0.3, + ) + + # then + assert result is True + + +@mock.patch.object(number_of_detections.random, "random") +def test_sample_based_on_detections_number_when_detections_in_range_and_sampling_fails( + random_mock: MagicMock, +) -> None: + # given + random_mock.return_value = 0.31 + + # when + result = sample_based_on_detections_number( + image=np.zeros((128, 128, 3), dtype=np.uint8), + prediction=OBJECT_DETECTION_PREDICTION, + prediction_type=OBJECT_DETECTION_TASK, + more_than=1, + less_than=3, + selected_class_names={"a", "b"}, + probability=0.3, + ) + + # then + assert result is False + + +def test_sample_based_on_detections_number_when_detections_not_in_range() -> None: + # when + result = sample_based_on_detections_number( + image=np.zeros((128, 128, 3), dtype=np.uint8), + prediction=OBJECT_DETECTION_PREDICTION, + prediction_type=OBJECT_DETECTION_TASK, + more_than=2, + less_than=3, + selected_class_names={"a"}, + probability=0.3, + ) + + # then + assert result is False + + +def test_initialize_detections_number_based_sampling() -> None: + # given + strategy_config = { + "name": "two_detections", + "less_than": 3, + "more_than": 1, + "selected_class_names": {"a", "b"}, + "probability": 1.0, + } + + # when + sampling_method = initialize_detections_number_based_sampling( + strategy_config=strategy_config + ) + result = sampling_method.sample( + np.zeros((128, 128, 3), dtype=np.uint8), + OBJECT_DETECTION_PREDICTION, + OBJECT_DETECTION_TASK, + ) + + # then + assert result is True + assert sampling_method.name == "two_detections" + + +@mock.patch.object(number_of_detections, "partial") +def test_test_initialize_detections_number_based_sampling_when_optional_values_not_given( + partial_mock: MagicMock, +) -> None: + # given + strategy_config = { + "name": "two_detections", + "probability": 1.0, + } + + # when + _ = initialize_detections_number_based_sampling(strategy_config=strategy_config) + + # then + partial_mock.assert_called_once_with( + sample_based_on_detections_number, + less_than=None, + more_than=None, + selected_class_names=None, + probability=1.0, + ) + + +@mock.patch.object(number_of_detections, "partial") +def test_test_initialize_detections_number_based_sampling_when_optional_values_given( + partial_mock: MagicMock, +) -> None: + # given + strategy_config = { + "name": "two_detections", + "probability": 1.0, + "less_than": 10, + "more_than": 5, + "selected_class_names": ["a", "b"], + } + + # when + _ = initialize_detections_number_based_sampling(strategy_config=strategy_config) + + # then + partial_mock.assert_called_once_with( + sample_based_on_detections_number, + less_than=10, + more_than=5, + selected_class_names={"a", "b"}, + probability=1.0, + ) + + +def test_initialize_detections_number_based_sampling_when_required_value_missing() -> ( + None +): + # given + strategy_config = { + "name": "two_detections", + "less_than": 10, + "more_than": 5, + "selected_class_names": ["a", "b"], + } + + # when + with pytest.raises(ActiveLearningConfigurationError): + _ = initialize_detections_number_based_sampling(strategy_config=strategy_config) + + +def test_initialize_detections_number_based_sampling_when_malformed_config_given() -> ( + None +): + # given + strategy_config = { + "name": "two_detections", + "less_than": 5, + "more_than": 6, + "probability": 1.0, + "selected_class_names": ["a", "b"], + } + + # when + with pytest.raises(ActiveLearningConfigurationError): + _ = initialize_detections_number_based_sampling(strategy_config=strategy_config) From ea1f173add47a4035fdf7fca6d29951b1afae80e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20P=C4=99czek?= Date: Fri, 3 Nov 2023 14:34:55 +0100 Subject: [PATCH 12/15] Update tests for configuration --- .../core/active_learning/configuration.py | 14 ++- inference/core/roboflow_api.py | 4 +- .../active_learning/test_configuration.py | 102 ++++++++++++++++-- 3 files changed, 104 insertions(+), 16 deletions(-) diff --git a/inference/core/active_learning/configuration.py b/inference/core/active_learning/configuration.py index dffbec111..d337bee38 100644 --- a/inference/core/active_learning/configuration.py +++ b/inference/core/active_learning/configuration.py @@ -17,6 +17,7 @@ ) from inference.core.active_learning.samplers.random import initialize_random_sampling from inference.core.env import ACTIVE_LEARNING_ENABLED +from inference.core.exceptions import ActiveLearningConfigurationError from inference.core.roboflow_api import ( get_roboflow_active_learning_configuration, get_roboflow_dataset_type, @@ -25,10 +26,10 @@ from inference.core.utils.roboflow import get_model_id_chunks TYPE2SAMPLING_INITIALIZERS = { - "random_sampling": initialize_random_sampling, - "close_to_threshold_sampling": initialize_close_to_threshold_sampling, - "classes_based_sampling": initialize_classes_based_sampling, - "detections_number_based_sampling": initialize_detections_number_based_sampling, + "random": initialize_random_sampling, + "close_to_threshold": initialize_close_to_threshold_sampling, + "classes_based": initialize_classes_based_sampling, + "detections_number_based": initialize_detections_number_based_sampling, } @@ -106,4 +107,9 @@ def initialize_sampling_methods( continue initializer = TYPE2SAMPLING_INITIALIZERS[sampling_type] result.append(initializer(sampling_strategy_config)) + names = set(m.name for m in result) + if len(names) != len(result): + raise ActiveLearningConfigurationError( + "Detected duplication of Active Learning strategies names." + ) return result diff --git a/inference/core/roboflow_api.py b/inference/core/roboflow_api.py index e6698eddd..7c108e91f 100644 --- a/inference/core/roboflow_api.py +++ b/inference/core/roboflow_api.py @@ -212,9 +212,9 @@ def get_roboflow_active_learning_configuration( "sampling_strategies": [ { "name": "default_strategy", - "type": "random_sampling", + "type": "random", "traffic_percentage": 0.1, # float 0-1 - "tags": ["c", "d"], # Optional + "tags": ["random_traffic"], # Optional "limits": [ # Optional {"type": "minutely", "value": 10}, {"type": "hourly", "value": 100}, diff --git a/tests/inference/unit_tests/core/active_learning/test_configuration.py b/tests/inference/unit_tests/core/active_learning/test_configuration.py index 58eafcc13..cd0f33e03 100644 --- a/tests/inference/unit_tests/core/active_learning/test_configuration.py +++ b/tests/inference/unit_tests/core/active_learning/test_configuration.py @@ -3,6 +3,7 @@ from unittest.mock import MagicMock import numpy as np +import pytest from inference.core.active_learning.configuration import ( initialize_sampling_methods, @@ -18,6 +19,7 @@ StrategyLimit, StrategyLimitType, ) +from inference.core.exceptions import ActiveLearningConfigurationError def test_initialize_sampling_methods() -> None: @@ -25,9 +27,9 @@ def test_initialize_sampling_methods() -> None: sampling_strategies_configs = [ { "name": "default_strategy", - "type": "random_sampling", + "type": "random", "traffic_percentage": 0.5, - "tags": ["c", "d"], + "tags": ["a"], "limits": [ {"type": "minutely", "value": 10}, {"type": "hourly", "value": 100}, @@ -35,6 +37,44 @@ def test_initialize_sampling_methods() -> None: ], }, {"type": "non-existing"}, + { + "name": "hard_examples", + "type": "close_to_threshold", + "selected_class_names": ["a", "b"], + "threshold": 0.25, + "epsilon": 0.1, + "probability": 0.5, + "tags": ["b"], + "limits": [ + {"type": "minutely", "value": 10}, + {"type": "hourly", "value": 100}, + {"type": "daily", "value": 1000}, + ], + }, + { + "name": "underrepresented_classes", + "type": "classes_based", + "selected_class_names": ["a"], + "probability": 0.5, + "tags": ["hard_classes"], + "limits": [ + {"type": "minutely", "value": 10}, + {"type": "hourly", "value": 100}, + {"type": "daily", "value": 1000}, + ], + }, + { + "name": "low_detections", + "type": "detections_number_based", + "probability": 0.5, + "less_than": 3, + "tags": ["empty"], + "limits": [ + {"type": "minutely", "value": 10}, + {"type": "hourly", "value": 100}, + {"type": "daily", "value": 1000}, + ], + }, ] # when @@ -43,13 +83,19 @@ def test_initialize_sampling_methods() -> None: ) # then - assert len(result) == 1 - assert result[0].name == "default_strategy" - _ = result[0].sample( # test if sampling executed correctly - np.zeros((128, 128, 3), dtype=np.ndarray), - {"some": "prediction"}, - "object-detection", - ) + assert len(result) == 4 + assert [r.name for r in result] == [ + "default_strategy", + "hard_examples", + "underrepresented_classes", + "low_detections", + ] + for strategy in result: + _ = strategy.sample( # test if sampling executed correctly + np.zeros((128, 128, 3), dtype=np.ndarray), + {"is_stub": "True"}, + "object-detection", + ) @mock.patch.object(configuration, "get_roboflow_active_learning_configuration") @@ -148,7 +194,7 @@ def test_prepare_active_learning_configuration_when_active_learning_enabled( "sampling_strategies": [ { "name": "default_strategy", - "type": "random_sampling", + "type": "random", "traffic_percentage": 0.1, "tags": ["c", "d"], "limits": [ @@ -200,3 +246,39 @@ def test_prepare_active_learning_configuration_when_active_learning_enabled( tags=["a", "b"], strategies_tags={"default_strategy": ["c", "d"]}, ) + + +def test_test_initialize_sampling_methods_when_duplicate_names_detected() -> None: + # given + sampling_strategies_configs = [ + { + "name": "default_strategy", + "type": "random", + "traffic_percentage": 0.5, + "tags": ["a"], + "limits": [ + {"type": "minutely", "value": 10}, + {"type": "hourly", "value": 100}, + {"type": "daily", "value": 1000}, + ], + }, + { + "name": "default_strategy", + "type": "close_to_threshold", + "selected_class_names": ["a", "b"], + "threshold": 0.25, + "epsilon": 0.1, + "probability": 0.5, + "tags": ["b"], + "limits": [ + {"type": "minutely", "value": 10}, + {"type": "hourly", "value": 100}, + {"type": "daily", "value": 1000}, + ], + }, + ] + + with pytest.raises(ActiveLearningConfigurationError): + _ = initialize_sampling_methods( + sampling_strategies_configs=sampling_strategies_configs + ) From 5135b95bf9218159336b676de765edd7a037d71f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20P=C4=99czek?= Date: Fri, 3 Nov 2023 14:39:57 +0100 Subject: [PATCH 13/15] Added dummy api response for new strategies --- inference/core/roboflow_api.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/inference/core/roboflow_api.py b/inference/core/roboflow_api.py index 7c108e91f..4dff8c2b8 100644 --- a/inference/core/roboflow_api.py +++ b/inference/core/roboflow_api.py @@ -220,7 +220,32 @@ def get_roboflow_active_learning_configuration( {"type": "hourly", "value": 100}, {"type": "daily", "value": 1000}, ], - } + }, + { + "name": "hard_examples", + "type": "close_to_threshold", + "threshold": 0.3, + "epsilon": 0.3, + "probability": 1.0, + "tags": ["hard_case"], + "limits": [ + {"type": "minutely", "value": 10}, + {"type": "hourly", "value": 100}, + {"type": "daily", "value": 1000}, + ], + }, + { + "name": "multiple_detections", + "type": "detections_number_based", + "probability": 1.0, + "more_than": 3, + "tags": ["crowded"], + "limits": [ + {"type": "minutely", "value": 10}, + {"type": "hourly", "value": 100}, + {"type": "daily", "value": 1000}, + ], + }, ], "batching_strategy": { "batches_name_prefix": "al_batch", From 9f397879339d1e37fb1a6f09844bb6e92c32d509 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20P=C4=99czek?= Date: Fri, 3 Nov 2023 14:40:39 +0100 Subject: [PATCH 14/15] Added dummy api response for new strategies --- inference/core/roboflow_api.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/inference/core/roboflow_api.py b/inference/core/roboflow_api.py index 4dff8c2b8..66c0da9c3 100644 --- a/inference/core/roboflow_api.py +++ b/inference/core/roboflow_api.py @@ -246,6 +246,18 @@ def get_roboflow_active_learning_configuration( {"type": "daily", "value": 1000}, ], }, + { + "name": "underrepresented_classes", + "type": "classes_based", + "selected_class_names": ["cat"], + "probability": 1.0, + "tags": ["hard_classes"], + "limits": [ + {"type": "minutely", "value": 10}, + {"type": "hourly", "value": 100}, + {"type": "daily", "value": 1000}, + ], + }, ], "batching_strategy": { "batches_name_prefix": "al_batch", From 628564607943821e139021acd496bca5a44f7a84 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20P=C4=99czek?= Date: Fri, 3 Nov 2023 15:16:28 +0100 Subject: [PATCH 15/15] Added dummy api response for new strategies --- inference/core/roboflow_api.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/inference/core/roboflow_api.py b/inference/core/roboflow_api.py index 66c0da9c3..b2b04bd30 100644 --- a/inference/core/roboflow_api.py +++ b/inference/core/roboflow_api.py @@ -214,7 +214,7 @@ def get_roboflow_active_learning_configuration( "name": "default_strategy", "type": "random", "traffic_percentage": 0.1, # float 0-1 - "tags": ["random_traffic"], # Optional + "tags": ["random-traffic"], # Optional "limits": [ # Optional {"type": "minutely", "value": 10}, {"type": "hourly", "value": 100}, @@ -226,8 +226,8 @@ def get_roboflow_active_learning_configuration( "type": "close_to_threshold", "threshold": 0.3, "epsilon": 0.3, - "probability": 1.0, - "tags": ["hard_case"], + "probability": 0.3, + "tags": ["hard-case"], "limits": [ {"type": "minutely", "value": 10}, {"type": "hourly", "value": 100}, @@ -237,7 +237,7 @@ def get_roboflow_active_learning_configuration( { "name": "multiple_detections", "type": "detections_number_based", - "probability": 1.0, + "probability": 0.2, "more_than": 3, "tags": ["crowded"], "limits": [ @@ -251,7 +251,7 @@ def get_roboflow_active_learning_configuration( "type": "classes_based", "selected_class_names": ["cat"], "probability": 1.0, - "tags": ["hard_classes"], + "tags": ["hard-classes"], "limits": [ {"type": "minutely", "value": 10}, {"type": "hourly", "value": 100},