-
Notifications
You must be signed in to change notification settings - Fork 3.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
SDK/CLI: improve mask support in the auto-annotation functionality
While it is already possible to output mask shapes from AA functions, which the driver _will_ accept, it's not convenient to do so. Improve the practicalities of it in several ways: * Add `mask` and `polygon` helpers to the interface module. * Add a helper function to encode masks into the format CVAT expects. * Add a built-in torchvision-based instance segmentation function. * Add an equivalent of the `conv_mask_to_poly` parameter for Nuclio functions. Add another extra for the `masks` module, because NumPy is a fairly beefy dependency that most SDK users probably will not need (and conversely, I don't think we can implement `encode_mask` efficiently without using NumPy).
- Loading branch information
Showing
20 changed files
with
505 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
### Added | ||
|
||
- \[SDK\] Added new auto-annotation helpers (`mask`, `polygon`, `encode_mask`) | ||
to support AA functions that return masks or polygons | ||
(<https://github.com/cvat-ai/cvat/pull/8724>) | ||
|
||
- \[SDK\] Added a new built-in auto-annotation function, | ||
`torchvision_instance_segmentation` | ||
(<https://github.com/cvat-ai/cvat/pull/8724>) | ||
|
||
- \[SDK, CLI\] Added a new auto-annotation parameter, `conv_mask_to_poly` | ||
(`--conv-mask-to-poly` in the CLI) | ||
(<https://github.com/cvat-ai/cvat/pull/8724>) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
26 changes: 26 additions & 0 deletions
26
cvat-sdk/cvat_sdk/auto_annotation/functions/_torchvision.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# Copyright (C) 2024 CVAT.ai Corporation | ||
# | ||
# SPDX-License-Identifier: MIT | ||
|
||
from functools import cached_property | ||
|
||
import torchvision.models | ||
|
||
import cvat_sdk.auto_annotation as cvataa | ||
|
||
|
||
class TorchvisionFunction: | ||
def __init__(self, model_name: str, weights_name: str = "DEFAULT", **kwargs) -> None: | ||
weights_enum = torchvision.models.get_model_weights(model_name) | ||
self._weights = weights_enum[weights_name] | ||
self._transforms = self._weights.transforms() | ||
self._model = torchvision.models.get_model(model_name, weights=self._weights, **kwargs) | ||
self._model.eval() | ||
|
||
@cached_property | ||
def spec(self) -> cvataa.DetectionFunctionSpec: | ||
return cvataa.DetectionFunctionSpec( | ||
labels=[ | ||
cvataa.label_spec(cat, i) for i, cat in enumerate(self._weights.meta["categories"]) | ||
] | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
68 changes: 68 additions & 0 deletions
68
cvat-sdk/cvat_sdk/auto_annotation/functions/torchvision_instance_segmentation.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# Copyright (C) 2024 CVAT.ai Corporation | ||
# | ||
# SPDX-License-Identifier: MIT | ||
|
||
import math | ||
|
||
import PIL.Image | ||
from skimage import measure | ||
from torch import Tensor | ||
|
||
import cvat_sdk.auto_annotation as cvataa | ||
import cvat_sdk.models as models | ||
from cvat_sdk.masks import encode_mask | ||
|
||
from ._torchvision import TorchvisionFunction | ||
|
||
|
||
class _TorchvisionInstanceSegmentationFunction(TorchvisionFunction): | ||
def detect( | ||
self, context: cvataa.DetectionFunctionContext, image: PIL.Image.Image | ||
) -> list[models.LabeledShapeRequest]: | ||
conf_threshold = context.conf_threshold or 0 | ||
results = self._model([self._transforms(image)]) | ||
|
||
return [ | ||
shape | ||
for result in results | ||
for box, mask, label, score in zip( | ||
result["boxes"], result["masks"], result["labels"], result["scores"] | ||
) | ||
if score >= conf_threshold | ||
for shape in self._generate_shapes(context, box, mask, label) | ||
] | ||
|
||
def _generate_shapes( | ||
self, context: cvataa.DetectionFunctionContext, box: Tensor, mask: Tensor, label: Tensor | ||
) -> list[models.LabeledShapeRequest]: | ||
LEVEL = 0.5 | ||
|
||
if context.conv_mask_to_poly: | ||
# Since we treat mask values of exactly LEVEL as true, we'd like them | ||
# to also be considered high by find_contours. And for that, the level | ||
# parameter must be slightly less than LEVEL. | ||
contours = measure.find_contours( | ||
mask[0].detach().numpy(), level=math.nextafter(LEVEL, 0) | ||
) | ||
if not contours: | ||
return [] | ||
|
||
contour = contours[0] | ||
if len(contour) < 3: | ||
return [] | ||
|
||
contour = measure.approximate_polygon(contour, tolerance=2.5) | ||
|
||
return [ | ||
cvataa.polygon( | ||
label.item(), | ||
contour[:, ::-1].ravel().tolist(), | ||
) | ||
] | ||
else: | ||
return [ | ||
cvataa.mask(label.item(), encode_mask((mask[0] >= LEVEL).numpy(), box.tolist())) | ||
] | ||
|
||
|
||
create = _TorchvisionInstanceSegmentationFunction |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.