Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NEW MODEL] Region-based Anomaly Detection #821

Merged
merged 302 commits into from
Jan 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
302 commits
Select commit Hold shift + click to select a range
25f503d
fix pylint issues
djdameln Oct 14, 2022
1245928
codacy
djdameln Oct 14, 2022
d9bd6e0
Merge branch 'main' into da/datamodules-alternative
djdameln Oct 14, 2022
0459a0d
update example dataset config in docs
djdameln Oct 14, 2022
30dc45a
fix test
djdameln Oct 14, 2022
85c475a
move base classes to separate files (avoid circular import)
djdameln Oct 14, 2022
cc32896
add base classes
djdameln Oct 14, 2022
23d4766
update docstring
djdameln Oct 14, 2022
e8d7998
fix imports
djdameln Oct 14, 2022
9c4e7bf
validation_split_mode -> val_split_mode
djdameln Oct 18, 2022
067d601
update docs
djdameln Oct 19, 2022
c84c99c
Update anomalib/data/base/dataset.py
djdameln Oct 21, 2022
b680d44
get length from self.samples
djdameln Oct 21, 2022
95c37b0
assert unique indices
djdameln Oct 21, 2022
3e77014
check is_setup for individual datasets
djdameln Oct 21, 2022
ede213a
remove assert in __getitem_\
djdameln Oct 21, 2022
f5e2d24
Update anomalib/data/btech.py
djdameln Oct 21, 2022
d9e1369
clearer assert message
djdameln Oct 21, 2022
2e6bc60
clarify list inversion in comment
djdameln Oct 21, 2022
af0cd99
comments and typing
djdameln Oct 21, 2022
d508786
Merge branch 'da/datamodules-alternative' of https://github.com/openv…
djdameln Oct 21, 2022
c85713c
Merge branch 'main' into da/datamodules-alternative
djdameln Oct 21, 2022
5ee8480
validate contents of samples dataframe before setting
djdameln Oct 21, 2022
a5e876a
add file paths check
djdameln Oct 21, 2022
c490e30
add seed to random_split function
djdameln Oct 21, 2022
4808287
fix expected columns
djdameln Oct 24, 2022
10bbf9c
fix typo
djdameln Oct 24, 2022
471ca1b
add pedestrian and avenue datasets and video utils
djdameln Oct 26, 2022
81d3ca3
add seed parameter to datamodules
djdameln Oct 28, 2022
b372dd1
set global seed in test entrypoint
djdameln Oct 28, 2022
e07a12c
add NONE option to valsplitmode
djdameln Oct 28, 2022
ffdb47c
clarify setup behaviour in docstring
djdameln Oct 28, 2022
bafe984
Merge branch 'da/datamodules-alternative' into da/video-datasets
djdameln Oct 28, 2022
4bf7a63
Created rbad directory
samet-akcay Oct 21, 2022
16c0796
Keep refactoring region-extractor
samet-akcay Oct 21, 2022
eb7cef8
rename new_image_sizes to transformed_image_sizes
samet-akcay Oct 21, 2022
d589108
Renamed the variables in region extractor
samet-akcay Oct 22, 2022
1e9f637
post-process function in region extractor
samet-akcay Oct 23, 2022
5b96e9a
Refactored tile-boxes function
samet-akcay Oct 23, 2022
742dfe2
Added feature extractor
samet-akcay Oct 25, 2022
f55bb64
Add main.py
samet-akcay Oct 25, 2022
787186a
Added feature extractor to tests
samet-akcay Oct 25, 2022
e135aeb
Update the jupyter notebook
samet-akcay Oct 25, 2022
2d82384
Uncomment loa weights from region.py
samet-akcay Oct 25, 2022
8a6bbc0
Add feature and region extractors
samet-akcay Oct 26, 2022
a1f53d6
Finished feature-extractor implementation
samet-akcay Oct 27, 2022
9f70322
Rename the algo as rkde
samet-akcay Oct 31, 2022
b21045b
New datamodules design (#572)
djdameln Oct 31, 2022
e3ad3f0
add basic visualization for video datasets
djdameln Nov 1, 2022
789e0a7
simplify ucsdped implementation
djdameln Nov 1, 2022
8ec3695
Merge branch 'da/video-datasets' of github.com:openvinotoolkit/anomal…
samet-akcay Nov 1, 2022
1e4b2f7
merge feature branch
djdameln Nov 1, 2022
d85dc2c
TODO: Investigate torch_model
samet-akcay Nov 2, 2022
44a86b8
add ucsd and avenue to __all__
djdameln Nov 2, 2022
7fb2a87
add default value for task
djdameln Nov 2, 2022
8deb8d9
add tests for ucsd and avenue
djdameln Nov 2, 2022
6c2dc62
add tests for video dataset and utils
djdameln Nov 2, 2022
1789e42
add download info for avenue dataset
djdameln Nov 2, 2022
5d47f61
add download info for ucsd pedestrian dataset
djdameln Nov 2, 2022
a9d534c
more consistent naming
djdameln Nov 3, 2022
7f7bfbc
fix path to masks folder in gt dir
djdameln Nov 3, 2022
4af1e46
pass original image in batch to facilitate visualization
djdameln Nov 3, 2022
13bb9bc
convert mask files for avenue
djdameln Nov 3, 2022
164c4dd
suppress warning due to torchvision bug
djdameln Nov 3, 2022
7523164
fix bug in avenue masks
djdameln Nov 3, 2022
db6769e
store visualizations for each video in separate folder
djdameln Nov 3, 2022
c5f1b60
rename parameters
djdameln Nov 3, 2022
efc8d28
add warning for clip_length > 1
djdameln Nov 3, 2022
011ba41
fix dataset tests
djdameln Nov 3, 2022
d93ebbc
fix labels tensor shape bug
djdameln Nov 3, 2022
04f45a7
add pyav to requirements
djdameln Nov 4, 2022
a09d7b5
Add TODO notes
samet-akcay Nov 7, 2022
92e2f36
add todo notes
samet-akcay Nov 7, 2022
65e7ed0
add description for avenue dataset
djdameln Nov 15, 2022
53bd70a
use pathlib
djdameln Nov 15, 2022
482b7c4
Update anomalib/data/avenue.py
djdameln Nov 15, 2022
4a8ddec
Update anomalib/data/avenue.py
djdameln Nov 15, 2022
85e83f0
Update anomalib/data/utils/video.py
djdameln Nov 15, 2022
b37f788
Update anomalib/data/base/video.py
djdameln Nov 15, 2022
527fcb0
Update anomalib/data/base/video.py
djdameln Nov 15, 2022
8665b8f
Update anomalib/data/ucsd_ped.py
djdameln Nov 15, 2022
2215941
import video dataset from base
djdameln Nov 15, 2022
2eb3ee7
fix bug when collecting ucsd samples
djdameln Nov 15, 2022
1acd6f5
clean up datamodules tests
djdameln Nov 15, 2022
d16cfb8
fix tests
djdameln Nov 16, 2022
64da65b
remove redundant test cases
djdameln Nov 16, 2022
44b70e5
add test case for normality model
djdameln Nov 18, 2022
ad36867
retrieve masks as numpy array
djdameln Nov 18, 2022
75009b9
use pathlib
djdameln Nov 18, 2022
cee65f8
variable name
djdameln Nov 18, 2022
0f4d65b
pathlib
djdameln Nov 18, 2022
f27b6f9
use preprocesser from arguments
djdameln Nov 18, 2022
1c2a3c1
fix indexing bug
djdameln Nov 18, 2022
f66c84e
Video Datamodules (#676)
djdameln Nov 18, 2022
1e1751e
Merge branch 'da/video-datasets' into feature/rbad
djdameln Nov 21, 2022
eeb7b6d
Merge branch 'feature/datamodules' into feature/rbad
djdameln Nov 21, 2022
6384227
properly handle batch processing
djdameln Nov 22, 2022
4c82e04
include batch index in rois tensor
djdameln Nov 22, 2022
0686ca2
return rkde results as lists
djdameln Nov 23, 2022
28154ff
update default rkde config
djdameln Nov 23, 2022
64008c3
add basic support for detection task
djdameln Nov 24, 2022
bcd03d4
use enum for task type
djdameln Nov 24, 2022
ca427fb
formatting
djdameln Nov 25, 2022
69b83b4
small bugfix
djdameln Nov 25, 2022
30c4368
add unit tests for bounding box conversion
djdameln Nov 25, 2022
3c0cfec
update error message
djdameln Nov 28, 2022
037c1e5
use as_tensor
djdameln Nov 28, 2022
abea835
typing and docstring
djdameln Nov 28, 2022
c060333
explicit keyword arguments
djdameln Nov 28, 2022
bf573d1
simplify bbox handling in video dataset
djdameln Nov 28, 2022
b7f1b66
docstring consistency
djdameln Nov 28, 2022
7f60ea2
add missing licenses
djdameln Nov 28, 2022
eb87358
add whitespace for readability
djdameln Nov 28, 2022
4c3a6b1
add missing license
djdameln Nov 28, 2022
cec6138
Update anomalib/data/utils/boxes.py
djdameln Nov 28, 2022
d13ce5b
Revert "Update anomalib/data/utils/boxes.py"
djdameln Nov 28, 2022
0e0dc80
add test case for custom collate function
djdameln Nov 28, 2022
5ead1ad
docstring
djdameln Nov 28, 2022
3d777da
Merge branch 'da/detection-task-type' into feature/rbad
djdameln Nov 28, 2022
44812d6
add integration tests for detection dataloading
djdameln Nov 29, 2022
d9304aa
extend and clean up datamodules tests
djdameln Nov 29, 2022
caf0867
add detection task type to visualizer tests
djdameln Nov 29, 2022
aac5a47
Update lightning_inference.py
djdameln Nov 29, 2022
67312fc
Merge branch 'feature/datamodules' into da/detection-task-type
djdameln Nov 29, 2022
d63a7b7
only show pred_boxes during inference
djdameln Nov 30, 2022
7ec5fa4
add detection support for torch inference
djdameln Nov 30, 2022
d74bf41
add detection support for openvino inference
djdameln Nov 30, 2022
39cf0ac
test inference for all task types
djdameln Dec 1, 2022
f3d00d8
pylint
djdameln Dec 1, 2022
ab6cb57
merge main
djdameln Dec 5, 2022
9962e8c
merge latest changes
djdameln Dec 5, 2022
cb06714
Make `val split ratio` configurable (#760)
djdameln Dec 5, 2022
5a055f2
merge feature branch
djdameln Dec 6, 2022
045d77f
Add support for Detection task type (#732)
djdameln Dec 6, 2022
ccec2f6
[Datamodules] Update deprecation messages (#764)
djdameln Dec 7, 2022
8dbcb40
Merge branch 'da/detection-task-type' into feature/rbad
djdameln Dec 13, 2022
b9b5975
Merge branch 'feature/datamodules' into feature/rbad
djdameln Dec 13, 2022
c5d6c84
update rkde
djdameln Dec 13, 2022
d210feb
Improve image source parsing for Folder dataset (#784)
djdameln Dec 13, 2022
67462e7
Synthetic anomaly for testing and validation (#634)
djdameln Dec 14, 2022
8141f2f
merge main
djdameln Dec 16, 2022
8601330
Bugfixes for Datamodules feature branch (#800)
djdameln Dec 19, 2022
663692e
Deprecate PreProcessor (#795)
djdameln Dec 19, 2022
d4afeee
Merge branch 'feature/datamodules' into feature/rbad
djdameln Dec 20, 2022
5446cd7
expose more parameters and fix wrong return format
djdameln Dec 20, 2022
5b2c310
fix tdd tests
djdameln Dec 20, 2022
96c9c82
update config
djdameln Dec 20, 2022
57d3b4e
[Datamodules] Fix bug in bbox score to image score conversion (#803)
djdameln Dec 20, 2022
a0d564c
Merge branch 'feature/datamodules' into feature/rbad
djdameln Dec 20, 2022
988c11c
update config
djdameln Dec 20, 2022
b239b38
apply pixel threshold to bbox detections
djdameln Dec 20, 2022
2bb5265
Merge branch 'da/detection-improvements' into feature/rbad
djdameln Dec 20, 2022
9c4a890
remove confidence threshold parameter from rkde
djdameln Dec 20, 2022
fbe3a1b
hardcode steepness and offset
djdameln Dec 20, 2022
2c8ff79
rename variable
djdameln Dec 20, 2022
608263c
remove unused parameters from config
djdameln Dec 20, 2022
192ba94
Improve handling of `test_split_mode='none'` and `val_split_mode='non…
djdameln Dec 20, 2022
1daf2c4
Merge branch 'feature/datamodules' into da/detection-improvements
djdameln Dec 20, 2022
c34d16c
Merge branch 'feature/datamodules' into feature/rbad
djdameln Dec 20, 2022
d3272f8
update config with new keys
djdameln Dec 21, 2022
17f587c
remove unused parameter
djdameln Dec 21, 2022
88fe40d
set device in rpn stage
djdameln Dec 21, 2022
d375ea1
move prediction format conversion to lightning model
djdameln Dec 21, 2022
016098c
clean up torch model
djdameln Dec 21, 2022
74d1dbd
move region- and feature-extractor to separate files
djdameln Dec 21, 2022
0edaa29
allow visualizing normal boxes
djdameln Dec 21, 2022
6bea776
Merge branch 'da/detection-improvements' into feature/rbad
djdameln Dec 21, 2022
e944263
refactor
djdameln Dec 21, 2022
5a5f197
WIP: simplify region extractor
djdameln Dec 21, 2022
9e5dd4a
simplify region extractor
djdameln Dec 22, 2022
61bb5a6
cleanup and docstrings
djdameln Dec 22, 2022
44e66ab
typing
djdameln Dec 22, 2022
035a772
expose max detections per image parameter
djdameln Dec 22, 2022
5e36181
explain configurable parameters
djdameln Dec 22, 2022
acf0bd7
fix wrong config value
djdameln Dec 22, 2022
538b9e5
remove unnecessary squeeze
djdameln Dec 22, 2022
8eece0d
box_likelihood -> rcnn_box_threshold
djdameln Dec 27, 2022
45e7104
update comments
djdameln Dec 27, 2022
d57059b
remove unnecessary typing
djdameln Dec 27, 2022
833f175
separate density estimation stage from torch model
djdameln Dec 27, 2022
be61815
improve readability
djdameln Dec 27, 2022
9ea074e
change default transform settings
djdameln Dec 27, 2022
b21f12c
fix to float transform
djdameln Dec 27, 2022
37d1db5
Merge branch 'feature/datamodules' into feature/rbad
djdameln Dec 27, 2022
2c40a11
simplify feature extractor
djdameln Dec 28, 2022
25b90cc
normalize box scores
djdameln Dec 28, 2022
47f8983
Merge branch 'da/detection-improvements' into feature/rbad
djdameln Dec 28, 2022
8f8b3bd
further simplify region extractor
djdameln Dec 28, 2022
bf141db
update comment
djdameln Dec 28, 2022
9f0da55
improve prn configurability
djdameln Dec 28, 2022
28f11b9
remove unnecessary check
djdameln Dec 28, 2022
1fbdcc7
use enum for roi stage options
djdameln Dec 28, 2022
a19c6c8
use enum for feature scaling method
djdameln Dec 28, 2022
2e91d09
re-order parameters
djdameln Dec 28, 2022
0f46da7
clean up model dir
djdameln Dec 28, 2022
8a19223
fix bbox logic in base anomaly module
djdameln Dec 29, 2022
8b961a8
Merge branch 'da/detection-improvements' into feature/rbad
djdameln Dec 29, 2022
750e332
update key in output dict
djdameln Dec 29, 2022
ed239c7
boxes_scores -> box_scores
djdameln Dec 29, 2022
9e5d775
Merge branch 'da/detection-improvements' into feature/rbad
djdameln Dec 29, 2022
80ce495
remove notebook
djdameln Dec 29, 2022
74dfef9
add comments and todo
djdameln Dec 29, 2022
ced7bc9
Detection improvements (#820)
djdameln Dec 29, 2022
80c5683
Merge branch 'feature/datamodules' into feature/rbad
djdameln Dec 29, 2022
9761ea8
add readme
djdameln Dec 29, 2022
690cb1b
merge main
djdameln Dec 29, 2022
6abe6d9
Merge branch 'feature/datamodules' into feature/rbad
djdameln Dec 29, 2022
4cf8577
update changelog
djdameln Dec 29, 2022
f29a7d7
Merge branch 'feature/datamodules' into feature/rbad
djdameln Dec 29, 2022
9cd02b7
update changelog
djdameln Dec 29, 2022
89661ba
update csflow config to new format
djdameln Dec 29, 2022
6d71119
Merge branch 'feature/datamodules' into feature/rbad
djdameln Dec 29, 2022
8e862ef
initialize max_length as empty tensor
djdameln Dec 29, 2022
af68432
include RKDE in model tests
djdameln Dec 29, 2022
9114c7d
remove unused imports
djdameln Dec 29, 2022
1c903f4
line length
djdameln Dec 29, 2022
60621b2
Merge branch 'feature/datamodules' into feature/rbad
djdameln Dec 29, 2022
2f0c4a8
move kde classifier to shared location
djdameln Dec 29, 2022
63d2514
fix import
djdameln Dec 29, 2022
ece12d5
re-use RKDE classifier in DFKDE
djdameln Dec 29, 2022
5128fb5
remove old imports
djdameln Jan 2, 2023
2c3051e
docstrings
djdameln Jan 2, 2023
370760e
fix codacy issues
djdameln Jan 2, 2023
a519c30
load feature extractor weights from url
djdameln Jan 2, 2023
dbba22a
suppress bandit warnings
djdameln Jan 2, 2023
451caf4
use torch rng in augmenter
djdameln Jan 2, 2023
9c96e78
typing
djdameln Jan 4, 2023
b58f4e9
add fit method to torch model
djdameln Jan 4, 2023
1595819
fix typo
djdameln Jan 4, 2023
848b0aa
use enum when checking stage
djdameln Jan 4, 2023
e457b5e
use tuple instead of list
djdameln Jan 5, 2023
68ef76b
add missing params to dosctring
djdameln Jan 5, 2023
1d23942
add missing licence information
djdameln Jan 5, 2023
d2fda44
COLS -> COLUMNS
djdameln Jan 5, 2023
0139d62
typing and variable naming
djdameln Jan 5, 2023
8914fa1
remove duplicate parameter in docstring
djdameln Jan 5, 2023
6b2dcc5
im_dir -> image_dir
djdameln Jan 5, 2023
2d24ed5
typing and docstring
djdameln Jan 5, 2023
6d39434
typing
djdameln Jan 5, 2023
6e3816f
ValSplitMode -> ValidationSplitMode
djdameln Jan 5, 2023
fad21b1
add missing licence
djdameln Jan 5, 2023
8df22c6
rename variable
djdameln Jan 5, 2023
ced0342
remove empty comment
djdameln Jan 5, 2023
96f9b5e
remove unused class attribute
djdameln Jan 5, 2023
0904529
[Detection] Compute box score when generating boxes from masks (#828)
djdameln Jan 6, 2023
4366f05
Merge branch 'feature/datamodules' into feature/rbad
djdameln Jan 6, 2023
a67c21b
revert val_split_mode -> validation_split_mode
djdameln Jan 6, 2023
1d5e2ba
Merge branch 'main' into feature/datamodules
djdameln Jan 6, 2023
f370b45
Merge branch 'feature/datamodules' into feature/rbad
djdameln Jan 6, 2023
6018bb2
Merge branch 'main' into feature/rbad
djdameln Jan 6, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

### Added

- Add RKDE model implementation (https://github.com/openvinotoolkit/anomalib/pull/821)
- Add Visual Anomaly (VisA) dataset adapter (<https://github.com/openvinotoolkit/anomalib/pull/824>)
- Add Synthetic anomalous dataset for validation and testing (https://github.com/openvinotoolkit/anomalib/pull/822)
- Add Detection task type support (https://github.com/openvinotoolkit/anomalib/pull/822)
Expand Down
15 changes: 15 additions & 0 deletions anomalib/data/utils/boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,18 @@ def boxes_to_anomaly_maps(boxes: Tensor, scores: Tensor, image_size: Tuple[int,
im_map[box_idx, y_1 : y_2 + 1, x_1 : x_2 + 1] = score
anomaly_maps[im_idx], _ = im_map.max(dim=0)
return anomaly_maps


def scale_boxes(boxes: Tensor, image_size: torch.Size, new_size: torch.Size) -> Tensor:
"""Scale bbox coordinates to a new image size.

Args:
boxes (Tensor): Boxes of shape (N, 4) - (x1, y1, x2, y2).
image_size (Size): Size of the original image in which the bbox coordinates were retrieved.
new_size (Size): New image size to which the bbox coordinates will be scaled.

Returns:
Tensor: Updated boxes of shape (N, 4) - (x1, y1, x2, y2).
"""
scale = Tensor([*new_size]) / Tensor([*image_size])
return boxes * scale.repeat(2).to(boxes.device)
3 changes: 3 additions & 0 deletions anomalib/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from anomalib.models.padim import Padim
from anomalib.models.patchcore import Patchcore
from anomalib.models.reverse_distillation import ReverseDistillation
from anomalib.models.rkde import Rkde
from anomalib.models.stfpm import Stfpm

__all__ = [
Expand All @@ -35,6 +36,7 @@
"Padim",
"Patchcore",
"ReverseDistillation",
"Rkde",
"Stfpm",
]

Expand Down Expand Up @@ -84,6 +86,7 @@ def get_model(config: Union[DictConfig, ListConfig]) -> AnomalyModule:
"padim",
"patchcore",
"reverse_distillation",
"rkde",
"stfpm",
]
model: AnomalyModule
Expand Down
5 changes: 5 additions & 0 deletions anomalib/models/components/classification/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Classification modules."""

from .kde_classifier import FeatureScalingMethod, KDEClassifier

__all__ = ["KDEClassifier", "FeatureScalingMethod"]
162 changes: 162 additions & 0 deletions anomalib/models/components/classification/kde_classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
"""Kernel Density Estimation Classifier."""

# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import logging
import random
from enum import Enum
from typing import Optional, Tuple

import torch
from torch import Tensor, nn

from anomalib.models.components import PCA, GaussianKDE

logger = logging.getLogger(__name__)


class FeatureScalingMethod(str, Enum):
"""Determines how the feature embeddings are scaled."""

NORM = "norm" # scale to unit vector length
SCALE = "scale" # scale to max length observed in training (preserve relative magnitude)


class KDEClassifier(nn.Module):
"""Classification module for KDE-based anomaly detection.

Args:
n_pca_components (int, optional): Number of PCA components. Defaults to 16.
feature_scaling_method (FeatureScalingMethod, optional): Scaling method applied to features before passing to
KDE. Options are `norm` (normalize to unit vector length) and `scale` (scale to max length observed in
training).
max_training_points (int, optional): Maximum number of training points to fit the KDE model. Defaults to 40000.
"""

def __init__(
self,
n_pca_components: int = 16,
feature_scaling_method: FeatureScalingMethod = FeatureScalingMethod.SCALE,
max_training_points: int = 40000,
) -> None:
super().__init__()

self.n_pca_components = n_pca_components
self.feature_scaling_method = feature_scaling_method
self.max_training_points = max_training_points

self.pca_model = PCA(n_components=self.n_pca_components)
self.kde_model = GaussianKDE()

self.register_buffer("max_length", torch.empty([]))
self.max_length = torch.empty([])

def pre_process(self, feature_stack: Tensor, max_length: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
"""Pre-process the CNN features.

Args:
feature_stack (Tensor): Features extracted from CNN
max_length (Optional[Tensor]): Used to unit normalize the feature_stack vector. If ``max_len`` is not
provided, the length is calculated from the ``feature_stack``. Defaults to None.

Returns:
(Tuple): Stacked features and length
"""

if max_length is None:
max_length = torch.max(torch.linalg.norm(feature_stack, ord=2, dim=1))

if self.feature_scaling_method == FeatureScalingMethod.NORM:
feature_stack /= torch.linalg.norm(feature_stack, ord=2, dim=1)[:, None]
elif self.feature_scaling_method == FeatureScalingMethod.SCALE:
feature_stack /= max_length
else:
raise RuntimeError("Unknown pre-processing mode. Available modes are: Normalized and Scale.")
return feature_stack, max_length

def fit(self, embeddings: Tensor) -> bool:
"""Fit a kde model to embeddings.

Args:
embeddings (Tensor): Input embeddings to fit the model.

Returns:
Boolean confirming whether the training is successful.
"""

if embeddings.shape[0] < self.n_pca_components:
logger.info("Not enough features to commit. Not making a model.")
return False

# if max training points is non-zero and smaller than number of staged features, select random subset
if embeddings.shape[0] > self.max_training_points:
selected_idx = torch.tensor(random.sample(range(embeddings.shape[0]), self.max_training_points))
selected_features = embeddings[selected_idx]
else:
selected_features = embeddings

feature_stack = self.pca_model.fit_transform(selected_features)
feature_stack, max_length = self.pre_process(feature_stack)
self.max_length = max_length
self.kde_model.fit(feature_stack)

return True

def compute_kde_scores(self, features: Tensor, as_log_likelihood: Optional[bool] = False) -> Tensor:
"""Compute the KDE scores.

The scores calculated from the KDE model are converted to densities. If `as_log_likelihood` is set to true then
the log of the scores are calculated.

Args:
features (Tensor): Features to which the PCA model is fit.
as_log_likelihood (Optional[bool], optional): If true, gets log likelihood scores. Defaults to False.

Returns:
(Tensor): Score
"""

features = self.pca_model.transform(features)
features, _ = self.pre_process(features, self.max_length)
# Scores are always assumed to be passed as a density
kde_scores = self.kde_model(features)

# add small constant to avoid zero division in log computation
kde_scores += 1e-300

if as_log_likelihood:
kde_scores = torch.log(kde_scores)

return kde_scores

@staticmethod
def compute_probabilities(scores: Tensor) -> Tensor:
"""Converts density scores to anomaly probabilities (see https://www.desmos.com/calculator/ifju7eesg7).

Args:
scores (Tensor): density of an image.

Returns:
probability that image with {density} is anomalous
"""
return 1 / (1 + torch.exp(0.05 * (scores - 12)))

def predict(self, features: Tensor) -> Tensor:
"""Predicts the probability that the features belong to the anomalous class.

Args:
features (Tensor): Feature from which the output probabilities are detected.

Returns:
Detection probabilities
"""

scores = self.compute_kde_scores(features, as_log_likelihood=True)
probabilities = self.compute_probabilities(scores)

return probabilities

def forward(self, features: Tensor) -> Tensor:
"""Make predictions on extracted features."""
return self.predict(features)
13 changes: 7 additions & 6 deletions anomalib/models/dfkde/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,17 @@ dataset:

model:
name: dfkde
# feature extraction params
backbone: resnet18
layers:
- layer4
pre_trained: true
# density estimation params
n_pca_components: 16
max_training_points: 40000
pre_processing: scale
n_components: 16
threshold_steepness: 0.05
threshold_offset: 12
feature_scaling_method: scale # Determines how the feature embeddings are scaled. Options: [scale, norm]
# generic params
normalization_method: min_max # options: [null, min_max, cdf]
layers:
- layer4

metrics:
image:
Expand Down
28 changes: 13 additions & 15 deletions anomalib/models/dfkde/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
import logging
from typing import List, Union

import torch
from omegaconf import DictConfig, ListConfig
from pytorch_lightning.utilities.cli import MODEL_REGISTRY
from torch import Tensor

from anomalib.models.components import AnomalyModule
from anomalib.models.components.classification import FeatureScalingMethod

from .torch_model import DfkdeModel

Expand Down Expand Up @@ -39,23 +41,19 @@ def __init__(
layers: List[str],
backbone: str,
pre_trained: bool = True,
n_pca_components: int = 16,
feature_scaling_method: FeatureScalingMethod = FeatureScalingMethod.SCALE,
max_training_points: int = 40000,
pre_processing: str = "scale",
n_components: int = 16,
threshold_steepness: float = 0.05,
threshold_offset: int = 12,
):
super().__init__()

self.model = DfkdeModel(
layers=layers,
backbone=backbone,
pre_trained=pre_trained,
n_comps=n_components,
pre_processing=pre_processing,
filter_count=max_training_points,
threshold_steepness=threshold_steepness,
threshold_offset=threshold_offset,
n_pca_components=n_pca_components,
feature_scaling_method=feature_scaling_method,
max_training_points=max_training_points,
)

self.embeddings: List[Tensor] = []
Expand All @@ -76,7 +74,7 @@ def training_step(self, batch, _batch_idx): # pylint: disable=arguments-differ
Deep CNN features.
"""

embedding = self.model.get_features(batch["image"]).squeeze()
embedding = self.model(batch["image"])

# NOTE: `self.embedding` appends each batch embedding to
# store the training set embedding. We manually append these
Expand All @@ -89,8 +87,10 @@ def on_validation_start(self) -> None:
# NOTE: Previous anomalib versions fit Gaussian at the end of the epoch.
# This is not possible anymore with PyTorch Lightning v1.4.0 since validation
# is run within train epoch.
embeddings = torch.vstack(self.embeddings)

logger.info("Fitting a KDE model to the embedding collected from the training set.")
self.model.fit(self.embeddings)
self.model.classifier.fit(embeddings)
djdameln marked this conversation as resolved.
Show resolved Hide resolved

def validation_step(self, batch, _): # pylint: disable=arguments-differ
"""Validation Step of DFKDE.
Expand Down Expand Up @@ -120,11 +120,9 @@ def __init__(self, hparams: Union[DictConfig, ListConfig]) -> None:
layers=hparams.model.layers,
backbone=hparams.model.backbone,
pre_trained=hparams.model.pre_trained,
n_pca_components=hparams.model.n_pca_components,
feature_scaling_method=FeatureScalingMethod(hparams.model.feature_scaling_method),
max_training_points=hparams.model.max_training_points,
pre_processing=hparams.model.pre_processing,
n_components=hparams.model.n_components,
threshold_steepness=hparams.model.threshold_steepness,
threshold_offset=hparams.model.threshold_offset,
)
self.hparams: Union[DictConfig, ListConfig] # type: ignore
self.save_hyperparameters(hparams)
Loading