diff --git a/CHANGELOG.md b/CHANGELOG.md index 340641fb7c..cf9807af26 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). ### Added +- Add `Datumaro` annotation format support by @ashwinvaidya17 in https://github.com/openvinotoolkit/anomalib/pull/2377 - Add `AUPIMO` tutorials notebooks in https://github.com/openvinotoolkit/anomalib/pull/2330 and https://github.com/openvinotoolkit/anomalib/pull/2336 - Add `AUPIMO` metric by [jpcbertoldo](https://github.com/jpcbertoldo) in https://github.com/openvinotoolkit/anomalib/pull/1726 and refactored by [ashwinvaidya17](https://github.com/ashwinvaidya17) in https://github.com/openvinotoolkit/anomalib/pull/2329 diff --git a/configs/data/datumaro.yaml b/configs/data/datumaro.yaml new file mode 100644 index 0000000000..31867f34fa --- /dev/null +++ b/configs/data/datumaro.yaml @@ -0,0 +1,15 @@ +class_path: anomalib.data.Datumaro +init_args: + root: "datasets/datumaro" + train_batch_size: 32 + eval_batch_size: 32 + num_workers: 8 + image_size: null + transform: null + train_transform: null + eval_transform: null + test_split_mode: FROM_DIR + test_split_ratio: 0.2 + val_split_mode: FROM_TEST + val_split_ratio: 0.5 + seed: null diff --git a/src/anomalib/data/__init__.py b/src/anomalib/data/__init__.py index e7eaf11156..0ad469ac69 100644 --- a/src/anomalib/data/__init__.py +++ b/src/anomalib/data/__init__.py @@ -14,7 +14,7 @@ from .base import AnomalibDataModule, AnomalibDataset from .depth import DepthDataFormat, Folder3D, MVTec3D -from .image import BTech, Folder, ImageDataFormat, Kolektor, MVTec, Visa +from .image import BTech, Datumaro, Folder, ImageDataFormat, Kolektor, MVTec, Visa from .predict import PredictDataset from .utils import LabelName from .video import Avenue, ShanghaiTech, UCSDped, VideoDataFormat @@ -70,6 +70,7 @@ def get_datamodule(config: DictConfig | ListConfig | dict) -> AnomalibDataModule "VideoDataFormat", "get_datamodule", "BTech", + "Datumaro", "Folder", "Folder3D", "PredictDataset", diff --git a/src/anomalib/data/image/__init__.py b/src/anomalib/data/image/__init__.py index 0bea0f07ad..147db09418 100644 --- a/src/anomalib/data/image/__init__.py +++ b/src/anomalib/data/image/__init__.py @@ -9,6 +9,7 @@ from enum import Enum from .btech import BTech +from .datumaro import Datumaro from .folder import Folder from .kolektor import Kolektor from .mvtec import MVTec @@ -18,13 +19,14 @@ class ImageDataFormat(str, Enum): """Supported Image Dataset Types.""" - MVTEC = "mvtec" - MVTEC_3D = "mvtec_3d" BTECH = "btech" - KOLEKTOR = "kolektor" + DATUMARO = "datumaro" FOLDER = "folder" FOLDER_3D = "folder_3d" + KOLEKTOR = "kolektor" + MVTEC = "mvtec" + MVTEC_3D = "mvtec_3d" VISA = "visa" -__all__ = ["BTech", "Folder", "Kolektor", "MVTec", "Visa"] +__all__ = ["BTech", "Datumaro", "Folder", "Kolektor", "MVTec", "Visa"] diff --git a/src/anomalib/data/image/datumaro.py b/src/anomalib/data/image/datumaro.py new file mode 100644 index 0000000000..b4836990ec --- /dev/null +++ b/src/anomalib/data/image/datumaro.py @@ -0,0 +1,226 @@ +"""Dataloader for Datumaro format. + +Note: This currently only works for annotations exported from Intel Geti™. +""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import json +from pathlib import Path + +import pandas as pd +from torchvision.transforms.v2 import Transform + +from anomalib import TaskType +from anomalib.data.base import AnomalibDataModule, AnomalibDataset +from anomalib.data.utils import LabelName, Split, TestSplitMode, ValSplitMode + + +def make_datumaro_dataset(root: str | Path, split: str | Split | None = None) -> pd.DataFrame: + """Make Datumaro Dataset. + + Assumes the following directory structure: + + dataset + ├── annotations + │ └── default.json + └── images + └── default + ├── image1.jpg + ├── image2.jpg + └── ... + + Args: + root (str | Path): Path to the dataset root directory. + split (str | Split | None): Split of the dataset, usually Split.TRAIN or Split.TEST. + Defaults to ``None``. + + Examples: + >>> root = Path("path/to/dataset") + >>> samples = make_datumaro_dataset(root) + >>> samples.head() + image_path label label_index split mask_path + 0 path/to/dataset... Normal 0 Split.TRAIN + 1 path/to/dataset... Normal 0 Split.TRAIN + 2 path/to/dataset... Normal 0 Split.TRAIN + 3 path/to/dataset... Normal 0 Split.TRAIN + 4 path/to/dataset... Normal 0 Split.TRAIN + + + Returns: + DataFrame: an output dataframe containing samples for the requested split (ie., train or test). + """ + annotation_file = Path(root) / "annotations" / "default.json" + with annotation_file.open() as f: + annotations = json.load(f) + + categories = annotations["categories"] + categories = {idx: label["name"] for idx, label in enumerate(categories["label"]["labels"])} + + samples = [] + for item in annotations["items"]: + image_path = Path(root) / "images" / "default" / item["image"]["path"] + label_index = item["annotations"][0]["label_id"] + label = categories[label_index] + samples.append({ + "image_path": str(image_path), + "label": label, + "label_index": label_index, + "split": None, + "mask_path": "", # mask is provided in the annotation file and is not on disk. + }) + samples_df = pd.DataFrame( + samples, + columns=["image_path", "label", "label_index", "split", "mask_path"], + index=range(len(samples)), + ) + # Create test/train split + # By default assign all "Normal" samples to train and all "Anomalous" samples to test + samples_df.loc[samples_df["label_index"] == LabelName.NORMAL, "split"] = Split.TRAIN + samples_df.loc[samples_df["label_index"] == LabelName.ABNORMAL, "split"] = Split.TEST + + # Get the data frame for the split. + if split: + samples_df = samples_df[samples_df.split == split].reset_index(drop=True) + + return samples_df + + +class DatumaroDataset(AnomalibDataset): + """Datumaro dataset class. + + Args: + task (TaskType): Task type, ``classification``, ``detection`` or ``segmentation``. + root (str | Path): Path to the dataset root directory. + transform (Transform, optional): Transforms that should be applied to the input images. + Defaults to ``None``. + split (str | Split | None): Split of the dataset, usually Split.TRAIN or Split.TEST + Defaults to ``None``. + + + Examples: + .. code-block:: python + + from anomalib.data.image.datumaro import DatumaroDataset + from torchvision.transforms.v2 import Resize + + dataset = DatumaroDataset(root=root, + task="classification", + transform=Resize((256, 256)), + ) + print(dataset[0].keys()) + # Output: dict_keys(['dm_format_version', 'infos', 'categories', 'items']) + + """ + + def __init__( + self, + task: TaskType, + root: str | Path, + transform: Transform | None = None, + split: str | Split | None = None, + ) -> None: + super().__init__(task, transform) + self.split = split + self.samples = make_datumaro_dataset(root, split) + + +class Datumaro(AnomalibDataModule): + """Datumaro datamodule. + + Args: + root (str | Path): Path to the dataset root directory. + train_batch_size (int): Batch size for training dataloader. + Defaults to ``32``. + eval_batch_size (int): Batch size for evaluation dataloader. + Defaults to ``32``. + num_workers (int): Number of workers for dataloaders. + Defaults to ``8``. + task (TaskType): Task type, ``classification``, ``detection`` or ``segmentation``. + Defaults to ``TaskType.CLASSIFICATION``. Currently only supports classification. + image_size (tuple[int, int], optional): Size to which input images should be resized. + Defaults to ``None``. + transform (Transform, optional): Transforms that should be applied to the input images. + Defaults to ``None``. + train_transform (Transform, optional): Transforms that should be applied to the input images during training. + Defaults to ``None``. + eval_transform (Transform, optional): Transforms that should be applied to the input images during evaluation. + Defaults to ``None``. + test_split_mode (TestSplitMode): Setting that determines how the testing subset is obtained. + Defaults to ``TestSplitMode.FROM_DIR``. + test_split_ratio (float): Fraction of images from the train set that will be reserved for testing. + Defaults to ``0.2``. + val_split_mode (ValSplitMode): Setting that determines how the validation subset is obtained. + Defaults to ``ValSplitMode.SAME_AS_TEST``. + val_split_ratio (float): Fraction of train or test images that will be reserved for validation. + Defaults to ``0.5``. + seed (int | None, optional): Seed which may be set to a fixed value for reproducibility. + Defualts to ``None``. + + Examples: + To create a Datumaro datamodule + + >>> from pathlib import Path + >>> from torchvision.transforms.v2 import Resize + >>> root = Path("path/to/dataset") + >>> datamodule = Datumaro(root, transform=Resize((256, 256))) + >>> datamodule.setup() + >>> i, data = next(enumerate(datamodule.train_dataloader())) + >>> data.keys() + dict_keys(['image_path', 'label', 'image']) + + >>> data["image"].shape + torch.Size([32, 3, 256, 256]) + """ + + def __init__( + self, + root: str | Path, + train_batch_size: int = 32, + eval_batch_size: int = 32, + num_workers: int = 8, + task: TaskType = TaskType.CLASSIFICATION, + image_size: tuple[int, int] | None = None, + transform: Transform | None = None, + train_transform: Transform | None = None, + eval_transform: Transform | None = None, + test_split_mode: TestSplitMode | str = TestSplitMode.FROM_DIR, + test_split_ratio: float = 0.5, + val_split_mode: ValSplitMode | str = ValSplitMode.FROM_TEST, + val_split_ratio: float = 0.5, + seed: int | None = None, + ) -> None: + if task != TaskType.CLASSIFICATION: + msg = "Datumaro dataloader currently only supports classification task." + raise ValueError(msg) + super().__init__( + train_batch_size=train_batch_size, + eval_batch_size=eval_batch_size, + num_workers=num_workers, + val_split_mode=val_split_mode, + val_split_ratio=val_split_ratio, + test_split_mode=test_split_mode, + test_split_ratio=test_split_ratio, + image_size=image_size, + transform=transform, + train_transform=train_transform, + eval_transform=eval_transform, + seed=seed, + ) + self.root = root + self.task = task + + def _setup(self, _stage: str | None = None) -> None: + self.train_data = DatumaroDataset( + task=self.task, + root=self.root, + transform=self.train_transform, + split=Split.TRAIN, + ) + self.test_data = DatumaroDataset( + task=self.task, + root=self.root, + transform=self.eval_transform, + split=Split.TEST, + ) diff --git a/tests/helpers/data.py b/tests/helpers/data.py index 0ad699fb2f..60433df9eb 100644 --- a/tests/helpers/data.py +++ b/tests/helpers/data.py @@ -5,6 +5,7 @@ from __future__ import annotations +import json import shutil from contextlib import ContextDecorator from pathlib import Path @@ -319,6 +320,43 @@ def __init__( self.min_size = min_size self.image_generator = DummyImageGenerator(image_shape=image_shape, rng=self.rng) + def _generate_dummy_datumaro_dataset(self) -> None: + """Generates dummy Datumaro dataset in a temporary directory.""" + # generate images + image_root = self.dataset_root / "images" / "default" + image_root.mkdir(parents=True, exist_ok=True) + + file_names: list[str] = [] + + # Create normal images + for i in range(self.num_train + self.num_test): + label = LabelName.NORMAL + image_filename = image_root / f"normal_{i:03}.png" + file_names.append(image_filename) + self.image_generator.generate_image(label, image_filename) + + # Create abnormal images + for i in range(self.num_test): + label = LabelName.ABNORMAL + image_filename = image_root / f"abnormal_{i:03}.png" + file_names.append(image_filename) + self.image_generator.generate_image(label, image_filename) + + # create annotation file + annotation_file = self.dataset_root / "annotations" / "default.json" + annotation_file.parent.mkdir(parents=True, exist_ok=True) + annotations = { + "categories": {"label": {"labels": [{"name": "Normal"}, {"name": "Anomalous"}]}}, + "items": [], + } + for file_name in file_names: + annotations["items"].append({ + "annotations": [{"label_id": 1 if "abnormal" in str(file_name) else 0}], + "image": {"path": file_name.name}, + }) + with annotation_file.open("w") as f: + json.dump(annotations, f) + def _generate_dummy_mvtec_dataset( self, normal_dir: str = "good", diff --git a/tests/unit/data/image/test_datumaro.py b/tests/unit/data/image/test_datumaro.py new file mode 100644 index 0000000000..2aef9ae715 --- /dev/null +++ b/tests/unit/data/image/test_datumaro.py @@ -0,0 +1,39 @@ +"""Unit tests - Datumaro Datamodule.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from pathlib import Path + +import pytest + +from anomalib import TaskType +from anomalib.data import Datumaro +from tests.unit.data.base.image import _TestAnomalibImageDatamodule + + +class TestDatumaro(_TestAnomalibImageDatamodule): + """Datumaro Datamodule Unit Tests.""" + + @pytest.fixture() + @staticmethod + def datamodule(dataset_path: Path, task_type: TaskType) -> Datumaro: + """Create and return a Datumaro datamodule.""" + if task_type != TaskType.CLASSIFICATION: + pytest.skip("Datumaro only supports classification tasks.") + + _datamodule = Datumaro( + root=dataset_path / "datumaro", + task=task_type, + train_batch_size=4, + eval_batch_size=4, + ) + _datamodule.setup() + + return _datamodule + + @pytest.fixture() + @staticmethod + def fxt_data_config_path() -> str: + """Return the path to the test data config.""" + return "configs/data/datumaro.yaml"