Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

πŸ”¨ Add default metrics to Engine #1769

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
368 changes: 56 additions & 312 deletions notebooks/000_getting_started/001_getting_started.ipynb

Large diffs are not rendered by default.

11 changes: 8 additions & 3 deletions src/anomalib/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ class Engine:
Defaults to None.
pixel_metrics (str | list[str] | None, optional): Pixel metrics to be used for evaluation.
Defaults to None.
visualization_handlers (BaseVisualizationGenerator | list[BaseVisualizationGenerator] | None):
visualizers (BaseVisualizationGenerator | list[BaseVisualizationGenerator] | None):
Visualization parameters. Defaults to None.
**kwargs: PyTorch Lightning Trainer arguments.
"""
Expand Down Expand Up @@ -139,8 +139,13 @@ def __init__(
self.normalization = normalization
self.threshold = threshold
self.task = TaskType(task)
self.image_metric_names = image_metrics
self.pixel_metric_names = pixel_metrics
self.image_metric_names = image_metrics if image_metrics else ["AUROC", "F1Score"]

# pixel metrics are only used for segmentation tasks.
self.pixel_metric_names = None
if self.task == TaskType.SEGMENTATION:
self.pixel_metric_names = pixel_metrics if pixel_metrics is not None else ["AUROC", "F1Score"]

self.visualizers = visualizers

self.save_image = save_image
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/model/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from anomalib.models import AnomalyModule, get_available_models, get_model


def models() -> list[str]:
def models() -> set[str]:
"""Return all available models."""
return get_available_models()

Expand Down
Loading