Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SDK/CLI: Add ability to call auto-annotation functions with a custom threshold #8688

Merged
merged 2 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions changelog.d/20241112_201034_roman_aa_threshold.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
### Added

- \[SDK, CLI\] Added a `conf_threshold` parameter to
`cvat_sdk.auto_annotation.annotate_task`, which is passed as-is to the AA
function object via the context. The CLI equivalent is `auto-annotate
--conf-threshold`. This makes it easier to write and use AA functions that
support object filtering based on confidence levels
(<https://github.com/cvat-ai/cvat/pull/8688>)

- \[SDK\] Built-in auto-annotation functions now support object filtering by
confidence level
(<https://github.com/cvat-ai/cvat/pull/8688>)
17 changes: 16 additions & 1 deletion cvat-cli/src/cvat_cli/_internal/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,13 @@
from cvat_sdk.core.proxies.tasks import ResourceType

from .command_base import CommandGroup
from .parsers import BuildDictAction, parse_function_parameter, parse_label_arg, parse_resource_type
from .parsers import (
BuildDictAction,
parse_function_parameter,
parse_label_arg,
parse_resource_type,
parse_threshold,
)

COMMANDS = CommandGroup(description="Perform common operations related to CVAT tasks.")

Expand Down Expand Up @@ -463,6 +469,13 @@ def configure_parser(self, parser: argparse.ArgumentParser) -> None:
help="Allow the function to declare labels not configured in the task",
)

parser.add_argument(
"--conf-threshold",
type=parse_threshold,
help="Confidence threshold for filtering detections",
default=None,
)

def execute(
self,
client: Client,
Expand All @@ -473,6 +486,7 @@ def execute(
function_parameters: dict[str, Any],
clear_existing: bool = False,
allow_unmatched_labels: bool = False,
conf_threshold: Optional[float],
) -> None:
if function_module is not None:
function = importlib.import_module(function_module)
Expand All @@ -497,4 +511,5 @@ def execute(
pbar=DeferredTqdmProgressReporter(),
clear_existing=clear_existing,
allow_unmatched_labels=allow_unmatched_labels,
conf_threshold=conf_threshold,
)
11 changes: 11 additions & 0 deletions cvat-cli/src/cvat_cli/_internal/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,17 @@ def parse_function_parameter(s: str) -> tuple[str, Any]:
return (key, value)


def parse_threshold(s: str) -> float:
try:
value = float(s)
except ValueError as e:
raise argparse.ArgumentTypeError("must be a number") from e

if not 0 <= value <= 1:
raise argparse.ArgumentTypeError("must be between 0 and 1")
return value


class BuildDictAction(argparse.Action):
def __init__(self, option_strings, dest, default=None, **kwargs):
super().__init__(option_strings, dest, default=default or {}, **kwargs)
Expand Down
15 changes: 13 additions & 2 deletions cvat-sdk/cvat_sdk/auto_annotation/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,10 @@ def validate_and_remap(self, shapes: list[models.LabeledShapeRequest], ds_frame:
shapes[:] = new_shapes


@attrs.frozen
@attrs.frozen(kw_only=True)
class _DetectionFunctionContextImpl(DetectionFunctionContext):
frame_name: str
conf_threshold: Optional[float] = None


def annotate_task(
Expand All @@ -233,6 +234,7 @@ def annotate_task(
pbar: Optional[ProgressReporter] = None,
clear_existing: bool = False,
allow_unmatched_labels: bool = False,
conf_threshold: Optional[float] = None,
) -> None:
"""
Downloads data for the task with the given ID, applies the given function to it
Expand Down Expand Up @@ -264,11 +266,17 @@ def annotate_task(
function declares a label in its spec that has no corresponding label in the task.
If it's set to true, then such labels are allowed, and any annotations returned by the
function that refer to this label are ignored. Otherwise, BadFunctionError is raised.

The conf_threshold parameter must be None or a number between 0 and 1. It will be passed
to the function as the conf_threshold attribute of the context object.
"""

if pbar is None:
pbar = NullProgressReporter()

if conf_threshold is not None and not 0 <= conf_threshold <= 1:
raise ValueError("conf_threshold must be None or a number between 0 and 1")

dataset = TaskDataset(client, task_id, load_annotations=False)

assert isinstance(function.spec, DetectionFunctionSpec)
Expand All @@ -285,7 +293,10 @@ def annotate_task(
with pbar.task(total=len(dataset.samples), unit="samples"):
for sample in pbar.iter(dataset.samples):
frame_shapes = function.detect(
_DetectionFunctionContextImpl(sample.frame_name), sample.media.load_image()
_DetectionFunctionContextImpl(
frame_name=sample.frame_name, conf_threshold=conf_threshold
),
sample.media.load_image(),
)
mapper.validate_and_remap(frame_shapes, sample.frame_index)
shapes.extend(frame_shapes)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,17 @@ def spec(self) -> cvataa.DetectionFunctionSpec:
]
)

def detect(self, context, image: PIL.Image.Image) -> list[models.LabeledShapeRequest]:
def detect(
self, context: cvataa.DetectionFunctionContext, image: PIL.Image.Image
) -> list[models.LabeledShapeRequest]:
conf_threshold = context.conf_threshold or 0
results = self._model([self._transforms(image)])

return [
cvataa.rectangle(label.item(), [x.item() for x in box])
for result in results
for box, label in zip(result["boxes"], result["labels"])
for box, label, score in zip(result["boxes"], result["labels"], result["scores"])
if score >= conf_threshold
]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@ def spec(self) -> cvataa.DetectionFunctionSpec:
]
)

def detect(self, context, image: PIL.Image.Image) -> list[models.LabeledShapeRequest]:
def detect(
self, context: cvataa.DetectionFunctionContext, image: PIL.Image.Image
) -> list[models.LabeledShapeRequest]:
conf_threshold = context.conf_threshold or 0
results = self._model([self._transforms(image)])

return [
Expand All @@ -51,7 +54,10 @@ def detect(self, context, image: PIL.Image.Image) -> list[models.LabeledShapeReq
],
)
for result in results
for keypoints, label in zip(result["keypoints"], result["labels"])
for keypoints, label, score in zip(
result["keypoints"], result["labels"], result["scores"]
)
if score >= conf_threshold
]


