Skip to content

Commit

Permalink
SDK/CLI: Add ability to call auto-annotation functions with a custom …
Browse files Browse the repository at this point in the history
…threshold

My main motivation is to support a future feature, but I think this is a
good thing in its own right.

While it's already possible to create an AA function that lets you customize the
threshold (by adding a creation parameter), confidence scoring is very
common in detection models, so it makes sense to make this easier to
support, both for the implementer of the function, and for its user.
  • Loading branch information
SpecLad committed Nov 14, 2024
1 parent 111feec commit 79c09ee
Show file tree
Hide file tree
Showing 11 changed files with 179 additions and 18 deletions.
9 changes: 9 additions & 0 deletions changelog.d/20241112_201034_roman_aa_threshold.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
### Added

- \[SDK, CLI\] Added a `threshold` parameter to `cvat_sdk.auto_annotation.annotate_task`,
which is passed as-is to the AA function object via the context. The CLI
equivalent is `auto-annotate --threshold`. This makes it easier to write
and use AA functions that support object filtering based on confidence
levels. Updated the builtin functions in `cvat_sdk.auto_annotation.functions`
to support filtering via this parameter
(<https://github.com/cvat-ai/cvat/pull/8688>)
17 changes: 16 additions & 1 deletion cvat-cli/src/cvat_cli/_internal/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,13 @@
from cvat_sdk.core.proxies.tasks import ResourceType

from .command_base import CommandGroup
from .parsers import BuildDictAction, parse_function_parameter, parse_label_arg, parse_resource_type
from .parsers import (
BuildDictAction,
parse_function_parameter,
parse_label_arg,
parse_resource_type,
parse_threshold,
)

COMMANDS = CommandGroup(description="Perform common operations related to CVAT tasks.")

Expand Down Expand Up @@ -463,6 +469,13 @@ def configure_parser(self, parser: argparse.ArgumentParser) -> None:
help="Allow the function to declare labels not configured in the task",
)

parser.add_argument(
"--threshold",
type=parse_threshold,
help="Confidence threshold for filtering detections",
default=None,
)

def execute(
self,
client: Client,
Expand All @@ -473,6 +486,7 @@ def execute(
function_parameters: dict[str, Any],
clear_existing: bool = False,
allow_unmatched_labels: bool = False,
threshold: Optional[float],
) -> None:
if function_module is not None:
function = importlib.import_module(function_module)
Expand All @@ -497,4 +511,5 @@ def execute(
pbar=DeferredTqdmProgressReporter(),
clear_existing=clear_existing,
allow_unmatched_labels=allow_unmatched_labels,
threshold=threshold,
)
11 changes: 11 additions & 0 deletions cvat-cli/src/cvat_cli/_internal/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,17 @@ def parse_function_parameter(s: str) -> tuple[str, Any]:
return (key, value)


def parse_threshold(s: str) -> float:
try:
value = float(s)
except ValueError as e:
raise argparse.ArgumentTypeError("must be a number") from e

if not 0 <= value <= 1:
raise argparse.ArgumentTypeError("must be between 0 and 1")
return value


class BuildDictAction(argparse.Action):
def __init__(self, option_strings, dest, default=None, **kwargs):
super().__init__(option_strings, dest, default=default or {}, **kwargs)
Expand Down
13 changes: 11 additions & 2 deletions cvat-sdk/cvat_sdk/auto_annotation/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,10 @@ def validate_and_remap(self, shapes: list[models.LabeledShapeRequest], ds_frame:
shapes[:] = new_shapes


@attrs.frozen
@attrs.frozen(kw_only=True)
class _DetectionFunctionContextImpl(DetectionFunctionContext):
frame_name: str
threshold: Optional[float] = None


def annotate_task(
Expand All @@ -233,6 +234,7 @@ def annotate_task(
pbar: Optional[ProgressReporter] = None,
clear_existing: bool = False,
allow_unmatched_labels: bool = False,
threshold: Optional[float] = None,
) -> None:
"""
Downloads data for the task with the given ID, applies the given function to it
Expand Down Expand Up @@ -264,11 +266,17 @@ def annotate_task(
function declares a label in its spec that has no corresponding label in the task.
If it's set to true, then such labels are allowed, and any annotations returned by the
function that refer to this label are ignored. Otherwise, BadFunctionError is raised.
The threshold parameter must be None or a number between 0 and 1. It will be passed
to the function as the threshold attribute of the context object.
"""

if pbar is None:
pbar = NullProgressReporter()

if threshold is not None and not 0 <= threshold <= 1:
raise ValueError("threshold must be None or a number between 0 and 1")

dataset = TaskDataset(client, task_id, load_annotations=False)

assert isinstance(function.spec, DetectionFunctionSpec)
Expand All @@ -285,7 +293,8 @@ def annotate_task(
with pbar.task(total=len(dataset.samples), unit="samples"):
for sample in pbar.iter(dataset.samples):
frame_shapes = function.detect(
_DetectionFunctionContextImpl(sample.frame_name), sample.media.load_image()
_DetectionFunctionContextImpl(frame_name=sample.frame_name, threshold=threshold),
sample.media.load_image(),
)
mapper.validate_and_remap(frame_shapes, sample.frame_index)
shapes.extend(frame_shapes)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,17 @@ def spec(self) -> cvataa.DetectionFunctionSpec:
]
)

def detect(self, context, image: PIL.Image.Image) -> list[models.LabeledShapeRequest]:
def detect(
self, context: cvataa.DetectionFunctionContext, image: PIL.Image.Image
) -> list[models.LabeledShapeRequest]:
threshold = context.threshold or 0
results = self._model([self._transforms(image)])

return [
cvataa.rectangle(label.item(), [x.item() for x in box])
for result in results
for box, label in zip(result["boxes"], result["labels"])
for box, label, score in zip(result["boxes"], result["labels"], result["scores"])
if score >= threshold
]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@ def spec(self) -> cvataa.DetectionFunctionSpec:
]
)

