Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add KNN outlier detector #677

Merged
merged 122 commits into from
Mar 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
122 commits
Select commit Hold shift + click to select a range
df5dfff
Initial commit
mauicv Oct 11, 2022
4a4d19d
Add transforms
mauicv Oct 14, 2022
e729855
Minor progress commit
mauicv Nov 14, 2022
b2ba7ab
Add transforms and fitted transforms
mauicv Nov 17, 2022
1c0c7c0
Fix flake8 errors
mauicv Nov 17, 2022
2bf0056
Add accumulator into KNNTorch backend
mauicv Nov 21, 2022
c7f4825
Add BaseTorchDetector functionality
mauicv Nov 21, 2022
02ac5bc
Add torchscript tests for knn backend module
mauicv Nov 22, 2022
9068ecc
Fix GaussianRBF knn kernel test
mauicv Nov 23, 2022
e95c884
Minor correction
mauicv Nov 23, 2022
47eb6c7
Rewrite knn outlier detector
mauicv Nov 23, 2022
4b201df
Surface errors if for unfit detectors
mauicv Nov 24, 2022
5e3bc7b
Merge backend test features into test_knn_backend
mauicv Nov 25, 2022
e272c74
Make knn tests better
mauicv Nov 25, 2022
bc7de79
Remove test file
mauicv Nov 25, 2022
ef1eb83
Fix mypy errors
mauicv Nov 25, 2022
945a15a
Import Literal from typing_extensions for python version compatibility
mauicv Nov 25, 2022
7363cb5
Add docstrings for backend ensemble and knn objects
mauicv Nov 28, 2022
6ca8062
Add docstrings for base torch outlier detector class
mauicv Nov 28, 2022
4674f8e
Add docstrings for kNN detector
mauicv Nov 28, 2022
07f0ffc
Minor fixes
mauicv Nov 28, 2022
5e0c3f7
Minor fixes
mauicv Nov 28, 2022
e884e94
Add docstrings for outlier detector base class
mauicv Nov 28, 2022
756def5
Fix mypy issue and test
mauicv Nov 28, 2022
96009e7
Reorder imports
mauicv Nov 28, 2022
b6ac822
Replace normaliser with normalizer
mauicv Nov 28, 2022
3dc8408
Add optional dependency tests
mauicv Nov 28, 2022
322989e
Add make_moons dataset tests for ensemble and single kNN detectors
mauicv Nov 29, 2022
68e3a45
Fix minor mypy incompatibiity
mauicv Nov 29, 2022
77ffdab
Add line breaks in return docstrings
mauicv Nov 29, 2022
111b40f
Add torch.jit.is_scripting checks for unscriptable control flow
mauicv Dec 14, 2022
42893bb
Add torch device logic for knn backend
mauicv Jan 3, 2023
f74dfc8
Ensure cuda output tensors are converted to cpu
mauicv Jan 3, 2023
0754d35
Set default device for knn to None
mauicv Jan 3, 2023
d0d04d6
Place tensors on correct device in transform
mauicv Jan 3, 2023
09412c4
Update default knn ensemble aggregator and normalizer values
mauicv Jan 3, 2023
65fbeac
Add tests for aggregator and normalizer default values
mauicv Jan 4, 2023
633636a
Merge branch 'master' into feature/knn-outlier-detector
mauicv Jan 4, 2023
4cd2eba
Remove Optional type from aggregator
mauicv Jan 4, 2023
98d2ba5
Change X -> x throughout
mauicv Jan 4, 2023
3d57b6d
Change anomaly -> outlier throughout
mauicv Jan 4, 2023
bc01ddf
Improve fpr description
mauicv Jan 4, 2023
47566b8
Update PValNormalizer docstring
mauicv Jan 4, 2023
bda15fd
Add custom error types
mauicv Jan 4, 2023
bacbfcc
Test pval and shift and scale normalizer output values
mauicv Jan 4, 2023
ea436b0
Test aggregator output values
mauicv Jan 4, 2023
c3af856
Remove unneeded NotImplemnentedErrors from ABC abstract methods
mauicv Jan 4, 2023
351e5a7
Fix typos
mauicv Jan 4, 2023
a190cc3
Fix method typo
mauicv Jan 4, 2023
0daa98f
Fix docstrings for KNNTorch
mauicv Jan 4, 2023
8d2fa3b
Set api signatures to accept np.ndarray and not List types
mauicv Jan 4, 2023
af8ab43
Fix mypy error
mauicv Jan 4, 2023
89c4815
Move to numpy logic from OutlierDetectorOutput dataclass to base class
mauicv Jan 5, 2023
c909f4c
Refactor init knn logic
mauicv Jan 5, 2023
bc7a771
Refator str to aggregator and normalizer methods to backend
mauicv Jan 5, 2023
dfd613d
Align kNN output with other outlier detectors
mauicv Jan 6, 2023
3fb167b
Refactor backend.pytorch into pytorch module
mauicv Jan 6, 2023
f28cb0f
Fix optional dependency tests
mauicv Jan 6, 2023
6ef8c62
Add backticks do docstrings
mauicv Jan 10, 2023
76b31d7
Update docstrings
mauicv Jan 11, 2023
6fb600c
reword numpy to torch tensor in transform object docstrings
mauicv Jan 11, 2023
cf6bba1
Update return type hints
mauicv Jan 11, 2023
96f9be0
Add hasattr check in _accumulator method
mauicv Jan 16, 2023
e936b36
Add singlular dispatch pattern for _to_numpy method
mauicv Jan 16, 2023
ce28b8e
Replace alibi_detect.utils._types imports with typing_extension
mauicv Jan 17, 2023
789a1d3
Replace singular_dispatch_method with singular_dispatch
mauicv Jan 17, 2023
2d7bdeb
Merge branch 'master' into feature/knn-outlier-detector
mauicv Jan 18, 2023
7bc86dc
Merge branch 'master' into feature/knn-outlier-detector
mauicv Jan 20, 2023
40a0b27
Fix minor PR suggested changes
mauicv Jan 24, 2023
a831642
Make knn object private
mauicv Jan 24, 2023
aa334bd
Improve the kNN detector docstrings
mauicv Jan 24, 2023
d06673f
Rename aggregator to ensembler
mauicv Jan 24, 2023
0c674f0
Merge branch 'master' into feature/knn-outlier-detector
mauicv Jan 24, 2023
3d69469
Add experimental module
mauicv Jan 31, 2023
901ad53
Add _to_numpy as a static method on base backend class
mauicv Feb 1, 2023
049f93f
Remove OutlierDetector
mauicv Feb 3, 2023
c1a2ea8
Add comments to test and remove duplicate test
mauicv Feb 3, 2023
885201f
Add further comments to _knn tests
mauicv Feb 3, 2023
8667af5
Correct method name spelling
mauicv Feb 3, 2023
fb7584a
Fix return in docstring
mauicv Feb 3, 2023
cddb470
Add no grad decorator to backend methods
mauicv Feb 3, 2023
04da0a2
Use docstring in test instead of comments
mauicv Feb 13, 2023
a677e30
Address minor pr comments
mauicv Feb 13, 2023
b227db4
Remove singular dispatch pattern
mauicv Feb 15, 2023
705bc95
Merge branch 'master' into feature/knn-outlier-detector
mauicv Feb 15, 2023
46bb747
Fit ensembler in infer_threshold step not fit step
mauicv Feb 22, 2023
a3e89e3
Add further documentation to knn detector
mauicv Feb 22, 2023
b08572d
Refactor exceptions into seperate file and add base class
mauicv Feb 28, 2023
70921b7
Rename exceptions to be consistent with alibi
mauicv Feb 28, 2023
87c8ab7
Remove __future__ annotations imports and unused _types imports
mauicv Mar 7, 2023
21a5fc0
Add link from base protocols to transform docstrings
mauicv Mar 7, 2023
75a7fb6
Rename transform_protocols and others to PascalCase
mauicv Mar 7, 2023
f4a13d3
Remove private methods from public __init__
mauicv Mar 7, 2023
186cde2
Fix missing comma bug
mauicv Mar 7, 2023
adb1f51
Add metion about torch device in KNNTorch and KNN docstrings
mauicv Mar 7, 2023
789b922
Make argument captialization consitent
mauicv Mar 7, 2023
bdaf4d4
Add missing raise statment
mauicv Mar 7, 2023
d539063
Remove constructors from mixins
mauicv Mar 8, 2023
7633d8b
Change code formatting to be more readable
mauicv Mar 8, 2023
fc713c3
Revert "Remove constructors from mixins"
mauicv Mar 8, 2023
1e3d4d8
Remove constructors from FitMixinTorch
mauicv Mar 8, 2023
ea20b21
Remove private method pattern in ensemble mixins
mauicv Mar 8, 2023
b1302dd
Fix spelling mistakes
mauicv Mar 8, 2023
09cf9f7
Add np.isclose for sum to one check
mauicv Mar 8, 2023
37025ed
Fix mutable default issue
mauicv Mar 8, 2023
b48b0a0
Expose _knn docstrings in the experimental namespace
mauicv Mar 9, 2023
8e423a0
Fix minor spelling mistake
mauicv Mar 9, 2023
e0deeaa
Add self return types
mauicv Mar 14, 2023
0a4046a
Make fit an abstract method on FitMixin
mauicv Mar 14, 2023
c4c7004
Add tests to check correct errors raised in KNNTorch backend
mauicv Mar 14, 2023
ffffbe8
Fix fxiture scope error in tests
mauicv Mar 14, 2023
02657ec
Catch and throw less confusing errors from backend components
mauicv Mar 15, 2023
5b754cd
Remove return self statments for consistency with old detectors
mauicv Mar 15, 2023
912afbd
Reword docstring for _catch_error dectorator
mauicv Mar 15, 2023
d880f92
Remove autodoc comment
mauicv Mar 15, 2023
8d98ee2
Add value error for invalid choice of fpr
mauicv Mar 16, 2023
95d0e3c
Set default PValNormalizer
mauicv Mar 16, 2023
12d22f7
Cast outlier booleans to ints
mauicv Mar 16, 2023
767aff9
Rewrite the 1st paragraph of the knn detector docstring
mauicv Mar 16, 2023
467aa6a
Minor change
mauicv Mar 16, 2023
d57d1bd
Add docstrings for the ensemble tests
mauicv Mar 16, 2023
a535e8d
Rename x_ref to x in infer_threshold
mauicv Mar 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions alibi_detect/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,14 @@ def predict(self, X: np.ndarray):

