Skip to content

Commit

Permalink
Return data from the backend fit method (#835)
Browse files Browse the repository at this point in the history
* Return data from the backend fit method

* Update docstrings for gmm and svm fit methods
  • Loading branch information
mauicv authored Jul 13, 2023
1 parent 28cb562 commit c580354
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 8 deletions.
10 changes: 9 additions & 1 deletion alibi_detect/od/_gmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,16 @@ def fit(
Defaults to ``'kmeans'``.
verbose
Verbosity level used to fit the detector. Used for both ``'sklearn'`` and ``'pytorch'`` backends. Defaults to ``0``.
Returns
-------
Dictionary with fit results. The dictionary contains the following keys depending on the backend used:
- converged: bool indicating whether EM algorithm converged.
- n_iter: number of EM iterations performed. Only returned if `backend` is ``'sklearn'``.
- n_epochs: number of gradient descent iterations performed. Only returned if `backend` is ``'pytorch'``.
- lower_bound: log-likelihood lower bound.
"""
self.backend.fit(
return self.backend.fit(
self.backend._to_backend_dtype(x_ref),
**self.backend.format_fit_kwargs(locals())
)
Expand Down
9 changes: 8 additions & 1 deletion alibi_detect/od/_svm.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,15 @@ def fit(
verbose
Verbosity level during training. ``0`` is silent, ``1`` prints fit status. If using `bgd`, fit displays a
progress bar. Otherwise, if using `sgd` then we output the Sklearn `SGDOneClassSVM.fit()` logs.
Returns
-------
Dictionary with fit results. The dictionary contains the following keys depending on the optimization used:
- converged: `bool` indicating whether training converged.
- n_iter: number of iterations performed.
- lower_bound: loss lower bound. Only returned for the `bgd`.
"""
self.backend.fit(
return self.backend.fit(
self.backend._to_backend_dtype(x_ref),
**self.backend.format_fit_kwargs(locals())
)
Expand Down
10 changes: 9 additions & 1 deletion alibi_detect/od/pytorch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@ def to_frontend_dtype(self):
return result


def _tensor_to_frontend_dtype(x: Union[torch.Tensor, np.ndarray, float]) -> Union[np.ndarray, float]:
if isinstance(x, torch.Tensor):
x = x.cpu().detach().numpy()
if isinstance(x, np.ndarray) and x.ndim == 0:
x = x.item()
return x # type: ignore[return-value]


def _raise_type_error(x):
raise TypeError(f'x is type={type(x)} but must be one of TorchOutlierDetectorOutput or a torch Tensor')

Expand All @@ -52,7 +60,7 @@ def to_frontend_dtype(x: Union[torch.Tensor, TorchOutlierDetectorOutput]) -> Uni

return {
'TorchOutlierDetectorOutput': lambda x: x.to_frontend_dtype(),
'Tensor': lambda x: x.cpu().detach().numpy()
'Tensor': _tensor_to_frontend_dtype
}.get(
x.__class__.__name__,
_raise_type_error
Expand Down
6 changes: 3 additions & 3 deletions alibi_detect/od/pytorch/gmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ def fit( # type: ignore[override]
Returns
-------
Dictionary with fit results. The dictionary contains the following keys:
- converged: bool indicating whether EM algorithm converged.
- n_iter: number of EM iterations performed.
- converged: bool indicating whether training converged.
- n_epochs: number of gradient descent iterations performed.
- lower_bound: log-likelihood lower bound.
"""
self.model = GMMModel(self.n_components, x_ref.shape[-1]).to(self.device)
Expand Down Expand Up @@ -132,7 +132,7 @@ def fit( # type: ignore[override]
self._set_fitted()
return {
'converged': converged,
'lower_bound': min_loss,
'lower_bound': self._to_frontend_dtype(min_loss),
'n_epochs': epoch
}

Expand Down
4 changes: 2 additions & 2 deletions alibi_detect/od/pytorch/svm.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def fit( # type: ignore[override]
Returns
-------
Dictionary with fit results. The dictionary contains the following keys:
- converged: bool indicating whether training converged.
- converged: `bool` indicating whether training converged.
- n_iter: number of iterations performed.
- lower_bound: loss lower bound.
"""
Expand Down Expand Up @@ -338,7 +338,7 @@ def fit( # type: ignore[override]
self._set_fitted()
return {
'converged': converged,
'lower_bound': min_loss,
'lower_bound': self._to_frontend_dtype(min_loss),
'n_iter': iter
}

Expand Down
16 changes: 16 additions & 0 deletions alibi_detect/od/tests/test__gmm/test__gmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,19 @@ def test_gmm_torchscript(tmp_path):
ts_gmm = torch.load(tmp_path / 'gmm.pt')
y = ts_gmm(x)
assert torch.all(y == torch.tensor([False, True]))


@pytest.mark.parametrize('backend', ['pytorch', 'sklearn'])
def test_gmm_fit(backend):
"""Test GMM detector fit method.
Tests detector checks for convergence and stops early if it does.
"""
gmm = GMM(n_components=1, backend=backend)
mean = [8, 8]
cov = [[2., 0.], [0., 1.]]
x_ref = torch.tensor(np.random.multivariate_normal(mean, cov, 1000))
fit_results = gmm.fit(x_ref, tol=0.01, batch_size=32)
assert isinstance(fit_results['lower_bound'], float)
assert fit_results['converged']
assert fit_results['lower_bound'] < 1
25 changes: 25 additions & 0 deletions alibi_detect/od/tests/test__svm/test__svm.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,3 +207,28 @@ def test_svm_torchscript(tmp_path):
ts_svm = torch.load(tmp_path / 'svm.pt')
y = ts_svm(x)
assert torch.all(y == torch.tensor([False, True]))


@pytest.mark.parametrize('optimization', ['sgd', 'bgd'])
def test_svm_fit(optimization):
"""Test SVM detector fit method.
Tests pytorch detector checks for convergence and stops early if it does.
"""
kernel = GaussianRBF(torch.tensor(1.))
svm = SVM(
n_components=10,
kernel=kernel,
nu=0.01,
optimization=optimization,
)
mean = [8, 8]
cov = [[2., 0.], [0., 1.]]
x_ref = torch.tensor(np.random.multivariate_normal(mean, cov, 1000))
fit_results = svm.fit(x_ref, tol=0.01)
assert fit_results['converged']
assert fit_results['n_iter'] < 100
assert fit_results.get('lower_bound', 0) < 1
# 'sgd' optimization does not return lower bound
if optimization == 'bgd':
assert isinstance(fit_results['lower_bound'], float)

0 comments on commit c580354

Please sign in to comment.