Skip to content

Commit

Permalink
SDK/CLI: improve mask support in the auto-annotation functionality
Browse files Browse the repository at this point in the history
While it is already possible to output mask shapes from AA functions, which
the driver _will_ accept, it's not convenient to do so. Improve the
practicalities of it in several ways:

* Add `mask` and `polygon` helpers to the interface module.

* Add a helper function to encode masks into the format CVAT expects.

* Add a built-in torchvision-based instance segmentation function.

* Add an equivalent of the `conv_mask_to_poly` parameter for Nuclio
  functions.

Add another extra for the `masks` module, because NumPy is a fairly
beefy dependency that most SDK users probably will not need (and
conversely, I don't think we can implement `encode_mask` efficiently
without using NumPy).
  • Loading branch information
SpecLad committed Nov 20, 2024
1 parent 3eec9fe commit 00222de
Show file tree
Hide file tree
Showing 20 changed files with 505 additions and 34 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/full.yml
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ jobs:
- name: Install SDK
run: |
pip3 install -r ./tests/python/requirements.txt \
-e './cvat-sdk[pytorch]' -e ./cvat-cli \
-e './cvat-sdk[masks,pytorch]' -e ./cvat-cli \
--extra-index-url https://download.pytorch.org/whl/cpu
- name: Running REST API and SDK tests
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ jobs:
- name: Install SDK
run: |
pip3 install -r ./tests/python/requirements.txt \
-e './cvat-sdk[pytorch]' -e ./cvat-cli \
-e './cvat-sdk[masks,pytorch]' -e ./cvat-cli \
--extra-index-url https://download.pytorch.org/whl/cpu
- name: Run REST API and SDK tests
Expand Down
13 changes: 13 additions & 0 deletions changelog.d/20241120_143739_roman_aa_masks.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
### Added