class FitMixin(ABC):
@abstractmethod
def fit(self, X: np.ndarray) -> None:
pass
def fit(self, *args, **kwargs) -> None:
...
ascillitoe marked this conversation as resolved.
Show resolved Hide resolved


class ThresholdMixin(ABC):
@abstractmethod
def infer_threshold(self, X: np.ndarray) -> None:
pass
def infer_threshold(self, *args, **kwargs) -> None:
...


# "Large artefacts" - to save memory these are skipped in _set_config(), but added back in get_config()
Expand Down
63 changes: 63 additions & 0 deletions alibi_detect/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""This module defines the Alibi Detect exception hierarchy and common exceptions used across the library."""
from typing_extensions import Literal
from typing import Callable
from abc import ABC
from functools import wraps


class AlibiDetectException(Exception, ABC):
def __init__(self, message: str) -> None:
"""Abstract base class of all alibi detect errors.

Parameters
----------
message
The error message.
"""
super().__init__(message)


class NotFittedError(AlibiDetectException):
def __init__(self, object_name: str) -> None:
"""Exception raised when a transform is not fitted.

Parameters
----------
message
The name of the unfit object.
"""
message = f'{object_name} has not been fit!'
super().__init__(message)


class ThresholdNotInferredError(AlibiDetectException):
def __init__(self, object_name: str) -> None:
"""Exception raised when a threshold not inferred for an outlier detector.

