Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce additional active learning sampling strategies #148

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions inference/core/active_learning/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,31 @@
RoboflowProjectMetadata,
SamplingMethod,
)
from inference.core.active_learning.sampling import initialize_random_sampling
from inference.core.cache.base import BaseCache
from inference.core.active_learning.samplers.close_to_threshold import (
initialize_close_to_threshold_sampling,
)
from inference.core.active_learning.samplers.contains_classes import (
initialize_classes_based_sampling,
)
from inference.core.active_learning.samplers.number_of_detections import (
initialize_detections_number_based_sampling,
)
from inference.core.active_learning.samplers.random import initialize_random_sampling
from inference.core.env import ACTIVE_LEARNING_ENABLED
from inference.core.exceptions import ActiveLearningConfigurationError
from inference.core.roboflow_api import (
get_roboflow_active_learning_configuration,
get_roboflow_dataset_type,
get_roboflow_workspace,
)
from inference.core.utils.roboflow import get_model_id_chunks

TYPE2SAMPLING_INITIALIZERS = {"random_sampling": initialize_random_sampling}
TYPE2SAMPLING_INITIALIZERS = {
"random": initialize_random_sampling,
"close_to_threshold": initialize_close_to_threshold_sampling,
"classes_based": initialize_classes_based_sampling,
"detections_number_based": initialize_detections_number_based_sampling,
}


