Skip to content

Commit

Permalink
Rename threshold to conf_threshold for clarity
Browse files Browse the repository at this point in the history
  • Loading branch information
SpecLad committed Nov 14, 2024
1 parent 79c09ee commit 06be1ca
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 19 deletions.
6 changes: 3 additions & 3 deletions cvat-cli/src/cvat_cli/_internal/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ def configure_parser(self, parser: argparse.ArgumentParser) -> None:
)

parser.add_argument(
"--threshold",
"--conf-threshold",
type=parse_threshold,
help="Confidence threshold for filtering detections",
default=None,
Expand All @@ -486,7 +486,7 @@ def execute(
function_parameters: dict[str, Any],
clear_existing: bool = False,
allow_unmatched_labels: bool = False,
threshold: Optional[float],
conf_threshold: Optional[float],
) -> None:
if function_module is not None:
function = importlib.import_module(function_module)
Expand All @@ -511,5 +511,5 @@ def execute(
pbar=DeferredTqdmProgressReporter(),
clear_existing=clear_existing,
allow_unmatched_labels=allow_unmatched_labels,
threshold=threshold,
conf_threshold=conf_threshold,
)
14 changes: 7 additions & 7 deletions cvat-sdk/cvat_sdk/auto_annotation/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def validate_and_remap(self, shapes: list[models.LabeledShapeRequest], ds_frame:
@attrs.frozen(kw_only=True)
class _DetectionFunctionContextImpl(DetectionFunctionContext):
frame_name: str
threshold: Optional[float] = None
conf_threshold: Optional[float] = None


def annotate_task(
Expand All @@ -234,7 +234,7 @@ def annotate_task(
pbar: Optional[ProgressReporter] = None,
clear_existing: bool = False,
allow_unmatched_labels: bool = False,
threshold: Optional[float] = None,
conf_threshold: Optional[float] = None,
) -> None:
"""
Downloads data for the task with the given ID, applies the given function to it
Expand Down Expand Up @@ -267,15 +267,15 @@ def annotate_task(
If it's set to true, then such labels are allowed, and any annotations returned by the
function that refer to this label are ignored. Otherwise, BadFunctionError is raised.
The threshold parameter must be None or a number between 0 and 1. It will be passed
to the function as the threshold attribute of the context object.
The conf_threshold parameter must be None or a number between 0 and 1. It will be passed
to the function as the conf_threshold attribute of the context object.
"""

if pbar is None:
pbar = NullProgressReporter()

if threshold is not None and not 0 <= threshold <= 1:
raise ValueError("threshold must be None or a number between 0 and 1")
if conf_threshold is not None and not 0 <= conf_threshold <= 1:
raise ValueError("conf_threshold must be None or a number between 0 and 1")

dataset = TaskDataset(client, task_id, load_annotations=False)

Expand All @@ -293,7 +293,7 @@ def annotate_task(
with pbar.task(total=len(dataset.samples), unit="samples"):
for sample in pbar.iter(dataset.samples):
frame_shapes = function.detect(
_DetectionFunctionContextImpl(frame_name=sample.frame_name, threshold=threshold),
_DetectionFunctionContextImpl(frame_name=sample.frame_name, conf_threshold=conf_threshold),
sample.media.load_image(),
)
mapper.validate_and_remap(frame_shapes, sample.frame_index)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@ def spec(self) -> cvataa.DetectionFunctionSpec:
def detect(
self, context: cvataa.DetectionFunctionContext, image: PIL.Image.Image
) -> list[models.LabeledShapeRequest]:
threshold = context.threshold or 0
conf_threshold = context.conf_threshold or 0
results = self._model([self._transforms(image)])

return [
cvataa.rectangle(label.item(), [x.item() for x in box])
for result in results
for box, label, score in zip(result["boxes"], result["labels"], result["scores"])
if score >= threshold
if score >= conf_threshold
]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def spec(self) -> cvataa.DetectionFunctionSpec:
def detect(
self, context: cvataa.DetectionFunctionContext, image: PIL.Image.Image
) -> list[models.LabeledShapeRequest]:
threshold = context.threshold or 0
conf_threshold = context.conf_threshold or 0
results = self._model([self._transforms(image)])

return [
Expand All @@ -57,7 +57,7 @@ def detect(
for keypoints, label, score in zip(
result["keypoints"], result["labels"], result["scores"]
)
if score >= threshold
if score >= conf_threshold
]


Expand Down
2 changes: 1 addition & 1 deletion cvat-sdk/cvat_sdk/auto_annotation/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def frame_name(self) -> str:

@property
@abc.abstractmethod
def threshold(self) -> Optional[float]:
def conf_threshold(self) -> Optional[float]:
"""
The confidence threshold that the function should use for filtering
detections.
Expand Down
8 changes: 4 additions & 4 deletions site/content/en/docs/api_sdk/sdk/auto-annotation.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class TorchvisionDetectionFunction:
self, context: cvataa.DetectionFunctionContext, image: PIL.Image.Image
) -> list[models.LabeledShapeRequest]:
# determine the threshold for filtering results
threshold = context.threshold or 0
conf_threshold = context.conf_threshold or 0

# convert the input into a form the model can understand
transformed_image = [self._transforms(image)]
Expand All @@ -85,7 +85,7 @@ class TorchvisionDetectionFunction:
cvataa.rectangle(label.item(), [x.item() for x in box])
for result in results
for box, label, score in zip(result["boxes"], result["labels"], result["scores"])
if score >= threshold
if score >= conf_threshold
]

# log into the CVAT server
Expand Down Expand Up @@ -122,7 +122,7 @@ that these objects must follow.
The following fields are available:

- `frame_name` (`str`). The file name of the frame on the CVAT server.
- `threshold` (`float | None`). The confidence threshold that the function
- `conf_threshold` (`float | None`). The confidence threshold that the function
should use to filter objects. If `None`, the function may apply a default
threshold at its discretion.

Expand Down Expand Up @@ -206,7 +206,7 @@ and any shapes referring to them will be dropped.
Same logic applies to sub-label IDs.

It's possible to pass a custom confidence threshold to the function via the
`threshold` parameter.
`conf_threshold` parameter.

`annotate_task` will raise a `BadFunctionError` exception
if it detects that the function violated the AA function protocol.
Expand Down

0 comments on commit 06be1ca

Please sign in to comment.