Parameters
----------
message
The name of the object that does not have a threshold fit.
"""
message = f'{object_name} has no threshold set, call `infer_threshold` to fit one!'
super().__init__(message)


def _catch_error(err_name: Literal['NotFittedError', 'ThresholdNotInferredError']) -> Callable:
mauicv marked this conversation as resolved.
Show resolved Hide resolved
"""Decorator to catch errors and raise a more informative error message.

Note: This decorator should only be used on detector frontend methods. It catches errors raised by
backend components and re-raises them with error messages corresponding to the specific detector frontend.
This is done to avoid exposing the backend components to the user.
"""
error_type = globals()[err_name]

def decorate(f):
@wraps(f)
def applicator(self, *args, **kwargs):
try:
return f(self, *args, **kwargs)
except error_type as err:
raise error_type(self.__class__.__name__) from err
return applicator
return decorate
3 changes: 2 additions & 1 deletion alibi_detect/od/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .mahalanobis import Mahalanobis
from .sr import SpectralResidual


OutlierAEGMM = import_optional('alibi_detect.od.aegmm', names=['OutlierAEGMM'])
OutlierAE = import_optional('alibi_detect.od.ae', names=['OutlierAE'])
OutlierVAE = import_optional('alibi_detect.od.vae', names=['OutlierVAE'])
Expand All @@ -22,5 +23,5 @@
"OutlierSeq2Seq",
"SpectralResidual",
"LLR",
"OutlierProphet"
"OutlierProphet",
]
221 changes: 221 additions & 0 deletions alibi_detect/od/_knn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
from typing import Callable, Union, Optional, Dict, Any, List, Tuple
from typing import TYPE_CHECKING

import numpy as np

from typing_extensions import Literal
from alibi_detect.base import outlier_prediction_dict
from alibi_detect.exceptions import _catch_error as catch_error
from alibi_detect.od.base import TransformProtocol, TransformProtocolType
from alibi_detect.base import BaseDetector, FitMixin, ThresholdMixin
from alibi_detect.od.pytorch import KNNTorch, Ensembler
from alibi_detect.od.base import get_aggregator, get_normalizer, NormalizerLiterals, AggregatorLiterals
from alibi_detect.utils.frameworks import BackendValidator
from alibi_detect.version import __version__


if TYPE_CHECKING:
import torch


backends = {
'pytorch': (KNNTorch, Ensembler)
}


class KNN(BaseDetector, FitMixin, ThresholdMixin):
def __init__(
self,
k: Union[int, np.ndarray, List[int], Tuple[int]],
kernel: Optional[Callable] = None,
normalizer: Optional[Union[TransformProtocolType, NormalizerLiterals]] = 'PValNormalizer',
aggregator: Union[TransformProtocol, AggregatorLiterals] = 'AverageAggregator',
mauicv marked this conversation as resolved.
Show resolved Hide resolved
backend: Literal['pytorch'] = 'pytorch',
device: Optional[Union[Literal['cuda', 'gpu', 'cpu'], 'torch.device']] = None,
) -> None:
"""
k-Nearest Neighbors (kNN) outlier detector.

