Skip to content

Commit 2c968f5

Browse files
author
pfinashx
committed
Merge branch 'develop' into pf/adding_anomaly_training_tests
2 parents 05f1b66 + 73ec03e commit 2c968f5

File tree

39 files changed

+1033
-334
lines changed

39 files changed

+1033
-334
lines changed

external/anomaly/ote_anomalib/callbacks/inference.py

+16-12
Original file line numberDiff line numberDiff line change
@@ -58,23 +58,27 @@ def on_predict_epoch_end(self, _trainer: pl.Trainer, pl_module: AnomalyModule, o
5858
for dataset_item, pred_score, pred_label, anomaly_map, pred_mask in zip(
5959
self.ote_dataset, pred_scores, pred_labels, anomaly_maps, pred_masks
6060
):
61-
label = self.anomalous_label if pred_label else self.normal_label
62-
probability = (1 - pred_score) if pred_score < 0.5 else pred_score
63-
dataset_item.append_labels([ScoredLabel(label=label, probability=float(probability))])
61+
probability = pred_score if pred_label else 1 - pred_score
62+
if self.task_type == TaskType.ANOMALY_CLASSIFICATION:
63+
label = self.anomalous_label if pred_label else self.normal_label
6464
if self.task_type == TaskType.ANOMALY_DETECTION:
65-
dataset_item.append_annotations(
66-
annotations=create_detection_annotation_from_anomaly_heatmap(
67-
hard_prediction=pred_mask,
68-
soft_prediction=anomaly_map,
69-
label_map=self.label_map,
70-
)
65+
annotations = create_detection_annotation_from_anomaly_heatmap(
66+
hard_prediction=pred_mask,
67+
soft_prediction=anomaly_map,
68+
label_map=self.label_map,
7169
)
70+
dataset_item.append_annotations(annotations)
71+
label = self.normal_label if len(annotations) == 0 else self.anomalous_label
7272
elif self.task_type == TaskType.ANOMALY_SEGMENTATION:
73-
mask = pred_mask.squeeze().astype(np.uint8)
74-
dataset_item.append_annotations(
75-
create_annotation_from_segmentation_map(mask, anomaly_map.squeeze(), self.label_map)
73+
annotations = create_annotation_from_segmentation_map(
74+
hard_prediction=pred_mask.squeeze().astype(np.uint8),
75+
soft_prediction=anomaly_map.squeeze(),
76+
label_map=self.label_map,
7677
)
78+
dataset_item.append_annotations(annotations)
79+
label = self.normal_label if len(annotations) == 0 else self.anomalous_label
7780

81+
dataset_item.append_labels([ScoredLabel(label=label, probability=float(probability))])
7882
dataset_item.append_metadata_item(
7983
ResultMediaEntity(
8084
name="Anomaly Map",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,271 @@
1+
# Copyright (C) 2020-2022 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
5+
"""Create MVTec AD (CC BY-NC-SA 4.0) JSON Annotations for OTE CLI.
6+
7+
Description:
8+
This script converts MVTec AD dataset masks to OTE CLI annotation format for
9+
classification, detection and segmentation tasks.
10+
11+
License:
12+
MVTec AD dataset is released under the Creative Commons
13+
Attribution-NonCommercial-ShareAlike 4.0 International License
14+
(CC BY-NC-SA 4.0)(https://creativecommons.org/licenses/by-nc-sa/4.0/).
15+
16+
Reference:
17+
- Paul Bergmann, Kilian Batzner, Michael Fauser, David Sattlegger, Carsten Steger:
18+
The MVTec Anomaly Detection Dataset: A Comprehensive Real-World Dataset for
19+
Unsupervised Anomaly Detection; in: International Journal of Computer Vision
20+
129(4):1038-1059, 2021, DOI: 10.1007/s11263-020-01400-4.
21+
22+
- Paul Bergmann, Michael Fauser, David Sattlegger, Carsten Steger: MVTec AD —
23+
A Comprehensive Real-World Dataset for Unsupervised Anomaly Detection;
24+
in: IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR),
25+
9584-9592, 2019, DOI: 10.1109/CVPR.2019.00982.
26+
27+
Example:
28+
Assume that MVTec AD dataset is located in "./data/anomaly/MVTec/" from the root
29+
directory in training_extensions. JSON annotations could be created by running the
30+
following:
31+
32+
>>> import os
33+
'~/training_extensions'
34+
>>> os.listdir("./data/anomaly")
35+
['detection', 'shapes', 'segmentation', 'MVTec', 'classification']
36+
37+
The following script will generate the classification, detection and segmentation
38+
JSON annotations to each category in ./data/anomaly/MVTec dataset.
39+
40+
>>> python external/anomaly/ote_anomalib/data/create_mvtec_ad_json_annotations.py \
41+
... --data_path ./data/anomaly/MVTec/
42+
"""
43+
44+
import json
45+
import os
46+
from argparse import ArgumentParser, Namespace
47+
from pathlib import Path
48+
from typing import Any, Dict, List, Optional
49+
50+
import cv2
51+
import pandas as pd
52+
from anomalib.data.mvtec import make_mvtec_dataset
53+
54+
55+
def create_bboxes_from_mask(mask_path: str) -> List[List[float]]:
56+
"""Create bounding box from binary mask.
57+
58+
Args:
59+
mask_path (str): Path to binary mask.
60+
61+
Returns:
62+
List[List[float]]: Bounding box coordinates.
63+
"""
64+
# pylint: disable-msg=too-many-locals
65+
66+
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
67+
height, width = mask.shape
68+
69+
bboxes: List[List[float]] = []
70+
_, _, coordinates, _ = cv2.connectedComponentsWithStats(mask)
71+
for i, coordinate in enumerate(coordinates):
72+
# First row of the coordinates is always backround,
73+
# so should be ignored.
74+
if i == 0:
75+
continue
76+
77+
# Last column of the coordinates is the area of the connected component.
78+
# It could therefore be ignored.
79+
comp_x, comp_y, comp_w, comp_h, _ = coordinate
80+
x1 = comp_x / width
81+
y1 = comp_y / height
82+
x2 = (comp_x + comp_w) / width
83+
y2 = (comp_y + comp_h) / height
84+
85+
bboxes.append([x1, y1, x2, y2])
86+
87+
return bboxes
88+
89+
90+
def create_polygons_from_mask(mask_path: str) -> List[List[float]]:
91+
"""Create polygons from binary mask.
92+
93+
Args:
94+
mask_path (str): Path to binary mask.
95+
96+
Returns:
97+
List[List[float]]: Polygon coordinates.
98+
"""
99+
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
100+
height, width = mask.shape
101+
102+
polygons = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)[0][0]
103+
polygons = [[x / width, y / height] for polygon in polygons for (x, y) in polygon]
104+
105+
return polygons
106+
107+
108+
def create_classification_json_items(pd_items: pd.DataFrame) -> Dict[str, Any]:
109+
"""Create JSON items for the classification task.
110+
111+
Args:
112+
pd_items (pd.DataFrame): MVTec AD samples in pandas DataFrame object.
113+
114+
Returns:
115+
Dict[str, Any]: MVTec AD classification JSON items
116+
"""
117+
json_items: Dict[str, Any] = {"image_path": {}, "label": {}, "masks": {}}
118+
for index, pd_item in pd_items.iterrows():
119+
json_items["image_path"][str(index)] = pd_item.image_path.replace(pd_item.path, "")[1:]
120+
json_items["label"][str(index)] = pd_item.label
121+
if pd_item.label != "good":
122+
json_items["masks"][str(index)] = pd_item.mask_path.replace(pd_item.path, "")[1:]
123+
124+
return json_items
125+
126+
127+
def create_detection_json_items(pd_items: pd.DataFrame) -> Dict[str, Any]:
128+
"""Create JSON items for the detection task.
129+
130+
Args:
131+
pd_items (pd.DataFrame): MVTec AD samples in pandas DataFrame object.
132+
133+
Returns:
134+
Dict[str, Any]: MVTec AD detection JSON items
135+
"""
136+
json_items: Dict[str, Any] = {"image_path": {}, "label": {}, "bboxes": {}}
137+
for index, pd_item in pd_items.iterrows():
138+
json_items["image_path"][str(index)] = pd_item.image_path.replace(pd_item.path, "")[1:]
139+
json_items["label"][str(index)] = pd_item.label
140+
if pd_item.label != "good":
141+
json_items["bboxes"][str(index)] = create_bboxes_from_mask(pd_item.mask_path)
142+
143+
return json_items
144+
145+
146+
def create_segmentation_json_items(pd_items: pd.DataFrame) -> Dict[str, Any]:
147+
"""Create JSON items for the segmentation task.
148+
149+
Args:
150+
pd_items (pd.DataFrame): MVTec AD samples in pandas DataFrame object.
151+
152+
Returns:
153+
Dict[str, Any]: MVTec AD segmentation JSON items
154+
"""
155+
json_items: Dict[str, Any] = {"image_path": {}, "label": {}, "masks": {}}
156+
for index, pd_item in pd_items.iterrows():
157+
json_items["image_path"][str(index)] = pd_item.image_path.replace(pd_item.path, "")[1:]
158+
json_items["label"][str(index)] = pd_item.label
159+
if pd_item.label != "good":
160+
json_items["masks"][str(index)] = create_polygons_from_mask(pd_item.mask_path)
161+
162+
return json_items
163+
164+
165+
def save_json_items(json_items: Dict[str, Any], file: str) -> None:
166+
"""Save JSON items to file.
167+
168+
Args:
169+
json_items (Dict[str, Any]): MVTec AD JSON items
170+
file (str): Path to save as a JSON file.
171+
"""
172+
with open(file=file, mode="w", encoding="utf-8") as f:
173+
json.dump(json_items, f)
174+
175+
176+
def create_task_annotations(task: str, data_path: str, annotation_path: str) -> None:
177+
"""Create MVTec AD categories for a given task.
178+
179+
Args:
180+
task (str): Task type to save annotations.
181+
data_path (str): Path to MVTec AD category.
182+
annotation_path (str): Path to save MVTec AD category JSON annotation items.
183+
184+
Raises:
185+
ValueError: When task is not classification, detection or segmentation.
186+
"""
187+
annotation_path = os.path.join(data_path, task)
188+
os.makedirs(annotation_path, exist_ok=True)
189+
190+
for split in ["train", "val", "test"]:
191+
192+
if task == "classification":
193+
create_json_items = create_classification_json_items
194+
elif task == "detection":
195+
create_json_items = create_detection_json_items
196+
elif task == "segmentation":
197+
create_json_items = create_segmentation_json_items
198+
else:
199+
raise ValueError(f"Unknown task {task}. Available tasks are classification, detection and segmentation.")
200+
201+
df_items = make_mvtec_dataset(path=Path(data_path), create_validation_set=True, split=split)
202+
json_items = create_json_items(df_items)
203+
save_json_items(json_items, f"{annotation_path}/{split}.json")
204+
205+
206+
def create_mvtec_ad_category_annotations(data_path: str, annotation_path: str) -> None:
207+
"""Create MVTec AD category annotations for classification, detection and segmentation tasks.
208+
209+
Args:
210+
data_path (str): Path to MVTec AD category.
211+
annotation_path (str): Path to save MVTec AD category JSON annotation items.
212+
"""
213+
for task in ["classification", "detection", "segmentation"]:
214+
create_task_annotations(task, data_path, annotation_path)
215+
216+
217+
def create_mvtec_ad_annotations(mvtec_data_path: str, mvtec_annotation_path: Optional[str] = None) -> None:
218+
"""Create JSON annotations for MVTec AD dataset.
219+
220+
Args:
221+
mvtec_data_path (str): Path to MVTec AD dataset.
222+
mvtec_annotation_path (Optional[str], optional): Path to save JSON annotations. Defaults to None.
223+
"""
224+
if mvtec_annotation_path is None:
225+
mvtec_annotation_path = mvtec_data_path
226+
227+
categories = [
228+
"bottle",
229+
"cable",
230+
"capsule",
231+
"carpet",
232+
"grid",
233+
"hazelnut",
234+
"leather",
235+
"metal_nut",
236+
"pill",
237+
"screw",
238+
"tile",
239+
"toothbrush",
240+
"transistor",
241+
"wood",
242+
"zipper",
243+
]
244+
245+
for category in categories:
246+
print(f"Creating annotations for {category}")
247+
category_data_path = os.path.join(mvtec_data_path, category)
248+
category_annotation_path = os.path.join(mvtec_annotation_path, category)
249+
create_mvtec_ad_category_annotations(category_data_path, category_annotation_path)
250+
251+
252+
def get_args() -> Namespace:
253+
"""Get command line arguments.
254+
255+
Returns:
256+
Namespace: List of arguments.
257+
"""
258+
parser = ArgumentParser()
259+
parser.add_argument("--data_path", type=str, default="./data/anomaly/MVTec/", help="Path to Mvtec AD dataset.")
260+
parser.add_argument("--annotation_path", type=str, required=False, help="Path to create OTE CLI annotations.")
261+
return parser.parse_args()
262+
263+
264+
def main():
265+
"""Create MVTec AD Annotations."""
266+
args = get_args()
267+
create_mvtec_ad_annotations(mvtec_data_path=args.data_path, mvtec_annotation_path=args.annotation_path)
268+
269+
270+
if __name__ == "__main__":
271+
main()

external/anomaly/ote_anomalib/data/data.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,15 @@
2222
import numpy as np
2323
from anomalib.pre_processing import PreProcessor
2424
from omegaconf import DictConfig, ListConfig
25-
from ote_anomalib.data.utils import (
26-
contains_anomalous_images,
27-
split_local_global_dataset,
28-
)
2925
from ote_anomalib.logging import get_logger
3026
from ote_sdk.entities.datasets import DatasetEntity
3127
from ote_sdk.entities.model_template import TaskType
3228
from ote_sdk.entities.shapes.polygon import Polygon
3329
from ote_sdk.entities.subset import Subset
30+
from ote_sdk.utils.dataset_utils import (
31+
contains_anomalous_images,
32+
split_local_global_dataset,
33+
)
3434
from ote_sdk.utils.segmentation_utils import mask_from_dataset_item
3535
from pytorch_lightning.core.datamodule import LightningDataModule
3636
from torch import Tensor

external/anomaly/ote_anomalib/data/mvtec.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,21 @@
1-
"""OTE MVTec Dataset facilitate OTE Anomaly Training."""
1+
"""OTE MVTec Dataset facilitate OTE Anomaly Training.
2+
3+
License:
4+
MVTec AD dataset is released under the Creative Commons
5+
Attribution-NonCommercial-ShareAlike 4.0 International License
6+
(CC BY-NC-SA 4.0)(https://creativecommons.org/licenses/by-nc-sa/4.0/).
7+
8+
Reference:
9+
- Paul Bergmann, Kilian Batzner, Michael Fauser, David Sattlegger, Carsten Steger:
10+
The MVTec Anomaly Detection Dataset: A Comprehensive Real-World Dataset for
11+
Unsupervised Anomaly Detection; in: International Journal of Computer Vision
12+
129(4):1038-1059, 2021, DOI: 10.1007/s11263-020-01400-4.
13+
14+
- Paul Bergmann, Michael Fauser, David Sattlegger, Carsten Steger: MVTec AD —
15+
A Comprehensive Real-World Dataset for Unsupervised Anomaly Detection;
16+
in: IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR),
17+
9584-9592, 2019, DOI: 10.1109/CVPR.2019.00982.
18+
"""
219

320
# Copyright (C) 2021 Intel Corporation
421
#

0 commit comments

Comments
 (0)