def detect(self, context, image: PIL.Image.Image) -> list[models.LabeledShapeRequest]:
def detect(
self, context: cvataa.DetectionFunctionContext, image: PIL.Image.Image
) -> list[models.LabeledShapeRequest]:
threshold = context.threshold or 0
results = self._model([self._transforms(image)])

return [
Expand All @@ -51,7 +54,10 @@ def detect(self, context, image: PIL.Image.Image) -> list[models.LabeledShapeReq
],
)
for result in results
for keypoints, label in zip(result["keypoints"], result["labels"])
for keypoints, label, score in zip(
result["keypoints"], result["labels"], result["scores"]
)
if score >= threshold
]


Expand Down
20 changes: 18 additions & 2 deletions cvat-sdk/cvat_sdk/auto_annotation/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import abc
from collections.abc import Sequence
from typing import Protocol
from typing import Optional, Protocol

import attrs
import PIL.Image
Expand Down Expand Up @@ -50,7 +50,23 @@ def frame_name(self) -> str:
The file name of the frame that the current image corresponds to in
the dataset.
"""
...

@property
@abc.abstractmethod
def threshold(self) -> Optional[float]:
"""
The confidence threshold that the function should use for filtering
detections.
If the function is able to estimate confidence levels, then:
* If this value is None, the function may apply a default threshold at its discretion.
* Otherwise, it will be a number between 0 and 1. The function must only return
objects with confidence levels greater than or equal to this value.
If the function is not able to estimate confidence levels, it can ignore this value.
"""


class DetectionFunction(Protocol):
Expand Down
23 changes: 18 additions & 5 deletions site/content/en/docs/api_sdk/sdk/auto-annotation.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,12 @@ class TorchvisionDetectionFunction:
]
)

def detect(self, context, image: PIL.Image.Image) -> List[models.LabeledShapeRequest]:
def detect(
self, context: cvataa.DetectionFunctionContext, image: PIL.Image.Image
) -> list[models.LabeledShapeRequest]:
# determine the threshold for filtering results
threshold = context.threshold or 0

# convert the input into a form the model can understand
transformed_image = [self._transforms(image)]

Expand All @@ -79,7 +84,8 @@ class TorchvisionDetectionFunction:
return [
cvataa.rectangle(label.item(), [x.item() for x in box])
for result in results
for box, label in zip(result['boxes'], result['labels'])
for box, label, score in zip(result["boxes"], result["labels"], result["scores"])
if score >= threshold
]

# log into the CVAT server
Expand Down Expand Up @@ -112,9 +118,13 @@ that these objects must follow.
`detect` must be a function/method accepting two parameters:

- `context` (`DetectionFunctionContext`).
Contains information about the current image.
Currently `DetectionFunctionContext` only contains a single field, `frame_name`,
which contains the file name of the frame on the CVAT server.
Contains invocation parameters and information about the current image.
The following fields are available:

- `frame_name` (`str`). The file name of the frame on the CVAT server.
- `threshold` (`float | None`). The confidence threshold that the function
should use to filter objects. If `None`, the function may apply a default
threshold at its discretion.

- `image` (`PIL.Image.Image`).
Contains image data.
Expand Down Expand Up @@ -195,6 +205,9 @@ If you use `allow_unmatched_label=True`, then such labels will be ignored,
and any shapes referring to them will be dropped.
Same logic applies to sub-label IDs.

It's possible to pass a custom confidence threshold to the function via the
`threshold` parameter.

`annotate_task` will raise a `BadFunctionError` exception
if it detects that the function violated the AA function protocol.

Expand Down
14 changes: 14 additions & 0 deletions tests/python/cli/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,3 +347,17 @@ def test_auto_annotate_with_parameters(self, fxt_new_task: Task):

annotations = fxt_new_task.get_annotations()
assert annotations.shapes

def test_auto_annotate_with_threshold(self, fxt_new_task: Task):
annotations = fxt_new_task.get_annotations()
assert not annotations.shapes

self.run_cli(
"auto-annotate",
str(fxt_new_task.id),
f"--function-module={__package__}.threshold_function",
"--threshold=0.75",
)

annotations = fxt_new_task.get_annotations()
assert annotations.shapes[0].points[0] == 0.75
21 changes: 21 additions & 0 deletions tests/python/cli/threshold_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (C) 2024 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT

import cvat_sdk.auto_annotation as cvataa
import cvat_sdk.models as models
import PIL.Image

spec = cvataa.DetectionFunctionSpec(
labels=[
cvataa.label_spec("car", 0),
],
)


def detect(
context: cvataa.DetectionFunctionContext, image: PIL.Image.Image
) -> list[models.LabeledShapeRequest]:
return [
cvataa.rectangle(0, [context.threshold, 1, 1, 1]),
]
51 changes: 47 additions & 4 deletions tests/python/sdk/test_auto_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,44 @@ def detect(context, image: PIL.Image.Image) -> list[models.LabeledShapeRequest]:
assert shapes[i].points == [5, 6, 7, 8]
assert shapes[i].rotation == 10

def test_threshold(self):
spec = cvataa.DetectionFunctionSpec(labels=[])

received_threshold = None

def detect(
context: cvataa.DetectionFunctionContext, image: PIL.Image.Image
) -> list[models.LabeledShapeRequest]:
nonlocal received_threshold
received_threshold = context.threshold
return []

cvataa.annotate_task(
self.client,
self.task.id,
namespace(spec=spec, detect=detect),
threshold=0.75,
)

assert received_threshold == 0.75

cvataa.annotate_task(
self.client,
self.task.id,
namespace(spec=spec, detect=detect),
)

assert received_threshold is None

for bad_threshold in [-0.1, 1.1]:
with pytest.raises(ValueError):
cvataa.annotate_task(
self.client,
self.task.id,
namespace(spec=spec, detect=detect),
threshold=bad_threshold,
)

def _test_bad_function_spec(self, spec: cvataa.DetectionFunctionSpec, exc_match: str) -> None:
def detect(context, image):
assert False
Expand Down Expand Up @@ -575,8 +613,9 @@ def forward(self, images: list[torch.Tensor]) -> list[dict]:

return [
{
"boxes": torch.tensor([[1, 2, 3, 4]]),
"labels": torch.tensor([self._label_id]),
"boxes": torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]),
"labels": torch.tensor([self._label_id, self._label_id]),
"scores": torch.tensor([0.75, 0.74]),
}
]

Expand All @@ -599,15 +638,17 @@ def forward(self, images: list[torch.Tensor]) -> list[dict]:

return [
{
"labels": torch.tensor([self._label_id]),
"labels": torch.tensor([self._label_id, self._label_id]),
"keypoints": torch.tensor(
[
[
[hash(name) % 100, 0, 1 if name.startswith("right_") else 0]
for i, name in enumerate(self._keypoint_names)
]
],
[[0, 0, 1] for i, name in enumerate(self._keypoint_names)],
]
),
"scores": torch.tensor([0.75, 0.74]),
}
]

Expand Down Expand Up @@ -672,6 +713,7 @@ def test_torchvision_detection(self, monkeypatch: pytest.MonkeyPatch):
self.task.id,
td.create("fasterrcnn_resnet50_fpn_v2", "COCO_V1", test_param="expected_value"),
allow_unmatched_labels=True,
threshold=0.75,
)

annotations = self.task.get_annotations()
Expand All @@ -691,6 +733,7 @@ def test_torchvision_keypoint_detection(self, monkeypatch: pytest.MonkeyPatch):
self.task.id,
tkd.create("keypointrcnn_resnet50_fpn", "COCO_V1", test_param="expected_value"),
allow_unmatched_labels=True,
threshold=0.75,
)

annotations = self.task.get_annotations()
Expand Down

0 comments on commit 79c09ee

Please sign in to comment.