The kNN detector is a non-parametric method for outlier detection. The detector scores each instance
based on the distance to its neighbors. Instances with a large distance to their neighbors are more
likely to be outliers.

The detector can be initialized with `k` a single value or an array of values. If `k` is a single value then
the outlier score is the distance/kernel similarity to the k-th nearest neighbor. If `k` is an array of
values then the outlier score is the distance/kernel similarity to each of the specified `k` neighbors.
In the latter case, an `aggregator` must be specified to aggregate the scores.

Note that, in the multiple k case, a normalizer can be provided. If a normalizer is passed then it is fit in
the `infer_threshold` method and so this method must be called before the `predict` method. If this is not
done an exception is raised. If `k` is a single value then the predict method can be called without first
calling `infer_threshold` but only scores will be returned and not outlier predictions.


Parameters
----------
k
Number of nearest neighbors to compute distance to. `k` can be a single value or
an array of integers. If an array is passed, an aggregator is required to aggregate
the scores. If `k` is a single value the outlier score is the distance/kernel
similarity to the `k`-th nearest neighbor. If `k` is a list then it returns the
distance/kernel similarity to each of the specified `k` neighbors.
kernel
Kernel function to use for outlier detection. If ``None``, `torch.cdist` is used.
Otherwise if a kernel is specified then instead of using `torch.cdist` the kernel
defines the k nearest neighbor distance.
normalizer
Normalizer to use for outlier detection. If ``None``, no normalization is applied.
For a list of available normalizers, see :mod:`alibi_detect.od.pytorch.ensemble`.
aggregator
Aggregator to use for outlier detection. Can be set to ``None`` if `k` is a single
value. For a list of available aggregators, see :mod:`alibi_detect.od.pytorch.ensemble`.
backend
Backend used for outlier detection. Defaults to ``'pytorch'``. Options are ``'pytorch'``.
device
Device type used. The default tries to use the GPU and falls back on CPU if needed.
Can be specified by passing either ``'cuda'``, ``'gpu'``, ``'cpu'`` or an instance of
``torch.device``.