Expand Down
20 changes: 18 additions & 2 deletions cvat-sdk/cvat_sdk/auto_annotation/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import abc
from collections.abc import Sequence
from typing import Protocol
from typing import Optional, Protocol

import attrs
import PIL.Image
Expand Down Expand Up @@ -50,7 +50,23 @@ def frame_name(self) -> str:
The file name of the frame that the current image corresponds to in
the dataset.
"""
...

@property
@abc.abstractmethod
def conf_threshold(self) -> Optional[float]:
"""
The confidence threshold that the function should use for filtering
detections.

If the function is able to estimate confidence levels, then:

* If this value is None, the function may apply a default threshold at its discretion.

* Otherwise, it will be a number between 0 and 1. The function must only return
objects with confidence levels greater than or equal to this value.

If the function is not able to estimate confidence levels, it can ignore this value.
"""


class DetectionFunction(Protocol):
Expand Down
23 changes: 18 additions & 5 deletions site/content/en/docs/api_sdk/sdk/auto-annotation.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,12 @@ class TorchvisionDetectionFunction:
]
)

def detect(self, context, image: PIL.Image.Image) -> List[models.LabeledShapeRequest]:
def detect(
self, context: cvataa.DetectionFunctionContext, image: PIL.Image.Image
) -> list[models.LabeledShapeRequest]:
# determine the threshold for filtering results
conf_threshold = context.conf_threshold or 0

# convert the input into a form the model can understand
transformed_image = [self._transforms(image)]

Expand All @@ -79,7 +84,8 @@ class TorchvisionDetectionFunction:
return [
cvataa.rectangle(label.item(), [x.item() for x in box])
for result in results
for box, label in zip(result['boxes'], result['labels'])
for box, label, score in zip(result["boxes"], result["labels"], result["scores"])
if score >= conf_threshold
]

# log into the CVAT server
Expand Down Expand Up @@ -112,9 +118,13 @@ that these objects must follow.
`detect` must be a function/method accepting two parameters:

- `context` (`DetectionFunctionContext`).
Contains information about the current image.
Currently `DetectionFunctionContext` only contains a single field, `frame_name`,
which contains the file name of the frame on the CVAT server.
Contains invocation parameters and information about the current image.
The following fields are available:

- `frame_name` (`str`). The file name of the frame on the CVAT server.
- `conf_threshold` (`float | None`). The confidence threshold that the function
should use to filter objects. If `None`, the function may apply a default
threshold at its discretion.

- `image` (`PIL.Image.Image`).
Contains image data.
Expand Down Expand Up @@ -195,6 +205,9 @@ If you use `allow_unmatched_label=True`, then such labels will be ignored,
and any shapes referring to them will be dropped.
Same logic applies to sub-label IDs.

It's possible to pass a custom confidence threshold to the function via the
`conf_threshold` parameter.

`annotate_task` will raise a `BadFunctionError` exception
if it detects that the function violated the AA function protocol.

Expand Down
21 changes: 21 additions & 0 deletions tests/python/cli/conf_threshold_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (C) 2024 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT

import cvat_sdk.auto_annotation as cvataa
import cvat_sdk.models as models
import PIL.Image

spec = cvataa.DetectionFunctionSpec(
labels=[
cvataa.label_spec("car", 0),
],
)


def detect(
context: cvataa.DetectionFunctionContext, image: PIL.Image.Image
) -> list[models.LabeledShapeRequest]:
return [
cvataa.rectangle(0, [context.conf_threshold, 1, 1, 1]),
]
14 changes: 14 additions & 0 deletions tests/python/cli/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,3 +347,17 @@ def test_auto_annotate_with_parameters(self, fxt_new_task: Task):

annotations = fxt_new_task.get_annotations()
assert annotations.shapes

def test_auto_annotate_with_threshold(self, fxt_new_task: Task):
annotations = fxt_new_task.get_annotations()
assert not annotations.shapes

self.run_cli(
"auto-annotate",
str(fxt_new_task.id),
f"--function-module={__package__}.conf_threshold_function",
"--conf-threshold=0.75",
)

annotations = fxt_new_task.get_annotations()
assert annotations.shapes[0].points[0] == 0.75
51 changes: 47 additions & 4 deletions tests/python/sdk/test_auto_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,44 @@ def detect(context, image: PIL.Image.Image) -> list[models.LabeledShapeRequest]:
assert shapes[i].points == [5, 6, 7, 8]
assert shapes[i].rotation == 10

def test_conf_threshold(self):
spec = cvataa.DetectionFunctionSpec(labels=[])

received_threshold = None

def detect(
context: cvataa.DetectionFunctionContext, image: PIL.Image.Image
) -> list[models.LabeledShapeRequest]:
nonlocal received_threshold
received_threshold = context.conf_threshold
return []

cvataa.annotate_task(
self.client,
self.task.id,
namespace(spec=spec, detect=detect),
conf_threshold=0.75,
)

assert received_threshold == 0.75

cvataa.annotate_task(
self.client,
self.task.id,
namespace(spec=spec, detect=detect),
)

assert received_threshold is None

for bad_threshold in [-0.1, 1.1]:
with pytest.raises(ValueError):
cvataa.annotate_task(
self.client,
self.task.id,
namespace(spec=spec, detect=detect),
conf_threshold=bad_threshold,
)

def _test_bad_function_spec(self, spec: cvataa.DetectionFunctionSpec, exc_match: str) -> None:
def detect(context, image):
assert False
Expand Down Expand Up @@ -575,8 +613,9 @@ def forward(self, images: list[torch.Tensor]) -> list[dict]:

return [
{
"boxes": torch.tensor([[1, 2, 3, 4]]),
"labels": torch.tensor([self._label_id]),
"boxes": torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]),
"labels": torch.tensor([self._label_id, self._label_id]),
"scores": torch.tensor([0.75, 0.74]),
}
]

Expand All @@ -599,15 +638,17 @@ def forward(self, images: list[torch.Tensor]) -> list[dict]:

return [
{
"labels": torch.tensor([self._label_id]),
"labels": torch.tensor([self._label_id, self._label_id]),
"keypoints": torch.tensor(
[
[
[hash(name) % 100, 0, 1 if name.startswith("right_") else 0]
for i, name in enumerate(self._keypoint_names)
]
],
[[0, 0, 1] for i, name in enumerate(self._keypoint_names)],
]
),
"scores": torch.tensor([0.75, 0.74]),
}
]

Expand Down Expand Up @@ -672,6 +713,7 @@ def test_torchvision_detection(self, monkeypatch: pytest.MonkeyPatch):
self.task.id,
td.create("fasterrcnn_resnet50_fpn_v2", "COCO_V1", test_param="expected_value"),
allow_unmatched_labels=True,
conf_threshold=0.75,
)

annotations = self.task.get_annotations()
Expand All @@ -691,6 +733,7 @@ def test_torchvision_keypoint_detection(self, monkeypatch: pytest.MonkeyPatch):
self.task.id,
tkd.create("keypointrcnn_resnet50_fpn", "COCO_V1", test_param="expected_value"),
allow_unmatched_labels=True,
conf_threshold=0.75,
)

annotations = self.task.get_annotations()
Expand Down
Loading