diff --git a/changelog.d/20241112_201034_roman_aa_threshold.md b/changelog.d/20241112_201034_roman_aa_threshold.md new file mode 100644 index 000000000000..0a1da765badb --- /dev/null +++ b/changelog.d/20241112_201034_roman_aa_threshold.md @@ -0,0 +1,12 @@ +### Added + +- \[SDK, CLI\] Added a `conf_threshold` parameter to + `cvat_sdk.auto_annotation.annotate_task`, which is passed as-is to the AA + function object via the context. The CLI equivalent is `auto-annotate + --conf-threshold`. This makes it easier to write and use AA functions that + support object filtering based on confidence levels + () + +- \[SDK\] Built-in auto-annotation functions now support object filtering by + confidence level + () diff --git a/cvat-cli/src/cvat_cli/_internal/commands.py b/cvat-cli/src/cvat_cli/_internal/commands.py index e86ef3b6350f..f49416c843e5 100644 --- a/cvat-cli/src/cvat_cli/_internal/commands.py +++ b/cvat-cli/src/cvat_cli/_internal/commands.py @@ -20,7 +20,13 @@ from cvat_sdk.core.proxies.tasks import ResourceType from .command_base import CommandGroup -from .parsers import BuildDictAction, parse_function_parameter, parse_label_arg, parse_resource_type +from .parsers import ( + BuildDictAction, + parse_function_parameter, + parse_label_arg, + parse_resource_type, + parse_threshold, +) COMMANDS = CommandGroup(description="Perform common operations related to CVAT tasks.") @@ -463,6 +469,13 @@ def configure_parser(self, parser: argparse.ArgumentParser) -> None: help="Allow the function to declare labels not configured in the task", ) + parser.add_argument( + "--conf-threshold", + type=parse_threshold, + help="Confidence threshold for filtering detections", + default=None, + ) + def execute( self, client: Client, @@ -473,6 +486,7 @@ def execute( function_parameters: dict[str, Any], clear_existing: bool = False, allow_unmatched_labels: bool = False, + conf_threshold: Optional[float], ) -> None: if function_module is not None: function = importlib.import_module(function_module) @@ -497,4 +511,5 @@ def execute( pbar=DeferredTqdmProgressReporter(), clear_existing=clear_existing, allow_unmatched_labels=allow_unmatched_labels, + conf_threshold=conf_threshold, ) diff --git a/cvat-cli/src/cvat_cli/_internal/parsers.py b/cvat-cli/src/cvat_cli/_internal/parsers.py index a66710a09f47..97dcb5b2668a 100644 --- a/cvat-cli/src/cvat_cli/_internal/parsers.py +++ b/cvat-cli/src/cvat_cli/_internal/parsers.py @@ -53,6 +53,17 @@ def parse_function_parameter(s: str) -> tuple[str, Any]: return (key, value) +def parse_threshold(s: str) -> float: + try: + value = float(s) + except ValueError as e: + raise argparse.ArgumentTypeError("must be a number") from e + + if not 0 <= value <= 1: + raise argparse.ArgumentTypeError("must be between 0 and 1") + return value + + class BuildDictAction(argparse.Action): def __init__(self, option_strings, dest, default=None, **kwargs): super().__init__(option_strings, dest, default=default or {}, **kwargs) diff --git a/cvat-sdk/cvat_sdk/auto_annotation/driver.py b/cvat-sdk/cvat_sdk/auto_annotation/driver.py index 0f3d82ea32ea..ffa5eab9879f 100644 --- a/cvat-sdk/cvat_sdk/auto_annotation/driver.py +++ b/cvat-sdk/cvat_sdk/auto_annotation/driver.py @@ -220,9 +220,10 @@ def validate_and_remap(self, shapes: list[models.LabeledShapeRequest], ds_frame: shapes[:] = new_shapes -@attrs.frozen +@attrs.frozen(kw_only=True) class _DetectionFunctionContextImpl(DetectionFunctionContext): frame_name: str + conf_threshold: Optional[float] = None def annotate_task( @@ -233,6 +234,7 @@ def annotate_task( pbar: Optional[ProgressReporter] = None, clear_existing: bool = False, allow_unmatched_labels: bool = False, + conf_threshold: Optional[float] = None, ) -> None: """ Downloads data for the task with the given ID, applies the given function to it @@ -264,11 +266,17 @@ def annotate_task( function declares a label in its spec that has no corresponding label in the task. If it's set to true, then such labels are allowed, and any annotations returned by the function that refer to this label are ignored. Otherwise, BadFunctionError is raised. + + The conf_threshold parameter must be None or a number between 0 and 1. It will be passed + to the function as the conf_threshold attribute of the context object. """ if pbar is None: pbar = NullProgressReporter() + if conf_threshold is not None and not 0 <= conf_threshold <= 1: + raise ValueError("conf_threshold must be None or a number between 0 and 1") + dataset = TaskDataset(client, task_id, load_annotations=False) assert isinstance(function.spec, DetectionFunctionSpec) @@ -285,7 +293,10 @@ def annotate_task( with pbar.task(total=len(dataset.samples), unit="samples"): for sample in pbar.iter(dataset.samples): frame_shapes = function.detect( - _DetectionFunctionContextImpl(sample.frame_name), sample.media.load_image() + _DetectionFunctionContextImpl( + frame_name=sample.frame_name, conf_threshold=conf_threshold + ), + sample.media.load_image(), ) mapper.validate_and_remap(frame_shapes, sample.frame_index) shapes.extend(frame_shapes) diff --git a/cvat-sdk/cvat_sdk/auto_annotation/functions/torchvision_detection.py b/cvat-sdk/cvat_sdk/auto_annotation/functions/torchvision_detection.py index d257cb7ec889..423db05adbcb 100644 --- a/cvat-sdk/cvat_sdk/auto_annotation/functions/torchvision_detection.py +++ b/cvat-sdk/cvat_sdk/auto_annotation/functions/torchvision_detection.py @@ -27,13 +27,17 @@ def spec(self) -> cvataa.DetectionFunctionSpec: ] ) - def detect(self, context, image: PIL.Image.Image) -> list[models.LabeledShapeRequest]: + def detect( + self, context: cvataa.DetectionFunctionContext, image: PIL.Image.Image + ) -> list[models.LabeledShapeRequest]: + conf_threshold = context.conf_threshold or 0 results = self._model([self._transforms(image)]) return [ cvataa.rectangle(label.item(), [x.item() for x in box]) for result in results - for box, label in zip(result["boxes"], result["labels"]) + for box, label, score in zip(result["boxes"], result["labels"], result["scores"]) + if score >= conf_threshold ] diff --git a/cvat-sdk/cvat_sdk/auto_annotation/functions/torchvision_keypoint_detection.py b/cvat-sdk/cvat_sdk/auto_annotation/functions/torchvision_keypoint_detection.py index c7199b67738b..0756b0b1738c 100644 --- a/cvat-sdk/cvat_sdk/auto_annotation/functions/torchvision_keypoint_detection.py +++ b/cvat-sdk/cvat_sdk/auto_annotation/functions/torchvision_keypoint_detection.py @@ -35,7 +35,10 @@ def spec(self) -> cvataa.DetectionFunctionSpec: ] ) - def detect(self, context, image: PIL.Image.Image) -> list[models.LabeledShapeRequest]: + def detect( + self, context: cvataa.DetectionFunctionContext, image: PIL.Image.Image + ) -> list[models.LabeledShapeRequest]: + conf_threshold = context.conf_threshold or 0 results = self._model([self._transforms(image)]) return [ @@ -51,7 +54,10 @@ def detect(self, context, image: PIL.Image.Image) -> list[models.LabeledShapeReq ], ) for result in results - for keypoints, label in zip(result["keypoints"], result["labels"]) + for keypoints, label, score in zip( + result["keypoints"], result["labels"], result["scores"] + ) + if score >= conf_threshold ] diff --git a/cvat-sdk/cvat_sdk/auto_annotation/interface.py b/cvat-sdk/cvat_sdk/auto_annotation/interface.py index 20a21fe4a5cf..47e944a1de84 100644 --- a/cvat-sdk/cvat_sdk/auto_annotation/interface.py +++ b/cvat-sdk/cvat_sdk/auto_annotation/interface.py @@ -4,7 +4,7 @@ import abc from collections.abc import Sequence -from typing import Protocol +from typing import Optional, Protocol import attrs import PIL.Image @@ -50,7 +50,23 @@ def frame_name(self) -> str: The file name of the frame that the current image corresponds to in the dataset. """ - ... + + @property + @abc.abstractmethod + def conf_threshold(self) -> Optional[float]: + """ + The confidence threshold that the function should use for filtering + detections. + + If the function is able to estimate confidence levels, then: + + * If this value is None, the function may apply a default threshold at its discretion. + + * Otherwise, it will be a number between 0 and 1. The function must only return + objects with confidence levels greater than or equal to this value. + + If the function is not able to estimate confidence levels, it can ignore this value. + """ class DetectionFunction(Protocol): diff --git a/site/content/en/docs/api_sdk/sdk/auto-annotation.md b/site/content/en/docs/api_sdk/sdk/auto-annotation.md index 24e16c7e6218..f97759efd175 100644 --- a/site/content/en/docs/api_sdk/sdk/auto-annotation.md +++ b/site/content/en/docs/api_sdk/sdk/auto-annotation.md @@ -68,7 +68,12 @@ class TorchvisionDetectionFunction: ] ) - def detect(self, context, image: PIL.Image.Image) -> List[models.LabeledShapeRequest]: + def detect( + self, context: cvataa.DetectionFunctionContext, image: PIL.Image.Image + ) -> list[models.LabeledShapeRequest]: + # determine the threshold for filtering results + conf_threshold = context.conf_threshold or 0 + # convert the input into a form the model can understand transformed_image = [self._transforms(image)] @@ -79,7 +84,8 @@ class TorchvisionDetectionFunction: return [ cvataa.rectangle(label.item(), [x.item() for x in box]) for result in results - for box, label in zip(result['boxes'], result['labels']) + for box, label, score in zip(result["boxes"], result["labels"], result["scores"]) + if score >= conf_threshold ] # log into the CVAT server @@ -112,9 +118,13 @@ that these objects must follow. `detect` must be a function/method accepting two parameters: - `context` (`DetectionFunctionContext`). - Contains information about the current image. - Currently `DetectionFunctionContext` only contains a single field, `frame_name`, - which contains the file name of the frame on the CVAT server. + Contains invocation parameters and information about the current image. + The following fields are available: + + - `frame_name` (`str`). The file name of the frame on the CVAT server. + - `conf_threshold` (`float | None`). The confidence threshold that the function + should use to filter objects. If `None`, the function may apply a default + threshold at its discretion. - `image` (`PIL.Image.Image`). Contains image data. @@ -195,6 +205,9 @@ If you use `allow_unmatched_label=True`, then such labels will be ignored, and any shapes referring to them will be dropped. Same logic applies to sub-label IDs. +It's possible to pass a custom confidence threshold to the function via the +`conf_threshold` parameter. + `annotate_task` will raise a `BadFunctionError` exception if it detects that the function violated the AA function protocol. diff --git a/tests/python/cli/conf_threshold_function.py b/tests/python/cli/conf_threshold_function.py new file mode 100644 index 000000000000..bcb1add2d660 --- /dev/null +++ b/tests/python/cli/conf_threshold_function.py @@ -0,0 +1,21 @@ +# Copyright (C) 2024 CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +import cvat_sdk.auto_annotation as cvataa +import cvat_sdk.models as models +import PIL.Image + +spec = cvataa.DetectionFunctionSpec( + labels=[ + cvataa.label_spec("car", 0), + ], +) + + +def detect( + context: cvataa.DetectionFunctionContext, image: PIL.Image.Image +) -> list[models.LabeledShapeRequest]: + return [ + cvataa.rectangle(0, [context.conf_threshold, 1, 1, 1]), + ] diff --git a/tests/python/cli/test_cli.py b/tests/python/cli/test_cli.py index d6b19cfe0a3c..a039fd3744bc 100644 --- a/tests/python/cli/test_cli.py +++ b/tests/python/cli/test_cli.py @@ -347,3 +347,17 @@ def test_auto_annotate_with_parameters(self, fxt_new_task: Task): annotations = fxt_new_task.get_annotations() assert annotations.shapes + + def test_auto_annotate_with_threshold(self, fxt_new_task: Task): + annotations = fxt_new_task.get_annotations() + assert not annotations.shapes + + self.run_cli( + "auto-annotate", + str(fxt_new_task.id), + f"--function-module={__package__}.conf_threshold_function", + "--conf-threshold=0.75", + ) + + annotations = fxt_new_task.get_annotations() + assert annotations.shapes[0].points[0] == 0.75 diff --git a/tests/python/sdk/test_auto_annotation.py b/tests/python/sdk/test_auto_annotation.py index ae4a0d711774..6fa96a5843f4 100644 --- a/tests/python/sdk/test_auto_annotation.py +++ b/tests/python/sdk/test_auto_annotation.py @@ -269,6 +269,44 @@ def detect(context, image: PIL.Image.Image) -> list[models.LabeledShapeRequest]: assert shapes[i].points == [5, 6, 7, 8] assert shapes[i].rotation == 10 + def test_conf_threshold(self): + spec = cvataa.DetectionFunctionSpec(labels=[]) + + received_threshold = None + + def detect( + context: cvataa.DetectionFunctionContext, image: PIL.Image.Image + ) -> list[models.LabeledShapeRequest]: + nonlocal received_threshold + received_threshold = context.conf_threshold + return [] + + cvataa.annotate_task( + self.client, + self.task.id, + namespace(spec=spec, detect=detect), + conf_threshold=0.75, + ) + + assert received_threshold == 0.75 + + cvataa.annotate_task( + self.client, + self.task.id, + namespace(spec=spec, detect=detect), + ) + + assert received_threshold is None + + for bad_threshold in [-0.1, 1.1]: + with pytest.raises(ValueError): + cvataa.annotate_task( + self.client, + self.task.id, + namespace(spec=spec, detect=detect), + conf_threshold=bad_threshold, + ) + def _test_bad_function_spec(self, spec: cvataa.DetectionFunctionSpec, exc_match: str) -> None: def detect(context, image): assert False @@ -575,8 +613,9 @@ def forward(self, images: list[torch.Tensor]) -> list[dict]: return [ { - "boxes": torch.tensor([[1, 2, 3, 4]]), - "labels": torch.tensor([self._label_id]), + "boxes": torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]), + "labels": torch.tensor([self._label_id, self._label_id]), + "scores": torch.tensor([0.75, 0.74]), } ] @@ -599,15 +638,17 @@ def forward(self, images: list[torch.Tensor]) -> list[dict]: return [ { - "labels": torch.tensor([self._label_id]), + "labels": torch.tensor([self._label_id, self._label_id]), "keypoints": torch.tensor( [ [ [hash(name) % 100, 0, 1 if name.startswith("right_") else 0] for i, name in enumerate(self._keypoint_names) - ] + ], + [[0, 0, 1] for i, name in enumerate(self._keypoint_names)], ] ), + "scores": torch.tensor([0.75, 0.74]), } ] @@ -672,6 +713,7 @@ def test_torchvision_detection(self, monkeypatch: pytest.MonkeyPatch): self.task.id, td.create("fasterrcnn_resnet50_fpn_v2", "COCO_V1", test_param="expected_value"), allow_unmatched_labels=True, + conf_threshold=0.75, ) annotations = self.task.get_annotations() @@ -691,6 +733,7 @@ def test_torchvision_keypoint_detection(self, monkeypatch: pytest.MonkeyPatch): self.task.id, tkd.create("keypointrcnn_resnet50_fpn", "COCO_V1", test_param="expected_value"), allow_unmatched_labels=True, + conf_threshold=0.75, ) annotations = self.task.get_annotations()