Raises
------
ValueError
If `k` is an array and `aggregator` is None.
NotImplementedError
If choice of `backend` is not implemented.
"""
super().__init__()

backend_str: str = backend.lower()
BackendValidator(
backend_options={'pytorch': ['pytorch']},
construct_name=self.__class__.__name__
).verify_backend(backend_str)

backend_cls, ensembler_cls = backends[backend]
ensembler = None

if aggregator is None and isinstance(k, (list, np.ndarray, tuple)):
raise ValueError('If `k` is a `np.ndarray`, `list` or `tuple`, '
'the `aggregator` argument cannot be ``None``.')

if isinstance(k, (list, np.ndarray, tuple)):
ensembler = ensembler_cls(
normalizer=get_normalizer(normalizer),
aggregator=get_aggregator(aggregator)
)

self.backend = backend_cls(k, kernel=kernel, ensembler=ensembler, device=device)

# set metadata
self.meta['detector_type'] = 'outlier'
self.meta['data_type'] = 'numeric'
self.meta['online'] = False

def fit(self, x_ref: np.ndarray) -> None:
"""Fit the detector on reference data.

Parameters
----------
x_ref
Reference data used to fit the detector.
"""
self.backend.fit(self.backend._to_tensor(x_ref))

@catch_error('NotFittedError')
@catch_error('ThresholdNotInferredError')
def score(self, x: np.ndarray) -> np.ndarray:
"""Score `x` instances using the detector.

Computes the k nearest neighbor distance/kernel similarity for each instance in `x`. If `k` is a single
value then this is the score otherwise if `k` is an array of values then the score is aggregated using
the ensembler.

Parameters
----------
x
Data to score. The shape of `x` should be `(n_instances, n_features)`.

Raises
------
NotFittedError
If called before detector has been fit.
ThresholdNotInferredError
If k is a list and a threshold was not inferred.

Returns
-------
Outlier scores. The shape of the scores is `(n_instances,)`. The higher the score, the more anomalous the \
instance.
ascillitoe marked this conversation as resolved.
Show resolved Hide resolved
"""
score = self.backend.score(self.backend._to_tensor(x))
score = self.backend._ensembler(score)
return self.backend._to_numpy(score)

@catch_error('NotFittedError')
def infer_threshold(self, x: np.ndarray, fpr: float) -> None:
"""Infer the threshold for the kNN detector.

The threshold is computed so that the outlier detector would incorrectly classify `fpr` proportion of the
reference data as outliers.

Raises
------
ValueError
Raised if `fpr` is not in ``(0, 1)``.

Raises
------
NotFittedError
If called before detector has been fit.

Parameters
----------
x
Reference data used to infer the threshold.
fpr
False positive rate used to infer the threshold. The false positive rate is the proportion of
instances in `x` that are incorrectly classified as outliers. The false positive rate should
be in the range ``(0, 1)``.
"""
self.backend.infer_threshold(self.backend._to_tensor(x), fpr)

@catch_error('NotFittedError')
@catch_error('ThresholdNotInferredError')
def predict(self, x: np.ndarray) -> Dict[str, Any]:
"""Predict whether the instances in `x` are outliers or not.