def prepare_active_learning_configuration(
Expand Down Expand Up @@ -93,4 +107,9 @@ def initialize_sampling_methods(
continue
initializer = TYPE2SAMPLING_INITIALIZERS[sampling_type]
result.append(initializer(sampling_strategy_config))
names = set(m.name for m in result)
if len(names) != len(result):
raise ActiveLearningConfigurationError(
"Detected duplication of Active Learning strategies names."
)
return result
Empty file.
227 changes: 227 additions & 0 deletions inference/core/active_learning/samplers/close_to_threshold.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
import random
from functools import partial
from typing import Any, Dict, Optional, Set

import numpy as np

from inference.core.active_learning.entities import (
Prediction,
PredictionType,
SamplingMethod,
)
from inference.core.constants import (
CLASSIFICATION_TASK,
INSTANCE_SEGMENTATION_TASK,
KEYPOINTS_DETECTION_TASK,
OBJECT_DETECTION_TASK,
)
from inference.core.exceptions import ActiveLearningConfigurationError

ELIGIBLE_PREDICTION_TYPES = {
CLASSIFICATION_TASK,
INSTANCE_SEGMENTATION_TASK,
KEYPOINTS_DETECTION_TASK,
OBJECT_DETECTION_TASK,
}


def initialize_close_to_threshold_sampling(
strategy_config: Dict[str, Any]
) -> SamplingMethod:
try:
selected_class_names = strategy_config.get("selected_class_names")
if selected_class_names is not None:
selected_class_names = set(selected_class_names)
sample_function = partial(
sample_close_to_threshold,
selected_class_names=selected_class_names,
threshold=strategy_config["threshold"],
epsilon=strategy_config["epsilon"],
only_top_classes=strategy_config.get("only_top_classes", False),
minimum_objects_close_to_threshold=strategy_config.get(
"minimum_objects_close_to_threshold",
1,
),
probability=strategy_config["probability"],
)
return SamplingMethod(
name=strategy_config["name"],
sample=sample_function,
)
except KeyError as error:
raise ActiveLearningConfigurationError(
f"In configuration of `close_to_threshold_sampling` missing key detected: {error}."
) from error


def sample_close_to_threshold(
image: np.ndarray,
prediction: Prediction,
prediction_type: PredictionType,
selected_class_names: Optional[Set[str]],
threshold: float,
epsilon: float,
only_top_classes: bool,
minimum_objects_close_to_threshold: int,
probability: float,
) -> bool:
if is_prediction_a_stub(prediction=prediction):
return False
if prediction_type not in ELIGIBLE_PREDICTION_TYPES:
return False
close_to_threshold = prediction_is_close_to_threshold(
prediction=prediction,
prediction_type=prediction_type,
selected_class_names=selected_class_names,
threshold=threshold,
epsilon=epsilon,
only_top_classes=only_top_classes,
minimum_objects_close_to_threshold=minimum_objects_close_to_threshold,
)
if not close_to_threshold:
return False
return random.random() < probability


def is_prediction_a_stub(prediction: Prediction) -> bool:
return prediction.get("is_stub", False)


def prediction_is_close_to_threshold(
prediction: Prediction,
prediction_type: PredictionType,
selected_class_names: Optional[Set[str]],
threshold: float,
epsilon: float,
only_top_classes: bool,
minimum_objects_close_to_threshold: int,
) -> bool:
if CLASSIFICATION_TASK not in prediction_type:
return detections_are_close_to_threshold(
prediction=prediction,
selected_class_names=selected_class_names,
threshold=threshold,
epsilon=epsilon,
minimum_objects_close_to_threshold=minimum_objects_close_to_threshold,
)
checker = multi_label_classification_prediction_is_close_to_threshold
if "top" in prediction:
checker = multi_class_classification_prediction_is_close_to_threshold
return checker(
prediction=prediction,
selected_class_names=selected_class_names,
threshold=threshold,
epsilon=epsilon,
only_top_classes=only_top_classes,
)


def multi_class_classification_prediction_is_close_to_threshold(
prediction: Prediction,
selected_class_names: Optional[Set[str]],
threshold: float,
epsilon: float,
only_top_classes: bool,
) -> bool:
if only_top_classes:
return (
multi_class_classification_prediction_is_close_to_threshold_for_top_class(
prediction=prediction,
selected_class_names=selected_class_names,
threshold=threshold,
epsilon=epsilon,
)
)
for prediction_details in prediction["predictions"]:
if class_to_be_excluded(
class_name=prediction_details["class"],
selected_class_names=selected_class_names,
):
continue
if is_close_to_threshold(
value=prediction_details["confidence"], threshold=threshold, epsilon=epsilon
):
return True
return False


def multi_class_classification_prediction_is_close_to_threshold_for_top_class(
prediction: Prediction,
selected_class_names: Optional[Set[str]],
threshold: float,
epsilon: float,
) -> bool:
if (
selected_class_names is not None
and prediction["top"] not in selected_class_names
):
return False
return abs(prediction["confidence"] - threshold) < epsilon


def multi_label_classification_prediction_is_close_to_threshold(
prediction: Prediction,
selected_class_names: Optional[Set[str]],
threshold: float,
epsilon: float,
only_top_classes: bool,
) -> bool:
predicted_classes = set(prediction["predicted_classes"])
for class_name, prediction_details in prediction["predictions"].items():
if only_top_classes and class_name not in predicted_classes:
continue
if class_to_be_excluded(
class_name=class_name, selected_class_names=selected_class_names
):
continue
if is_close_to_threshold(
value=prediction_details["confidence"], threshold=threshold, epsilon=epsilon
):
return True
return False


def detections_are_close_to_threshold(
prediction: Prediction,
selected_class_names: Optional[Set[str]],
threshold: float,
epsilon: float,
minimum_objects_close_to_threshold: int,
) -> bool:
detections_close_to_threshold = count_detections_close_to_threshold(
prediction=prediction,
selected_class_names=selected_class_names,
threshold=threshold,
epsilon=epsilon,
)
return detections_close_to_threshold >= minimum_objects_close_to_threshold


def count_detections_close_to_threshold(
prediction: Prediction,
selected_class_names: Optional[Set[str]],
threshold: float,
epsilon: float,
) -> int:
counter = 0
for prediction_details in prediction["predictions"]:
if class_to_be_excluded(
class_name=prediction_details["class"],
selected_class_names=selected_class_names,
):
continue
if is_close_to_threshold(
value=prediction_details["confidence"], threshold=threshold, epsilon=epsilon
):
counter += 1
return counter


def class_to_be_excluded(
class_name: str, selected_class_names: Optional[Set[str]]
) -> bool:
return selected_class_names is not None and class_name not in selected_class_names


def is_close_to_threshold(value: float, threshold: float, epsilon: float) -> bool:
return abs(value - threshold) < epsilon
58 changes: 58 additions & 0 deletions inference/core/active_learning/samplers/contains_classes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from functools import partial
from typing import Any, Dict, Set

import numpy as np

from inference.core.active_learning.entities import (
Prediction,
PredictionType,
SamplingMethod,
)
from inference.core.active_learning.samplers.close_to_threshold import (
sample_close_to_threshold,
)
from inference.core.constants import CLASSIFICATION_TASK
from inference.core.exceptions import ActiveLearningConfigurationError

ELIGIBLE_PREDICTION_TYPES = {CLASSIFICATION_TASK}


def initialize_classes_based_sampling(
strategy_config: Dict[str, Any]
) -> SamplingMethod:
try:
sample_function = partial(
sample_based_on_classes,
selected_class_names=set(strategy_config["selected_class_names"]),
probability=strategy_config["probability"],
)
return SamplingMethod(
name=strategy_config["name"],
sample=sample_function,
)
except KeyError as error:
raise ActiveLearningConfigurationError(
f"In configuration of `classes_based_sampling` missing key detected: {error}."
) from error


def sample_based_on_classes(
image: np.ndarray,
prediction: Prediction,
prediction_type: PredictionType,
selected_class_names: Set[str],
probability: float,
) -> bool:
if prediction_type not in ELIGIBLE_PREDICTION_TYPES:
return False
return sample_close_to_threshold(
image=image,
prediction=prediction,
prediction_type=prediction_type,
selected_class_names=selected_class_names,
threshold=0.5,
epsilon=1.0,
only_top_classes=True,
minimum_objects_close_to_threshold=1,
probability=probability,
)
Loading