- \[SDK\] Added new auto-annotation helpers (`mask`, `polygon`, `encode_mask`)
to support AA functions that return masks or polygons
(<https://github.com/cvat-ai/cvat/pull/8724>)

- \[SDK\] Added a new built-in auto-annotation function,
`torchvision_instance_segmentation`
(<https://github.com/cvat-ai/cvat/pull/8724>)

- \[SDK, CLI\] Added a new auto-annotation parameter, `conv_mask_to_poly`
(`--conv-mask-to-poly` in the CLI)
(<https://github.com/cvat-ai/cvat/pull/8724>)
8 changes: 8 additions & 0 deletions cvat-cli/src/cvat_cli/_internal/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,12 @@ def configure_parser(self, parser: argparse.ArgumentParser) -> None:
default=None,
)

parser.add_argument(
"--conv-mask-to-poly",
action="store_true",
help="Convert mask shapes to polygon shapes",
)

def execute(
self,
client: Client,
Expand All @@ -487,6 +493,7 @@ def execute(
clear_existing: bool = False,
allow_unmatched_labels: bool = False,
conf_threshold: Optional[float],
conv_mask_to_poly: bool,
) -> None:
if function_module is not None:
function = importlib.import_module(function_module)
Expand All @@ -512,4 +519,5 @@ def execute(
clear_existing=clear_existing,
allow_unmatched_labels=allow_unmatched_labels,
conf_threshold=conf_threshold,
conv_mask_to_poly=conv_mask_to_poly,
)
9 changes: 8 additions & 1 deletion cvat-sdk/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,14 @@ To install a prebuilt package, run the following command in the terminal:
pip install cvat-sdk
```

To use the PyTorch adapter, request the `pytorch` extra:
To use the `cvat_sdk.masks` module, request the `masks` extra:

```bash
pip install "cvat-sdk[masks]"
```

To use the PyTorch adapter or the built-in PyTorch-based auto-annotation functions,
request the `pytorch` extra:

```bash
pip install "cvat-sdk[pytorch]"
Expand Down
19 changes: 19 additions & 0 deletions cvat-sdk/cvat_sdk/auto_annotation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,27 @@
keypoint,
keypoint_spec,
label_spec,
mask,
polygon,
rectangle,
shape,
skeleton,
skeleton_label_spec,
)

__all__ = [
"annotate_task",
"BadFunctionError",
"DetectionFunction",
"DetectionFunctionContext",
"DetectionFunctionSpec",
"keypoint_spec",
"keypoint",
"label_spec",
"mask",
"polygon",
"rectangle",
"shape",
"skeleton_label_spec",
"skeleton",
]
20 changes: 18 additions & 2 deletions cvat-sdk/cvat_sdk/auto_annotation/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,11 @@ def __init__(
ds_labels: Sequence[models.ILabel],
*,
allow_unmatched_labels: bool,
conv_mask_to_poly: bool,
) -> None:
self._logger = logger
self._allow_unmatched_labels = allow_unmatched_labels
self._conv_mask_to_poly = conv_mask_to_poly

ds_labels_by_name = {ds_label.name: ds_label for ds_label in ds_labels}

Expand Down Expand Up @@ -217,13 +219,19 @@ def validate_and_remap(self, shapes: list[models.LabeledShapeRequest], ds_frame:
if getattr(shape, "elements", None):
raise BadFunctionError("function output non-skeleton shape with elements")

if shape.type.value == "mask" and self._conv_mask_to_poly:
raise BadFunctionError(
"function output mask shape despite conv_mask_to_poly=True"
)

shapes[:] = new_shapes


@attrs.frozen(kw_only=True)
class _DetectionFunctionContextImpl(DetectionFunctionContext):
frame_name: str
conf_threshold: Optional[float] = None
conv_mask_to_poly: bool = False


def annotate_task(
Expand All @@ -235,6 +243,7 @@ def annotate_task(
clear_existing: bool = False,
allow_unmatched_labels: bool = False,
conf_threshold: Optional[float] = None,
conv_mask_to_poly: bool = False,
) -> None:
"""
Downloads data for the task with the given ID, applies the given function to it
Expand Down Expand Up @@ -268,7 +277,11 @@ def annotate_task(
function that refer to this label are ignored. Otherwise, BadFunctionError is raised.
The conf_threshold parameter must be None or a number between 0 and 1. It will be passed
to the function as the conf_threshold attribute of the context object.
to the AA function as the conf_threshold attribute of the context object.
The conv_mask_to_poly parameter will be passed to the AA function as the conv_mask_to_poly
attribute of the context object. If it's true, and the AA function returns any mask shapes,
BadFunctionError will be raised.
"""

if pbar is None:
Expand All @@ -286,6 +299,7 @@ def annotate_task(
function.spec.labels,
dataset.labels,
allow_unmatched_labels=allow_unmatched_labels,
conv_mask_to_poly=conv_mask_to_poly,
)

shapes = []
Expand All @@ -294,7 +308,9 @@ def annotate_task(
for sample in pbar.iter(dataset.samples):
frame_shapes = function.detect(
_DetectionFunctionContextImpl(
frame_name=sample.frame_name, conf_threshold=conf_threshold
frame_name=sample.frame_name,
conf_threshold=conf_threshold,
conv_mask_to_poly=conv_mask_to_poly,
),
sample.media.load_image(),
)
Expand Down
26 changes: 26 additions & 0 deletions cvat-sdk/cvat_sdk/auto_annotation/functions/_torchvision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright (C) 2024 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT

from functools import cached_property

import torchvision.models

import cvat_sdk.auto_annotation as cvataa


class TorchvisionFunction:
def __init__(self, model_name: str, weights_name: str = "DEFAULT", **kwargs) -> None:
weights_enum = torchvision.models.get_model_weights(model_name)
self._weights = weights_enum[weights_name]
self._transforms = self._weights.transforms()
self._model = torchvision.models.get_model(model_name, weights=self._weights, **kwargs)
self._model.eval()

@cached_property
def spec(self) -> cvataa.DetectionFunctionSpec:
return cvataa.DetectionFunctionSpec(
labels=[
cvataa.label_spec(cat, i) for i, cat in enumerate(self._weights.meta["categories"])
]
)
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,15 @@
#
# SPDX-License-Identifier: MIT

from functools import cached_property

import PIL.Image
import torchvision.models

import cvat_sdk.auto_annotation as cvataa
import cvat_sdk.models as models

from ._torchvision import TorchvisionFunction

class _TorchvisionDetectionFunction:
def __init__(self, model_name: str, weights_name: str = "DEFAULT", **kwargs) -> None:
weights_enum = torchvision.models.get_model_weights(model_name)
self._weights = weights_enum[weights_name]
self._transforms = self._weights.transforms()
self._model = torchvision.models.get_model(model_name, weights=self._weights, **kwargs)
self._model.eval()

@cached_property
def spec(self) -> cvataa.DetectionFunctionSpec:
return cvataa.DetectionFunctionSpec(
labels=[
cvataa.label_spec(cat, i) for i, cat in enumerate(self._weights.meta["categories"])
]
)

class _TorchvisionDetectionFunction(TorchvisionFunction):
def detect(
self, context: cvataa.DetectionFunctionContext, image: PIL.Image.Image
) -> list[models.LabeledShapeRequest]:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright (C) 2024 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT

import math

import PIL.Image
from skimage import measure
from torch import Tensor

import cvat_sdk.auto_annotation as cvataa
import cvat_sdk.models as models
from cvat_sdk.masks import encode_mask

from ._torchvision import TorchvisionFunction


class _TorchvisionInstanceSegmentationFunction(TorchvisionFunction):
def detect(
self, context: cvataa.DetectionFunctionContext, image: PIL.Image.Image
) -> list[models.LabeledShapeRequest]:
conf_threshold = context.conf_threshold or 0
results = self._model([self._transforms(image)])

return [
shape
for result in results
for box, mask, label, score in zip(
result["boxes"], result["masks"], result["labels"], result["scores"]
)
if score >= conf_threshold
for shape in self._generate_shapes(context, box, mask, label)
]

def _generate_shapes(
self, context: cvataa.DetectionFunctionContext, box: Tensor, mask: Tensor, label: Tensor
) -> list[models.LabeledShapeRequest]:
LEVEL = 0.5

if context.conv_mask_to_poly:
# Since we treat mask values of exactly LEVEL as true, we'd like them
# to also be considered high by find_contours. And for that, the level
# parameter must be slightly less than LEVEL.
contours = measure.find_contours(
mask[0].detach().numpy(), level=math.nextafter(LEVEL, 0)
)
if not contours:
return []

contour = contours[0]
if len(contour) < 3:
return []

contour = measure.approximate_polygon(contour, tolerance=2.5)

return [
cvataa.polygon(
label.item(),
contour[:, ::-1].ravel().tolist(),
)
]
else:
return [
cvataa.mask(label.item(), encode_mask((mask[0] >= LEVEL).numpy(), box.tolist()))
]


create = _TorchvisionInstanceSegmentationFunction
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,14 @@
from functools import cached_property

import PIL.Image
import torchvision.models

import cvat_sdk.auto_annotation as cvataa
import cvat_sdk.models as models

from ._torchvision import TorchvisionFunction

class _TorchvisionKeypointDetectionFunction:
def __init__(self, model_name: str, weights_name: str = "DEFAULT", **kwargs) -> None:
weights_enum = torchvision.models.get_model_weights(model_name)
self._weights = weights_enum[weights_name]
self._transforms = self._weights.transforms()
self._model = torchvision.models.get_model(model_name, weights=self._weights, **kwargs)
self._model.eval()

class _TorchvisionKeypointDetectionFunction(TorchvisionFunction):
@cached_property
def spec(self) -> cvataa.DetectionFunctionSpec:
return cvataa.DetectionFunctionSpec(
Expand Down
25 changes: 25 additions & 0 deletions cvat-sdk/cvat_sdk/auto_annotation/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,16 @@ def conf_threshold(self) -> Optional[float]:
If the function is not able to estimate confidence levels, it can ignore this value.
"""

@property
@abc.abstractmethod
def conv_mask_to_poly(self) -> bool:
"""
If this is true, the function must convert any mask shapes to polygon shapes
before returning them.
If the function does not return any mask shapes, then it can ignore this value.
"""


class DetectionFunction(Protocol):
"""
Expand Down Expand Up @@ -168,6 +178,21 @@ def rectangle(label_id: int, points: Sequence[float], **kwargs) -> models.Labele
return shape(label_id, type="rectangle", points=points, **kwargs)


def polygon(label_id: int, points: Sequence[float], **kwargs) -> models.LabeledShapeRequest:
"""Helper factory function for LabeledShapeRequest with frame=0 and type="polygon"."""
return shape(label_id, type="polygon", points=points, **kwargs)


def mask(label_id: int, points: Sequence[float], **kwargs) -> models.LabeledShapeRequest:
"""
Helper factory function for LabeledShapeRequest with frame=0 and type="mask".
It's recommended to use the cvat.masks.encode_mask function to build the
points argument.
"""
return shape(label_id, type="mask", points=points, **kwargs)


def skeleton(
label_id: int, elements: Sequence[models.SubLabeledShapeRequest], **kwargs
) -> models.LabeledShapeRequest:
Expand Down
Loading

0 comments on commit 00222de

Please sign in to comment.