Scores the instances in `x` and if the threshold was inferred, returns the outlier labels and p-values as well.

Parameters
----------
x
Data to predict. The shape of `x` should be `(n_instances, n_features)`.

Raises
------
NotFittedError
If called before detector has been fit.
ThresholdNotInferredError
If k is a list and a threshold was not inferred.

Returns
-------
Dictionary with keys 'data' and 'meta'. 'data' contains the outlier scores. If threshold inference was \
performed, 'data' also contains the threshold value, outlier labels and p-vals . The shape of the scores is \
`(n_instances,)`. The higher the score, the more anomalous the instance. 'meta' contains information about \
the detector.
ascillitoe marked this conversation as resolved.
Show resolved Hide resolved
"""
outputs = self.backend.predict(self.backend._to_tensor(x))
output = outlier_prediction_dict()
output['data'] = {
**output['data'],
**self.backend._to_numpy(outputs)
}
output['meta'] = {
**output['meta'],
'name': self.__class__.__name__,
'detector_type': 'outlier',
'online': False,
'version': __version__,
}
return output
75 changes: 75 additions & 0 deletions alibi_detect/od/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from alibi_detect.utils.missing_optional_dependency import import_optional

from typing import Union
from typing_extensions import Literal, Protocol, runtime_checkable


# Use Protocols instead of base classes for the backend associated objects. This is a bit more flexible and allows us to
# avoid the torch/tensorflow imports in the base class.
@runtime_checkable
class TransformProtocol(Protocol):
"""Protocol for transformer objects.

The :py:obj:`~alibi_detect.od.pytorch.ensemble.BaseTransformTorch` object provides abstract methods for
jklaise marked this conversation as resolved.
Show resolved Hide resolved
objects that map between `torch` tensors. This protocol models the interface of the `BaseTransformTorch`
class.
"""
def transform(self, x):
pass


@runtime_checkable
class FittedTransformProtocol(TransformProtocol, Protocol):
"""Protocol for fitted transformer objects.

This protocol models the joint interface of the :py:obj:`~alibi_detect.od.pytorch.ensemble.BaseTransformTorch`
class and the :py:obj:`~alibi_detect.od.pytorch.ensemble.FitMixinTorch` class. These objects are transforms that
require to be fit."""
def fit(self, x_ref):
pass

def set_fitted(self):
pass

def check_fitted(self):
pass


TransformProtocolType = Union[TransformProtocol, FittedTransformProtocol]
NormalizerLiterals = Literal['PValNormalizer', 'ShiftAndScaleNormalizer']
AggregatorLiterals = Literal['TopKAggregator', 'AverageAggregator',
'MaxAggregator', 'MinAggregator']


PValNormalizer, ShiftAndScaleNormalizer, TopKAggregator, AverageAggregator, \
MaxAggregator, MinAggregator = import_optional(
'alibi_detect.od.pytorch.ensemble',
['PValNormalizer', 'ShiftAndScaleNormalizer', 'TopKAggregator',
'AverageAggregator', 'MaxAggregator', 'MinAggregator']
)


def get_normalizer(normalizer: Union[TransformProtocolType, NormalizerLiterals]) -> TransformProtocol:
if isinstance(normalizer, str):
try:
return {
'PValNormalizer': PValNormalizer,
'ShiftAndScaleNormalizer': ShiftAndScaleNormalizer,
}.get(normalizer)()
except KeyError:
raise NotImplementedError(f'Normalizer {normalizer} not implemented.')
return normalizer


def get_aggregator(aggregator: Union[TransformProtocol, AggregatorLiterals]) -> TransformProtocol:
if isinstance(aggregator, str):
try:
return {
'TopKAggregator': TopKAggregator,
'AverageAggregator': AverageAggregator,
'MaxAggregator': MaxAggregator,
'MinAggregator': MinAggregator,
}.get(aggregator)()
except KeyError:
raise NotImplementedError(f'Aggregator {aggregator} not implemented.')
return aggregator
Loading