diff --git a/inference/core/active_learning/configuration.py b/inference/core/active_learning/configuration.py index ec76cae05..d337bee38 100644 --- a/inference/core/active_learning/configuration.py +++ b/inference/core/active_learning/configuration.py @@ -6,9 +6,18 @@ RoboflowProjectMetadata, SamplingMethod, ) -from inference.core.active_learning.sampling import initialize_random_sampling -from inference.core.cache.base import BaseCache +from inference.core.active_learning.samplers.close_to_threshold import ( + initialize_close_to_threshold_sampling, +) +from inference.core.active_learning.samplers.contains_classes import ( + initialize_classes_based_sampling, +) +from inference.core.active_learning.samplers.number_of_detections import ( + initialize_detections_number_based_sampling, +) +from inference.core.active_learning.samplers.random import initialize_random_sampling from inference.core.env import ACTIVE_LEARNING_ENABLED +from inference.core.exceptions import ActiveLearningConfigurationError from inference.core.roboflow_api import ( get_roboflow_active_learning_configuration, get_roboflow_dataset_type, @@ -16,7 +25,12 @@ ) from inference.core.utils.roboflow import get_model_id_chunks -TYPE2SAMPLING_INITIALIZERS = {"random_sampling": initialize_random_sampling} +TYPE2SAMPLING_INITIALIZERS = { + "random": initialize_random_sampling, + "close_to_threshold": initialize_close_to_threshold_sampling, + "classes_based": initialize_classes_based_sampling, + "detections_number_based": initialize_detections_number_based_sampling, +} def prepare_active_learning_configuration( @@ -93,4 +107,9 @@ def initialize_sampling_methods( continue initializer = TYPE2SAMPLING_INITIALIZERS[sampling_type] result.append(initializer(sampling_strategy_config)) + names = set(m.name for m in result) + if len(names) != len(result): + raise ActiveLearningConfigurationError( + "Detected duplication of Active Learning strategies names." + ) return result diff --git a/inference/core/active_learning/samplers/__init__.py b/inference/core/active_learning/samplers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/inference/core/active_learning/samplers/close_to_threshold.py b/inference/core/active_learning/samplers/close_to_threshold.py new file mode 100644 index 000000000..e1801ad32 --- /dev/null +++ b/inference/core/active_learning/samplers/close_to_threshold.py @@ -0,0 +1,227 @@ +import random +from functools import partial +from typing import Any, Dict, Optional, Set + +import numpy as np + +from inference.core.active_learning.entities import ( + Prediction, + PredictionType, + SamplingMethod, +) +from inference.core.constants import ( + CLASSIFICATION_TASK, + INSTANCE_SEGMENTATION_TASK, + KEYPOINTS_DETECTION_TASK, + OBJECT_DETECTION_TASK, +) +from inference.core.exceptions import ActiveLearningConfigurationError + +ELIGIBLE_PREDICTION_TYPES = { + CLASSIFICATION_TASK, + INSTANCE_SEGMENTATION_TASK, + KEYPOINTS_DETECTION_TASK, + OBJECT_DETECTION_TASK, +} + + +def initialize_close_to_threshold_sampling( + strategy_config: Dict[str, Any] +) -> SamplingMethod: + try: + selected_class_names = strategy_config.get("selected_class_names") + if selected_class_names is not None: + selected_class_names = set(selected_class_names) + sample_function = partial( + sample_close_to_threshold, + selected_class_names=selected_class_names, + threshold=strategy_config["threshold"], + epsilon=strategy_config["epsilon"], + only_top_classes=strategy_config.get("only_top_classes", False), + minimum_objects_close_to_threshold=strategy_config.get( + "minimum_objects_close_to_threshold", + 1, + ), + probability=strategy_config["probability"], + ) + return SamplingMethod( + name=strategy_config["name"], + sample=sample_function, + ) + except KeyError as error: + raise ActiveLearningConfigurationError( + f"In configuration of `close_to_threshold_sampling` missing key detected: {error}." + ) from error + + +def sample_close_to_threshold( + image: np.ndarray, + prediction: Prediction, + prediction_type: PredictionType, + selected_class_names: Optional[Set[str]], + threshold: float, + epsilon: float, + only_top_classes: bool, + minimum_objects_close_to_threshold: int, + probability: float, +) -> bool: + if is_prediction_a_stub(prediction=prediction): + return False + if prediction_type not in ELIGIBLE_PREDICTION_TYPES: + return False + close_to_threshold = prediction_is_close_to_threshold( + prediction=prediction, + prediction_type=prediction_type, + selected_class_names=selected_class_names, + threshold=threshold, + epsilon=epsilon, + only_top_classes=only_top_classes, + minimum_objects_close_to_threshold=minimum_objects_close_to_threshold, + ) + if not close_to_threshold: + return False + return random.random() < probability + + +def is_prediction_a_stub(prediction: Prediction) -> bool: + return prediction.get("is_stub", False) + + +def prediction_is_close_to_threshold( + prediction: Prediction, + prediction_type: PredictionType, + selected_class_names: Optional[Set[str]], + threshold: float, + epsilon: float, + only_top_classes: bool, + minimum_objects_close_to_threshold: int, +) -> bool: + if CLASSIFICATION_TASK not in prediction_type: + return detections_are_close_to_threshold( + prediction=prediction, + selected_class_names=selected_class_names, + threshold=threshold, + epsilon=epsilon, + minimum_objects_close_to_threshold=minimum_objects_close_to_threshold, + ) + checker = multi_label_classification_prediction_is_close_to_threshold + if "top" in prediction: + checker = multi_class_classification_prediction_is_close_to_threshold + return checker( + prediction=prediction, + selected_class_names=selected_class_names, + threshold=threshold, + epsilon=epsilon, + only_top_classes=only_top_classes, + ) + + +def multi_class_classification_prediction_is_close_to_threshold( + prediction: Prediction, + selected_class_names: Optional[Set[str]], + threshold: float, + epsilon: float, + only_top_classes: bool, +) -> bool: + if only_top_classes: + return ( + multi_class_classification_prediction_is_close_to_threshold_for_top_class( + prediction=prediction, + selected_class_names=selected_class_names, + threshold=threshold, + epsilon=epsilon, + ) + ) + for prediction_details in prediction["predictions"]: + if class_to_be_excluded( + class_name=prediction_details["class"], + selected_class_names=selected_class_names, + ): + continue + if is_close_to_threshold( + value=prediction_details["confidence"], threshold=threshold, epsilon=epsilon + ): + return True + return False + + +def multi_class_classification_prediction_is_close_to_threshold_for_top_class( + prediction: Prediction, + selected_class_names: Optional[Set[str]], + threshold: float, + epsilon: float, +) -> bool: + if ( + selected_class_names is not None + and prediction["top"] not in selected_class_names + ): + return False + return abs(prediction["confidence"] - threshold) < epsilon + + +def multi_label_classification_prediction_is_close_to_threshold( + prediction: Prediction, + selected_class_names: Optional[Set[str]], + threshold: float, + epsilon: float, + only_top_classes: bool, +) -> bool: + predicted_classes = set(prediction["predicted_classes"]) + for class_name, prediction_details in prediction["predictions"].items(): + if only_top_classes and class_name not in predicted_classes: + continue + if class_to_be_excluded( + class_name=class_name, selected_class_names=selected_class_names + ): + continue + if is_close_to_threshold( + value=prediction_details["confidence"], threshold=threshold, epsilon=epsilon + ): + return True + return False + + +def detections_are_close_to_threshold( + prediction: Prediction, + selected_class_names: Optional[Set[str]], + threshold: float, + epsilon: float, + minimum_objects_close_to_threshold: int, +) -> bool: + detections_close_to_threshold = count_detections_close_to_threshold( + prediction=prediction, + selected_class_names=selected_class_names, + threshold=threshold, + epsilon=epsilon, + ) + return detections_close_to_threshold >= minimum_objects_close_to_threshold + + +def count_detections_close_to_threshold( + prediction: Prediction, + selected_class_names: Optional[Set[str]], + threshold: float, + epsilon: float, +) -> int: + counter = 0 + for prediction_details in prediction["predictions"]: + if class_to_be_excluded( + class_name=prediction_details["class"], + selected_class_names=selected_class_names, + ): + continue + if is_close_to_threshold( + value=prediction_details["confidence"], threshold=threshold, epsilon=epsilon + ): + counter += 1 + return counter + + +def class_to_be_excluded( + class_name: str, selected_class_names: Optional[Set[str]] +) -> bool: + return selected_class_names is not None and class_name not in selected_class_names + + +def is_close_to_threshold(value: float, threshold: float, epsilon: float) -> bool: + return abs(value - threshold) < epsilon diff --git a/inference/core/active_learning/samplers/contains_classes.py b/inference/core/active_learning/samplers/contains_classes.py new file mode 100644 index 000000000..854dc3716 --- /dev/null +++ b/inference/core/active_learning/samplers/contains_classes.py @@ -0,0 +1,58 @@ +from functools import partial +from typing import Any, Dict, Set + +import numpy as np + +from inference.core.active_learning.entities import ( + Prediction, + PredictionType, + SamplingMethod, +) +from inference.core.active_learning.samplers.close_to_threshold import ( + sample_close_to_threshold, +) +from inference.core.constants import CLASSIFICATION_TASK +from inference.core.exceptions import ActiveLearningConfigurationError + +ELIGIBLE_PREDICTION_TYPES = {CLASSIFICATION_TASK} + + +def initialize_classes_based_sampling( + strategy_config: Dict[str, Any] +) -> SamplingMethod: + try: + sample_function = partial( + sample_based_on_classes, + selected_class_names=set(strategy_config["selected_class_names"]), + probability=strategy_config["probability"], + ) + return SamplingMethod( + name=strategy_config["name"], + sample=sample_function, + ) + except KeyError as error: + raise ActiveLearningConfigurationError( + f"In configuration of `classes_based_sampling` missing key detected: {error}." + ) from error + + +def sample_based_on_classes( + image: np.ndarray, + prediction: Prediction, + prediction_type: PredictionType, + selected_class_names: Set[str], + probability: float, +) -> bool: + if prediction_type not in ELIGIBLE_PREDICTION_TYPES: + return False + return sample_close_to_threshold( + image=image, + prediction=prediction, + prediction_type=prediction_type, + selected_class_names=selected_class_names, + threshold=0.5, + epsilon=1.0, + only_top_classes=True, + minimum_objects_close_to_threshold=1, + probability=probability, + ) diff --git a/inference/core/active_learning/samplers/number_of_detections.py b/inference/core/active_learning/samplers/number_of_detections.py new file mode 100644 index 000000000..4b80351ba --- /dev/null +++ b/inference/core/active_learning/samplers/number_of_detections.py @@ -0,0 +1,107 @@ +import random +from functools import partial +from typing import Any, Dict, Optional, Set + +import numpy as np + +from inference.core.active_learning.entities import ( + Prediction, + PredictionType, + SamplingMethod, +) +from inference.core.active_learning.samplers.close_to_threshold import ( + count_detections_close_to_threshold, + is_prediction_a_stub, +) +from inference.core.constants import ( + INSTANCE_SEGMENTATION_TASK, + KEYPOINTS_DETECTION_TASK, + OBJECT_DETECTION_TASK, +) +from inference.core.exceptions import ActiveLearningConfigurationError + +ELIGIBLE_PREDICTION_TYPES = { + INSTANCE_SEGMENTATION_TASK, + KEYPOINTS_DETECTION_TASK, + OBJECT_DETECTION_TASK, +} + + +def initialize_detections_number_based_sampling( + strategy_config: Dict[str, Any] +) -> SamplingMethod: + try: + more_than = strategy_config.get("more_than") + less_than = strategy_config.get("less_than") + ensure_range_configuration_is_valid(more_than=more_than, less_than=less_than) + selected_class_names = strategy_config.get("selected_class_names") + if selected_class_names is not None: + selected_class_names = set(selected_class_names) + sample_function = partial( + sample_based_on_detections_number, + less_than=less_than, + more_than=more_than, + selected_class_names=selected_class_names, + probability=strategy_config["probability"], + ) + return SamplingMethod( + name=strategy_config["name"], + sample=sample_function, + ) + except KeyError as error: + raise ActiveLearningConfigurationError( + f"In configuration of `detections_number_based_sampling` missing key detected: {error}." + ) from error + + +def sample_based_on_detections_number( + image: np.ndarray, + prediction: Prediction, + prediction_type: PredictionType, + more_than: Optional[int], + less_than: Optional[int], + selected_class_names: Optional[Set[str]], + probability: float, +) -> bool: + if is_prediction_a_stub(prediction=prediction): + return False + if prediction_type not in ELIGIBLE_PREDICTION_TYPES: + return False + detections_close_to_threshold = count_detections_close_to_threshold( + prediction=prediction, + selected_class_names=selected_class_names, + threshold=0.5, + epsilon=1.0, + ) + if is_in_range( + value=detections_close_to_threshold, less_than=less_than, more_than=more_than + ): + return random.random() < probability + return False + + +def is_in_range( + value: int, + more_than: Optional[int], + less_than: Optional[int], +) -> bool: + # calculates value > more_than and value < less_than, with optional borders of range + less_than_satisfied, more_than_satisfied = less_than is None, more_than is None + if less_than is not None and value < less_than: + less_than_satisfied = True + if more_than is not None and value > more_than: + more_than_satisfied = True + return less_than_satisfied and more_than_satisfied + + +def ensure_range_configuration_is_valid( + more_than: Optional[int], + less_than: Optional[int], +) -> None: + if more_than is None or less_than is None: + return None + if more_than >= less_than: + raise ActiveLearningConfigurationError( + f"Misconfiguration of detections number sampling: " + f"`more_than` parameter ({more_than}) >= `less_than` ({less_than})." + ) diff --git a/inference/core/active_learning/samplers/random.py b/inference/core/active_learning/samplers/random.py new file mode 100644 index 000000000..42df157e5 --- /dev/null +++ b/inference/core/active_learning/samplers/random.py @@ -0,0 +1,37 @@ +import random +from functools import partial +from typing import Any, Dict + +import numpy as np + +from inference.core.active_learning.entities import ( + Prediction, + PredictionType, + SamplingMethod, +) +from inference.core.exceptions import ActiveLearningConfigurationError + + +def initialize_random_sampling(strategy_config: Dict[str, Any]) -> SamplingMethod: + try: + sample_function = partial( + sample_randomly, + traffic_percentage=strategy_config["traffic_percentage"], + ) + return SamplingMethod( + name=strategy_config["name"], + sample=sample_function, + ) + except KeyError as error: + raise ActiveLearningConfigurationError( + f"In configuration of `random_sampling` missing key detected: {error}." + ) from error + + +def sample_randomly( + image: np.ndarray, + prediction: Prediction, + prediction_type: PredictionType, + traffic_percentage: float, +) -> bool: + return random.random() < traffic_percentage diff --git a/inference/core/active_learning/sampling.py b/inference/core/active_learning/sampling.py deleted file mode 100644 index f639a7da9..000000000 --- a/inference/core/active_learning/sampling.py +++ /dev/null @@ -1,33 +0,0 @@ -import random -from functools import partial -from typing import Any, Dict - -import numpy as np - -from inference.core.active_learning.entities import ( - Prediction, - PredictionType, - SamplingMethod, -) - - -def initialize_random_sampling(strategy_config: Dict[str, Any]) -> SamplingMethod: - sample_function = partial( - sample_randomly, - traffic_percentage=strategy_config["traffic_percentage"], - ) - return SamplingMethod( - name=strategy_config["name"], - sample=sample_function, - ) - - -def sample_randomly( - image: np.ndarray, - prediction: Prediction, - prediction_type: PredictionType, - traffic_percentage: float, -) -> bool: - if random.random() >= traffic_percentage: - return False - return True diff --git a/inference/core/constants.py b/inference/core/constants.py index 383e36845..a17766247 100644 --- a/inference/core/constants.py +++ b/inference/core/constants.py @@ -1,3 +1,4 @@ CLASSIFICATION_TASK = "classification" OBJECT_DETECTION_TASK = "object-detection" INSTANCE_SEGMENTATION_TASK = "instance-segmentation" +KEYPOINTS_DETECTION_TASK = "keypoints-detection" diff --git a/inference/core/exceptions.py b/inference/core/exceptions.py index 0be636c2f..41f509e28 100644 --- a/inference/core/exceptions.py +++ b/inference/core/exceptions.py @@ -172,3 +172,7 @@ class PredictionFormatNotSupported(ActiveLearningError): class ActiveLearningConfigurationDecodingError(ActiveLearningError): pass + + +class ActiveLearningConfigurationError(ActiveLearningError): + pass diff --git a/inference/core/roboflow_api.py b/inference/core/roboflow_api.py index e6698eddd..b2b04bd30 100644 --- a/inference/core/roboflow_api.py +++ b/inference/core/roboflow_api.py @@ -212,15 +212,52 @@ def get_roboflow_active_learning_configuration( "sampling_strategies": [ { "name": "default_strategy", - "type": "random_sampling", + "type": "random", "traffic_percentage": 0.1, # float 0-1 - "tags": ["c", "d"], # Optional + "tags": ["random-traffic"], # Optional "limits": [ # Optional {"type": "minutely", "value": 10}, {"type": "hourly", "value": 100}, {"type": "daily", "value": 1000}, ], - } + }, + { + "name": "hard_examples", + "type": "close_to_threshold", + "threshold": 0.3, + "epsilon": 0.3, + "probability": 0.3, + "tags": ["hard-case"], + "limits": [ + {"type": "minutely", "value": 10}, + {"type": "hourly", "value": 100}, + {"type": "daily", "value": 1000}, + ], + }, + { + "name": "multiple_detections", + "type": "detections_number_based", + "probability": 0.2, + "more_than": 3, + "tags": ["crowded"], + "limits": [ + {"type": "minutely", "value": 10}, + {"type": "hourly", "value": 100}, + {"type": "daily", "value": 1000}, + ], + }, + { + "name": "underrepresented_classes", + "type": "classes_based", + "selected_class_names": ["cat"], + "probability": 1.0, + "tags": ["hard-classes"], + "limits": [ + {"type": "minutely", "value": 10}, + {"type": "hourly", "value": 100}, + {"type": "daily", "value": 1000}, + ], + }, ], "batching_strategy": { "batches_name_prefix": "al_batch", diff --git a/tests/inference/unit_tests/core/active_learning/samplers/__init__.py b/tests/inference/unit_tests/core/active_learning/samplers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/inference/unit_tests/core/active_learning/samplers/test_close_to_threshold.py b/tests/inference/unit_tests/core/active_learning/samplers/test_close_to_threshold.py new file mode 100644 index 000000000..6013ba050 --- /dev/null +++ b/tests/inference/unit_tests/core/active_learning/samplers/test_close_to_threshold.py @@ -0,0 +1,794 @@ +from unittest import mock +from unittest.mock import MagicMock + +import numpy as np +import pytest + +from inference.core.active_learning.samplers.close_to_threshold import ( + is_close_to_threshold, + count_detections_close_to_threshold, + detections_are_close_to_threshold, + multi_label_classification_prediction_is_close_to_threshold, + multi_class_classification_prediction_is_close_to_threshold, + class_to_be_excluded, + prediction_is_close_to_threshold, + is_prediction_a_stub, + sample_close_to_threshold, + initialize_close_to_threshold_sampling, +) +from inference.core.active_learning.samplers import close_to_threshold +from inference.core.constants import ( + CLASSIFICATION_TASK, + OBJECT_DETECTION_TASK, +) +from inference.core.exceptions import ActiveLearningConfigurationError + +OBJECT_DETECTION_PREDICTION = { + "predictions": [ + { + "x": 784.5, + "y": 397.5, + "width": 187.0, + "height": 309.0, + "confidence": 0.9, + "class": "a", + "class_id": 1, + }, + { + "x": 784.5, + "y": 397.5, + "width": 187.0, + "height": 309.0, + "confidence": 0.7, + "class": "b", + "class_id": 1, + }, + ] +} +INSTANCE_SEGMENTATION_PREDICTION = { + "predictions": [ + { + "x": 784.5, + "y": 397.5, + "width": 187.0, + "height": 309.0, + "confidence": 0.9, + "class": "a", + "class_id": 1, + "points": [{"x": 207.0453125, "y": 106.559375}], + }, + { + "x": 784.5, + "y": 397.5, + "width": 187.0, + "height": 309.0, + "confidence": 0.7, + "class": "b", + "class_id": 1, + "points": [{"x": 207.0453125, "y": 106.559375}], + }, + ] +} +KEYPOINTS_PREDICTION = { + "predictions": [ + { + "x": 784.5, + "y": 397.5, + "width": 187.0, + "height": 309.0, + "confidence": 0.9, + "class": "a", + "class_id": 1, + "keypoints": [ + { + "x": 207.0453125, + "y": 106.559375, + "confidence": 0.92, + "class_id": 1, + "class": "eye", + } + ], + }, + { + "x": 784.5, + "y": 397.5, + "width": 187.0, + "height": 309.0, + "confidence": 0.7, + "class": "b", + "class_id": 1, + "keypoints": [ + { + "x": 207.0453125, + "y": 106.559375, + "confidence": 0.92, + "class_id": 1, + "class": "eye", + } + ], + }, + ] +} + +MULTI_LABEL_CLASSIFICATION_PREDICTION = { + "image": {"width": 416, "height": 416}, + "predictions": { + "cat": {"confidence": 0.97}, + "dog": {"confidence": 0.03}, + }, + "predicted_classes": ["cat"], +} + +MULTI_CLASS_CLASSIFICATION_PREDICTION = { + "image": {"width": 3487, "height": 2039}, + "predictions": [ + {"class": "Ambulance", "class_id": 0, "confidence": 0.6}, + {"class": "Limousine", "class_id": 16, "confidence": 0.3}, + {"class": "Helicopter", "class_id": 15, "confidence": 0.1}, + ], + "top": "Ambulance", + "confidence": 0.6, +} + + +@pytest.mark.parametrize( + "value, threshold, epsilon", + [ + (0.45, 0.5, 0.06), + (0.55, 0.5, 0.06), + ], +) +def test_close_to_threshold_when_value_is_close( + value: float, threshold: float, epsilon: float +) -> None: + # when + result = is_close_to_threshold(value=value, threshold=threshold, epsilon=epsilon) + + # then + assert result is True + + +@pytest.mark.parametrize( + "value, threshold, epsilon", + [ + (0.44, 0.5, 0.05), + (0.56, 0.5, 0.05), + ], +) +def test_close_to_threshold_when_value_is_not_close( + value: float, threshold: float, epsilon: float +) -> None: + # when + result = is_close_to_threshold(value=value, threshold=threshold, epsilon=epsilon) + + # then + assert result is False + + +def test_class_to_be_excluded_when_classes_not_selected() -> None: + # when + result = class_to_be_excluded(class_name="some", selected_class_names=None) + + # then + assert result is False + + +def test_class_to_be_excluded_when_classes_selected_and_specific_class_matches() -> ( + None +): + # when + result = class_to_be_excluded(class_name="a", selected_class_names={"a", "b", "c"}) + + # then + assert result is False + + +def test_class_to_be_excluded_when_classes_selected_and_specific_class_does_not_matche() -> ( + None +): + # when + result = class_to_be_excluded(class_name="d", selected_class_names={"a", "b", "c"}) + + # then + assert result is True + + +def test_count_detections_close_to_threshold_when_no_detections_in_prediction() -> None: + # given + prediction = {"predictions": []} + + # when + result = count_detections_close_to_threshold( + prediction=prediction, + selected_class_names=None, + threshold=0.5, + epsilon=1.0, + ) + + # then + assert result == 0 + + +@pytest.mark.parametrize( + "prediction", + [ + OBJECT_DETECTION_PREDICTION, + INSTANCE_SEGMENTATION_PREDICTION, + KEYPOINTS_PREDICTION, + ], +) +def test_count_detections_close_to_threshold_when_no_selected_class_names_pointed( + prediction: dict, +) -> None: + # when + result = count_detections_close_to_threshold( + prediction=prediction, + selected_class_names=None, + threshold=0.5, + epsilon=1.0, + ) + + # then + assert result == 2 + + +@pytest.mark.parametrize( + "prediction", + [ + OBJECT_DETECTION_PREDICTION, + INSTANCE_SEGMENTATION_PREDICTION, + KEYPOINTS_PREDICTION, + ], +) +def test_count_detections_close_to_threshold_when_no_selected_class_names_filter_out_predictions( + prediction: dict, +) -> None: + # when + result = count_detections_close_to_threshold( + prediction=prediction, + selected_class_names={"a", "c"}, + threshold=0.5, + epsilon=1.0, + ) + + # then + assert result == 1 + + +@pytest.mark.parametrize( + "prediction", + [ + OBJECT_DETECTION_PREDICTION, + INSTANCE_SEGMENTATION_PREDICTION, + KEYPOINTS_PREDICTION, + ], +) +def test_count_detections_close_to_threshold_when_no_selected_threshold_filter_out_predictions( + prediction: dict, +) -> None: + # when + result = count_detections_close_to_threshold( + prediction=prediction, + selected_class_names=None, + threshold=0.6, + epsilon=0.15, + ) + + # then + assert result == 1 + + +@pytest.mark.parametrize( + "prediction", + [ + OBJECT_DETECTION_PREDICTION, + INSTANCE_SEGMENTATION_PREDICTION, + KEYPOINTS_PREDICTION, + ], +) +def test_detection_prediction_is_close_to_threshold_when_minimum_objects_criterion_met( + prediction: dict, +) -> None: + # when + result = detections_are_close_to_threshold( + prediction=prediction, + selected_class_names=None, + threshold=0.6, + epsilon=0.15, + minimum_objects_close_to_threshold=1, + ) + + # then + assert result is True + + +@pytest.mark.parametrize( + "prediction", + [ + OBJECT_DETECTION_PREDICTION, + INSTANCE_SEGMENTATION_PREDICTION, + KEYPOINTS_PREDICTION, + ], +) +def test_detection_prediction_is_close_to_threshold_when_minimum_objects_criterion_not_met( + prediction: dict, +) -> None: + # when + result = detections_are_close_to_threshold( + prediction=prediction, + selected_class_names=None, + threshold=0.6, + epsilon=0.15, + minimum_objects_close_to_threshold=2, + ) + + # then + assert result is False + + +def test_multi_label_classification_prediction_is_close_to_threshold_when_top_class_meet_criteria() -> ( + None +): + # when + result = multi_label_classification_prediction_is_close_to_threshold( + prediction=MULTI_LABEL_CLASSIFICATION_PREDICTION, + selected_class_names=None, + threshold=0.95, + epsilon=0.05, + only_top_classes=True, + ) + + # then + assert result is True + + +def test_multi_label_classification_prediction_is_close_to_threshold_when_non_top_class_meet_threshold_but_filtered_out_by_top_classes() -> ( + None +): + # when + result = multi_label_classification_prediction_is_close_to_threshold( + prediction=MULTI_LABEL_CLASSIFICATION_PREDICTION, + selected_class_names=None, + threshold=0.05, + epsilon=0.05, + only_top_classes=True, + ) + + # then + assert result is False + + +def test_multi_label_classification_prediction_is_close_to_threshold_when_non_top_class_meet_threshold_but_filtered_out_by_class_names() -> ( + None +): + # when + result = multi_label_classification_prediction_is_close_to_threshold( + prediction=MULTI_LABEL_CLASSIFICATION_PREDICTION, + selected_class_names={"cat", "tiger"}, + threshold=0.05, + epsilon=0.05, + only_top_classes=False, + ) + + # then + assert result is False + + +def test_multi_label_classification_prediction_is_close_to_threshold_when_non_top_class_meet_criteria() -> ( + None +): + # when + result = multi_label_classification_prediction_is_close_to_threshold( + prediction=MULTI_LABEL_CLASSIFICATION_PREDICTION, + selected_class_names=None, + threshold=0.05, + epsilon=0.05, + only_top_classes=False, + ) + + # then + assert result is True + + +def test_multi_class_classification_prediction_is_close_to_threshold_for_top_class_when_classes_are_not_selected_and_threshold_met() -> ( + None +): + # when + result = multi_class_classification_prediction_is_close_to_threshold( + prediction=MULTI_CLASS_CLASSIFICATION_PREDICTION, + selected_class_names=None, + threshold=0.6, + epsilon=0.1, + only_top_classes=True, + ) + + # then + assert result is True + + +def test_multi_class_classification_prediction_is_close_to_threshold_for_top_class_when_classes_are_not_selected_and_threshold_not_met() -> ( + None +): + # when + result = multi_class_classification_prediction_is_close_to_threshold( + prediction=MULTI_CLASS_CLASSIFICATION_PREDICTION, + selected_class_names=None, + threshold=0.8, + epsilon=0.1, + only_top_classes=True, + ) + + # then + assert result is False + + +def test_multi_class_classification_prediction_is_close_to_threshold_for_top_class_when_classes_are_selected_and_top_class_does_not_match() -> ( + None +): + # when + result = multi_class_classification_prediction_is_close_to_threshold( + prediction=MULTI_CLASS_CLASSIFICATION_PREDICTION, + selected_class_names={"Limousine", "Helicopter"}, + threshold=0.6, + epsilon=0.1, + only_top_classes=True, + ) + + # then + assert result is False + + +def test_multi_class_classification_prediction_is_close_to_threshold_for_top_class_when_classes_are_selected_and_top_class_matches() -> ( + None +): + # when + result = multi_class_classification_prediction_is_close_to_threshold( + prediction=MULTI_CLASS_CLASSIFICATION_PREDICTION, + selected_class_names={"Ambulance"}, + threshold=0.6, + epsilon=0.1, + only_top_classes=True, + ) + + # then + assert result is True + + +def test_multi_class_classification_prediction_is_close_to_threshold_not_only_for_top_class_when_classes_not_selected() -> ( + None +): + # when + result = multi_class_classification_prediction_is_close_to_threshold( + prediction=MULTI_CLASS_CLASSIFICATION_PREDICTION, + selected_class_names=None, + threshold=0.3, + epsilon=0.1, + only_top_classes=False, + ) + + # then + assert result is True + + +def test_multi_class_classification_prediction_is_close_to_threshold_not_only_for_top_class_when_classes_not_selected_and_no_match_expected() -> ( + None +): + # when + result = multi_class_classification_prediction_is_close_to_threshold( + prediction=MULTI_CLASS_CLASSIFICATION_PREDICTION, + selected_class_names=None, + threshold=0.5, + epsilon=0.05, + only_top_classes=False, + ) + + # then + assert result is False + + +def test_multi_class_classification_prediction_is_close_to_threshold_not_only_for_top_class_when_classes_are_selected() -> ( + None +): + # when + result = multi_class_classification_prediction_is_close_to_threshold( + prediction=MULTI_CLASS_CLASSIFICATION_PREDICTION, + selected_class_names={"Ambulance", "Helicopter"}, + threshold=0.3, + epsilon=0.1, + only_top_classes=False, + ) + + # then + assert result is False + + +def test_prediction_is_close_to_threshold_for_detection_prediction() -> None: + # when + result = prediction_is_close_to_threshold( + prediction=OBJECT_DETECTION_PREDICTION, + prediction_type=OBJECT_DETECTION_TASK, + selected_class_names=None, + threshold=0.9, + epsilon=0.05, + only_top_classes=False, + minimum_objects_close_to_threshold=1, + ) + + # then + assert result is True + + +def test_prediction_is_close_to_threshold_for_classification_prediction() -> None: + # when + result = prediction_is_close_to_threshold( + prediction=MULTI_LABEL_CLASSIFICATION_PREDICTION, + prediction_type=CLASSIFICATION_TASK, + selected_class_names=None, + threshold=0.05, + epsilon=0.05, + only_top_classes=False, + minimum_objects_close_to_threshold=1, + ) + + # then + assert result is True + + +@pytest.mark.parametrize( + "prediction", + [ + OBJECT_DETECTION_PREDICTION, + KEYPOINTS_PREDICTION, + INSTANCE_SEGMENTATION_PREDICTION, + MULTI_CLASS_CLASSIFICATION_PREDICTION, + MULTI_LABEL_CLASSIFICATION_PREDICTION, + ], +) +def test_is_prediction_a_stub_when_prediction_is_not_a_stub(prediction: dict) -> None: + # when + result = is_prediction_a_stub(prediction=prediction) + + # then + assert result is False + + +def test_is_prediction_a_stub_when_prediction_is_a_stub() -> None: + # given + prediction = {"is_stub": True} + + # when + result = is_prediction_a_stub(prediction=prediction) + + # then + assert result is True + + +def test_sample_close_to_threshold_when_prediction_is_sub() -> None: + # when + result = sample_close_to_threshold( + image=np.zeros((128, 128, 3), dtype=np.uint8), + prediction={"is_stub": True}, + prediction_type=CLASSIFICATION_TASK, + selected_class_names=None, + threshold=0.5, + epsilon=0.1, + only_top_classes=True, + minimum_objects_close_to_threshold=1, + probability=1.0, + ) + + # then + assert result is False + + +def test_sample_close_to_threshold_when_prediction_type_is_unknown() -> None: + # when + result = sample_close_to_threshold( + image=np.zeros((128, 128, 3), dtype=np.uint8), + prediction={"is_stub": True}, + prediction_type="unknown", + selected_class_names=None, + threshold=0.5, + epsilon=0.1, + only_top_classes=True, + minimum_objects_close_to_threshold=1, + probability=1.0, + ) + + # then + assert result is False + + +def test_sample_close_to_threshold_when_prediction_is_not_close_to_threshold() -> None: + # when + result = sample_close_to_threshold( + image=np.zeros((128, 128, 3), dtype=np.uint8), + prediction=MULTI_CLASS_CLASSIFICATION_PREDICTION, + prediction_type=CLASSIFICATION_TASK, + selected_class_names={"Ambulance"}, + threshold=0.8, + epsilon=0.1, + only_top_classes=True, + minimum_objects_close_to_threshold=1, + probability=1.0, + ) + + # then + assert result is False + + +@mock.patch.object(close_to_threshold.random, "random") +def test_sample_close_to_threshold_when_prediction_is_close_to_threshold( + random_mock: MagicMock, +) -> None: + # given + random_mock.return_value = 0.49 + + # when + result = sample_close_to_threshold( + image=np.zeros((128, 128, 3), dtype=np.uint8), + prediction=MULTI_CLASS_CLASSIFICATION_PREDICTION, + prediction_type=CLASSIFICATION_TASK, + selected_class_names={"Ambulance"}, + threshold=0.6, + epsilon=0.1, + only_top_classes=True, + minimum_objects_close_to_threshold=1, + probability=0.5, + ) + + # then + assert result is True + + +def test_initialize_close_to_threshold_sampling() -> None: + # given + strategy_config = { + "name": "ambulance_high_confidence", + "selected_class_names": ["Ambulance"], + "threshold": 0.75, + "epsilon": 0.25, + "probability": 1.0, + } + + # when + sampling_method = initialize_close_to_threshold_sampling( + strategy_config=strategy_config + ) + result = sampling_method.sample( + np.zeros((128, 128, 3), dtype=np.uint8), + MULTI_CLASS_CLASSIFICATION_PREDICTION, + CLASSIFICATION_TASK, + ) + + # then + assert result is True + assert sampling_method.name == "ambulance_high_confidence" + + +@mock.patch.object(close_to_threshold, "partial") +def test_initialize_close_to_threshold_sampling_when_classes_not_selected( + partial_mock: MagicMock, +) -> None: + # given + strategy_config = { + "name": "ambulance_high_confidence", + "threshold": 0.75, + "epsilon": 0.25, + "probability": 0.6, + } + + # when + _ = initialize_close_to_threshold_sampling(strategy_config=strategy_config) + + # then + partial_mock.assert_called_once_with( + sample_close_to_threshold, + selected_class_names=None, + threshold=0.75, + epsilon=0.25, + only_top_classes=False, + minimum_objects_close_to_threshold=1, + probability=0.6, + ) + + +@mock.patch.object(close_to_threshold, "partial") +def test_initialize_close_to_threshold_sampling_when_classes_selected( + partial_mock: MagicMock, +) -> None: + # given + strategy_config = { + "name": "ambulance_high_confidence", + "selected_class_names": ["Ambulance", "Helicopter"], + "threshold": 0.75, + "epsilon": 0.25, + "probability": 0.6, + } + + # when + _ = initialize_close_to_threshold_sampling(strategy_config=strategy_config) + + # then + partial_mock.assert_called_once_with( + sample_close_to_threshold, + selected_class_names={"Ambulance", "Helicopter"}, + threshold=0.75, + epsilon=0.25, + only_top_classes=False, + minimum_objects_close_to_threshold=1, + probability=0.6, + ) + + +@mock.patch.object(close_to_threshold, "partial") +def test_initialize_close_to_threshold_sampling_when_only_top_classes_mode_enabled( + partial_mock: MagicMock, +) -> None: + # given + strategy_config = { + "name": "ambulance_high_confidence", + "selected_class_names": ["Ambulance"], + "threshold": 0.75, + "epsilon": 0.25, + "probability": 0.6, + "only_top_classes": True, + } + + # when + _ = initialize_close_to_threshold_sampling(strategy_config=strategy_config) + + # then + partial_mock.assert_called_once_with( + sample_close_to_threshold, + selected_class_names={"Ambulance"}, + threshold=0.75, + epsilon=0.25, + only_top_classes=True, + minimum_objects_close_to_threshold=1, + probability=0.6, + ) + + +@mock.patch.object(close_to_threshold, "partial") +def test_initialize_close_to_threshold_sampling_when_objects_close_to_threshold_specified( + partial_mock: MagicMock, +) -> None: + # given + strategy_config = { + "name": "ambulance_high_confidence", + "selected_class_names": ["Ambulance"], + "threshold": 0.75, + "epsilon": 0.25, + "probability": 0.6, + "minimum_objects_close_to_threshold": 6, + } + + # when + _ = initialize_close_to_threshold_sampling(strategy_config=strategy_config) + + # then + partial_mock.assert_called_once_with( + sample_close_to_threshold, + selected_class_names={"Ambulance"}, + threshold=0.75, + epsilon=0.25, + only_top_classes=False, + minimum_objects_close_to_threshold=6, + probability=0.6, + ) + + +def test_initialize_close_to_threshold_sampling_when_configuration_key_missing() -> ( + None +): + # given + strategy_config = { + "name": "ambulance_high_confidence", + "threshold": 0.75, + "epsilon": 0.25, + } + + # when + with pytest.raises(ActiveLearningConfigurationError): + _ = initialize_close_to_threshold_sampling(strategy_config=strategy_config) diff --git a/tests/inference/unit_tests/core/active_learning/samplers/test_contains_classes.py b/tests/inference/unit_tests/core/active_learning/samplers/test_contains_classes.py new file mode 100644 index 000000000..c6291bfbe --- /dev/null +++ b/tests/inference/unit_tests/core/active_learning/samplers/test_contains_classes.py @@ -0,0 +1,139 @@ +from unittest import mock +from unittest.mock import MagicMock + +import numpy as np +import pytest + +from inference.core.active_learning.samplers.contains_classes import ( + sample_based_on_classes, + initialize_classes_based_sampling, +) +from inference.core.active_learning.samplers import contains_classes +from inference.core.constants import OBJECT_DETECTION_TASK, CLASSIFICATION_TASK +from inference.core.exceptions import ActiveLearningConfigurationError + + +MULTI_LABEL_CLASSIFICATION_PREDICTION = { + "image": {"width": 416, "height": 416}, + "predictions": { + "cat": {"confidence": 0.97}, + "dog": {"confidence": 0.03}, + }, + "predicted_classes": ["cat"], +} + + +def test_sample_based_on_classes_for_detection_predictions() -> None: + # given + prediction = { + "predictions": [ + { + "x": 784.5, + "y": 397.5, + "width": 187.0, + "height": 309.0, + "confidence": 0.9, + "class": "a", + "class_id": 1, + } + ] + } + + # when + result = sample_based_on_classes( + image=np.zeros((128, 128, 3), dtype=np.uint8), + prediction=prediction, + prediction_type=OBJECT_DETECTION_TASK, + selected_class_names={"a", "b", "c"}, + probability=1.0, + ) + + # then + assert result is False + + +def test_sample_based_on_classes_for_classification_prediction_when_classes_detected() -> ( + None +): + # when + result = sample_based_on_classes( + image=np.zeros((128, 128, 3), dtype=np.uint8), + prediction=MULTI_LABEL_CLASSIFICATION_PREDICTION, + prediction_type=CLASSIFICATION_TASK, + selected_class_names={"cat"}, + probability=1.0, + ) + + # then + assert result is True + + +def test_sample_based_on_classes_for_classification_prediction_when_classes_not_detected() -> ( + None +): + # when + result = sample_based_on_classes( + image=np.zeros((128, 128, 3), dtype=np.uint8), + prediction=MULTI_LABEL_CLASSIFICATION_PREDICTION, + prediction_type=CLASSIFICATION_TASK, + selected_class_names={"dog"}, + probability=1.0, + ) + + # then + assert result is False + + +def test_initialize_classes_based_sampling() -> None: + # given + strategy_config = { + "name": "detect_dogs", + "probability": 1.0, + "selected_class_names": ["dog"], + } + + # when + sampling_method = initialize_classes_based_sampling(strategy_config=strategy_config) + result = sampling_method.sample( + np.zeros((128, 128, 3), dtype=np.uint8), + MULTI_LABEL_CLASSIFICATION_PREDICTION, + CLASSIFICATION_TASK, + ) + + # then + assert result is False + + +@mock.patch.object(contains_classes, "partial") +def test_initialize_classes_based_sampling_against_parameters_correctness( + partial_mock: MagicMock, +) -> None: + # given + strategy_config = { + "name": "detect_dogs", + "probability": 1.0, + "selected_class_names": ["dog"], + } + + # when + _ = initialize_classes_based_sampling(strategy_config=strategy_config) + + # then + partial_mock.assert_called_once_with( + sample_based_on_classes, + selected_class_names={"dog"}, + probability=1.0, + ) + + +def test_initialize_classes_based_sampling_when_configuration_key_missing() -> None: + # given + strategy_config = { + "name": "detect_dogs", + "selected_class_names": ["dog"], + "minimum_objects": 10, + } + + # when + with pytest.raises(ActiveLearningConfigurationError): + _ = initialize_classes_based_sampling(strategy_config=strategy_config) diff --git a/tests/inference/unit_tests/core/active_learning/samplers/test_number_of_detections.py b/tests/inference/unit_tests/core/active_learning/samplers/test_number_of_detections.py new file mode 100644 index 000000000..2f7721640 --- /dev/null +++ b/tests/inference/unit_tests/core/active_learning/samplers/test_number_of_detections.py @@ -0,0 +1,286 @@ +from typing import Optional +from unittest import mock +from unittest.mock import MagicMock + +import numpy as np +import pytest + +from inference.core.active_learning.samplers.number_of_detections import ( + is_in_range, + sample_based_on_detections_number, + initialize_detections_number_based_sampling, +) +from inference.core.active_learning.samplers import number_of_detections +from inference.core.constants import CLASSIFICATION_TASK, OBJECT_DETECTION_TASK +from inference.core.exceptions import ActiveLearningConfigurationError + +OBJECT_DETECTION_PREDICTION = { + "predictions": [ + { + "x": 784.5, + "y": 397.5, + "width": 187.0, + "height": 309.0, + "confidence": 0.9, + "class": "a", + "class_id": 1, + }, + { + "x": 784.5, + "y": 397.5, + "width": 187.0, + "height": 309.0, + "confidence": 0.7, + "class": "b", + "class_id": 1, + }, + ] +} + + +@pytest.mark.parametrize( + "value, more_than, less_than", + [ + (1, None, 2), + (2, 1, 5), + (5, 4, None), + (1, None, None), + ], +) +def test_is_in_range_value_meets_condition( + value: int, more_than: Optional[int], less_than: Optional[int] +) -> None: + # when + result = is_in_range(value=value, more_than=more_than, less_than=less_than) + + # then + assert result is True + + +@pytest.mark.parametrize( + "value, more_than, less_than", + [ + (1, 2, None), + (2, 5, 1), + (5, None, 4), + ], +) +def test_is_in_range_value_does_not_meet_condition( + value: int, more_than: Optional[int], less_than: Optional[int] +) -> None: + # when + result = is_in_range(value=value, more_than=more_than, less_than=less_than) + + # then + assert result is False + + +def test_sample_based_on_detections_number_when_classification_prediction_given() -> ( + None +): + # given + prediction = { + "image": {"width": 416, "height": 416}, + "predictions": { + "cat": {"confidence": 0.97}, + "dog": {"confidence": 0.03}, + }, + "predicted_classes": ["cat"], + } + + # when + result = sample_based_on_detections_number( + image=np.zeros((128, 128, 3), dtype=np.uint8), + prediction=prediction, + prediction_type=CLASSIFICATION_TASK, + less_than=None, + more_than=None, + selected_class_names=None, + probability=1.0, + ) + + # then + assert result is False + + +def test_sample_based_on_detections_number_when_stub_prediction_given() -> None: + # when + result = sample_based_on_detections_number( + image=np.zeros((128, 128, 3), dtype=np.uint8), + prediction={"is_stub": True}, + prediction_type=CLASSIFICATION_TASK, + less_than=None, + more_than=None, + selected_class_names=None, + probability=1.0, + ) + + # then + assert result is False + + +@mock.patch.object(number_of_detections.random, "random") +def test_sample_based_on_detections_number_when_detections_in_range_and_sampling_succeeds( + random_mock: MagicMock, +) -> None: + # given + random_mock.return_value = 0.29 + + # when + result = sample_based_on_detections_number( + image=np.zeros((128, 128, 3), dtype=np.uint8), + prediction=OBJECT_DETECTION_PREDICTION, + prediction_type=OBJECT_DETECTION_TASK, + more_than=1, + less_than=3, + selected_class_names={"a", "b"}, + probability=0.3, + ) + + # then + assert result is True + + +@mock.patch.object(number_of_detections.random, "random") +def test_sample_based_on_detections_number_when_detections_in_range_and_sampling_fails( + random_mock: MagicMock, +) -> None: + # given + random_mock.return_value = 0.31 + + # when + result = sample_based_on_detections_number( + image=np.zeros((128, 128, 3), dtype=np.uint8), + prediction=OBJECT_DETECTION_PREDICTION, + prediction_type=OBJECT_DETECTION_TASK, + more_than=1, + less_than=3, + selected_class_names={"a", "b"}, + probability=0.3, + ) + + # then + assert result is False + + +def test_sample_based_on_detections_number_when_detections_not_in_range() -> None: + # when + result = sample_based_on_detections_number( + image=np.zeros((128, 128, 3), dtype=np.uint8), + prediction=OBJECT_DETECTION_PREDICTION, + prediction_type=OBJECT_DETECTION_TASK, + more_than=2, + less_than=3, + selected_class_names={"a"}, + probability=0.3, + ) + + # then + assert result is False + + +def test_initialize_detections_number_based_sampling() -> None: + # given + strategy_config = { + "name": "two_detections", + "less_than": 3, + "more_than": 1, + "selected_class_names": {"a", "b"}, + "probability": 1.0, + } + + # when + sampling_method = initialize_detections_number_based_sampling( + strategy_config=strategy_config + ) + result = sampling_method.sample( + np.zeros((128, 128, 3), dtype=np.uint8), + OBJECT_DETECTION_PREDICTION, + OBJECT_DETECTION_TASK, + ) + + # then + assert result is True + assert sampling_method.name == "two_detections" + + +@mock.patch.object(number_of_detections, "partial") +def test_test_initialize_detections_number_based_sampling_when_optional_values_not_given( + partial_mock: MagicMock, +) -> None: + # given + strategy_config = { + "name": "two_detections", + "probability": 1.0, + } + + # when + _ = initialize_detections_number_based_sampling(strategy_config=strategy_config) + + # then + partial_mock.assert_called_once_with( + sample_based_on_detections_number, + less_than=None, + more_than=None, + selected_class_names=None, + probability=1.0, + ) + + +@mock.patch.object(number_of_detections, "partial") +def test_test_initialize_detections_number_based_sampling_when_optional_values_given( + partial_mock: MagicMock, +) -> None: + # given + strategy_config = { + "name": "two_detections", + "probability": 1.0, + "less_than": 10, + "more_than": 5, + "selected_class_names": ["a", "b"], + } + + # when + _ = initialize_detections_number_based_sampling(strategy_config=strategy_config) + + # then + partial_mock.assert_called_once_with( + sample_based_on_detections_number, + less_than=10, + more_than=5, + selected_class_names={"a", "b"}, + probability=1.0, + ) + + +def test_initialize_detections_number_based_sampling_when_required_value_missing() -> ( + None +): + # given + strategy_config = { + "name": "two_detections", + "less_than": 10, + "more_than": 5, + "selected_class_names": ["a", "b"], + } + + # when + with pytest.raises(ActiveLearningConfigurationError): + _ = initialize_detections_number_based_sampling(strategy_config=strategy_config) + + +def test_initialize_detections_number_based_sampling_when_malformed_config_given() -> ( + None +): + # given + strategy_config = { + "name": "two_detections", + "less_than": 5, + "more_than": 6, + "probability": 1.0, + "selected_class_names": ["a", "b"], + } + + # when + with pytest.raises(ActiveLearningConfigurationError): + _ = initialize_detections_number_based_sampling(strategy_config=strategy_config) diff --git a/tests/inference/unit_tests/core/active_learning/test_sampling.py b/tests/inference/unit_tests/core/active_learning/samplers/test_random.py similarity index 84% rename from tests/inference/unit_tests/core/active_learning/test_sampling.py rename to tests/inference/unit_tests/core/active_learning/samplers/test_random.py index 6ca7c1662..22b74567d 100644 --- a/tests/inference/unit_tests/core/active_learning/test_sampling.py +++ b/tests/inference/unit_tests/core/active_learning/samplers/test_random.py @@ -4,11 +4,12 @@ import numpy as np import pytest -from inference.core.active_learning.sampling import initialize_random_sampling -from inference.core.active_learning import sampling +from inference.core.active_learning.samplers.random import initialize_random_sampling +from inference.core.active_learning.samplers import random +from inference.core.exceptions import ActiveLearningConfigurationError -@mock.patch.object(sampling.random, "random") +@mock.patch.object(random.random, "random") def test_initialize_random_sampling_when_config_is_valid( random_mock: MagicMock, ) -> None: @@ -57,7 +58,7 @@ def test_initialize_random_sampling_when_strategy_name_is_not_present() -> None: } # when - with pytest.raises(KeyError): + with pytest.raises(ActiveLearningConfigurationError): _ = initialize_random_sampling(strategy_config=strategy_config) @@ -75,5 +76,5 @@ def test_initialize_random_sampling_when_traffic_percentage_is_not_present() -> } # when - with pytest.raises(KeyError): + with pytest.raises(ActiveLearningConfigurationError): _ = initialize_random_sampling(strategy_config=strategy_config) diff --git a/tests/inference/unit_tests/core/active_learning/test_configuration.py b/tests/inference/unit_tests/core/active_learning/test_configuration.py index 58eafcc13..cd0f33e03 100644 --- a/tests/inference/unit_tests/core/active_learning/test_configuration.py +++ b/tests/inference/unit_tests/core/active_learning/test_configuration.py @@ -3,6 +3,7 @@ from unittest.mock import MagicMock import numpy as np +import pytest from inference.core.active_learning.configuration import ( initialize_sampling_methods, @@ -18,6 +19,7 @@ StrategyLimit, StrategyLimitType, ) +from inference.core.exceptions import ActiveLearningConfigurationError def test_initialize_sampling_methods() -> None: @@ -25,9 +27,9 @@ def test_initialize_sampling_methods() -> None: sampling_strategies_configs = [ { "name": "default_strategy", - "type": "random_sampling", + "type": "random", "traffic_percentage": 0.5, - "tags": ["c", "d"], + "tags": ["a"], "limits": [ {"type": "minutely", "value": 10}, {"type": "hourly", "value": 100}, @@ -35,6 +37,44 @@ def test_initialize_sampling_methods() -> None: ], }, {"type": "non-existing"}, + { + "name": "hard_examples", + "type": "close_to_threshold", + "selected_class_names": ["a", "b"], + "threshold": 0.25, + "epsilon": 0.1, + "probability": 0.5, + "tags": ["b"], + "limits": [ + {"type": "minutely", "value": 10}, + {"type": "hourly", "value": 100}, + {"type": "daily", "value": 1000}, + ], + }, + { + "name": "underrepresented_classes", + "type": "classes_based", + "selected_class_names": ["a"], + "probability": 0.5, + "tags": ["hard_classes"], + "limits": [ + {"type": "minutely", "value": 10}, + {"type": "hourly", "value": 100}, + {"type": "daily", "value": 1000}, + ], + }, + { + "name": "low_detections", + "type": "detections_number_based", + "probability": 0.5, + "less_than": 3, + "tags": ["empty"], + "limits": [ + {"type": "minutely", "value": 10}, + {"type": "hourly", "value": 100}, + {"type": "daily", "value": 1000}, + ], + }, ] # when @@ -43,13 +83,19 @@ def test_initialize_sampling_methods() -> None: ) # then - assert len(result) == 1 - assert result[0].name == "default_strategy" - _ = result[0].sample( # test if sampling executed correctly - np.zeros((128, 128, 3), dtype=np.ndarray), - {"some": "prediction"}, - "object-detection", - ) + assert len(result) == 4 + assert [r.name for r in result] == [ + "default_strategy", + "hard_examples", + "underrepresented_classes", + "low_detections", + ] + for strategy in result: + _ = strategy.sample( # test if sampling executed correctly + np.zeros((128, 128, 3), dtype=np.ndarray), + {"is_stub": "True"}, + "object-detection", + ) @mock.patch.object(configuration, "get_roboflow_active_learning_configuration") @@ -148,7 +194,7 @@ def test_prepare_active_learning_configuration_when_active_learning_enabled( "sampling_strategies": [ { "name": "default_strategy", - "type": "random_sampling", + "type": "random", "traffic_percentage": 0.1, "tags": ["c", "d"], "limits": [ @@ -200,3 +246,39 @@ def test_prepare_active_learning_configuration_when_active_learning_enabled( tags=["a", "b"], strategies_tags={"default_strategy": ["c", "d"]}, ) + + +def test_test_initialize_sampling_methods_when_duplicate_names_detected() -> None: + # given + sampling_strategies_configs = [ + { + "name": "default_strategy", + "type": "random", + "traffic_percentage": 0.5, + "tags": ["a"], + "limits": [ + {"type": "minutely", "value": 10}, + {"type": "hourly", "value": 100}, + {"type": "daily", "value": 1000}, + ], + }, + { + "name": "default_strategy", + "type": "close_to_threshold", + "selected_class_names": ["a", "b"], + "threshold": 0.25, + "epsilon": 0.1, + "probability": 0.5, + "tags": ["b"], + "limits": [ + {"type": "minutely", "value": 10}, + {"type": "hourly", "value": 100}, + {"type": "daily", "value": 1000}, + ], + }, + ] + + with pytest.raises(ActiveLearningConfigurationError): + _ = initialize_sampling_methods( + sampling_strategies_configs=sampling_strategies_configs + )