diff --git a/alibi_detect/od/_gmm.py b/alibi_detect/od/_gmm.py index 5f0d90fbc..53ee14cb2 100644 --- a/alibi_detect/od/_gmm.py +++ b/alibi_detect/od/_gmm.py @@ -130,8 +130,16 @@ def fit( Defaults to ``'kmeans'``. verbose Verbosity level used to fit the detector. Used for both ``'sklearn'`` and ``'pytorch'`` backends. Defaults to ``0``. + + Returns + ------- + Dictionary with fit results. The dictionary contains the following keys depending on the backend used: + - converged: bool indicating whether EM algorithm converged. + - n_iter: number of EM iterations performed. Only returned if `backend` is ``'sklearn'``. + - n_epochs: number of gradient descent iterations performed. Only returned if `backend` is ``'pytorch'``. + - lower_bound: log-likelihood lower bound. """ - self.backend.fit( + return self.backend.fit( self.backend._to_backend_dtype(x_ref), **self.backend.format_fit_kwargs(locals()) ) diff --git a/alibi_detect/od/_svm.py b/alibi_detect/od/_svm.py index 8c5b38cc5..a60066ff5 100644 --- a/alibi_detect/od/_svm.py +++ b/alibi_detect/od/_svm.py @@ -147,8 +147,15 @@ def fit( verbose Verbosity level during training. ``0`` is silent, ``1`` prints fit status. If using `bgd`, fit displays a progress bar. Otherwise, if using `sgd` then we output the Sklearn `SGDOneClassSVM.fit()` logs. + + Returns + ------- + Dictionary with fit results. The dictionary contains the following keys depending on the optimization used: + - converged: `bool` indicating whether training converged. + - n_iter: number of iterations performed. + - lower_bound: loss lower bound. Only returned for the `bgd`. """ - self.backend.fit( + return self.backend.fit( self.backend._to_backend_dtype(x_ref), **self.backend.format_fit_kwargs(locals()) ) diff --git a/alibi_detect/od/pytorch/base.py b/alibi_detect/od/pytorch/base.py index 747f6c995..ccb1543ed 100644 --- a/alibi_detect/od/pytorch/base.py +++ b/alibi_detect/od/pytorch/base.py @@ -31,6 +31,14 @@ def to_frontend_dtype(self): return result +def _tensor_to_frontend_dtype(x: Union[torch.Tensor, np.ndarray, float]) -> Union[np.ndarray, float]: + if isinstance(x, torch.Tensor): + x = x.cpu().detach().numpy() + if isinstance(x, np.ndarray) and x.ndim == 0: + x = x.item() + return x # type: ignore[return-value] + + def _raise_type_error(x): raise TypeError(f'x is type={type(x)} but must be one of TorchOutlierDetectorOutput or a torch Tensor') @@ -52,7 +60,7 @@ def to_frontend_dtype(x: Union[torch.Tensor, TorchOutlierDetectorOutput]) -> Uni return { 'TorchOutlierDetectorOutput': lambda x: x.to_frontend_dtype(), - 'Tensor': lambda x: x.cpu().detach().numpy() + 'Tensor': _tensor_to_frontend_dtype }.get( x.__class__.__name__, _raise_type_error diff --git a/alibi_detect/od/pytorch/gmm.py b/alibi_detect/od/pytorch/gmm.py index ba4f3e47a..474311b9a 100644 --- a/alibi_detect/od/pytorch/gmm.py +++ b/alibi_detect/od/pytorch/gmm.py @@ -75,8 +75,8 @@ def fit( # type: ignore[override] Returns ------- Dictionary with fit results. The dictionary contains the following keys: - - converged: bool indicating whether EM algorithm converged. - - n_iter: number of EM iterations performed. + - converged: bool indicating whether training converged. + - n_epochs: number of gradient descent iterations performed. - lower_bound: log-likelihood lower bound. """ self.model = GMMModel(self.n_components, x_ref.shape[-1]).to(self.device) @@ -132,7 +132,7 @@ def fit( # type: ignore[override] self._set_fitted() return { 'converged': converged, - 'lower_bound': min_loss, + 'lower_bound': self._to_frontend_dtype(min_loss), 'n_epochs': epoch } diff --git a/alibi_detect/od/pytorch/svm.py b/alibi_detect/od/pytorch/svm.py index 094499776..081896beb 100644 --- a/alibi_detect/od/pytorch/svm.py +++ b/alibi_detect/od/pytorch/svm.py @@ -272,7 +272,7 @@ def fit( # type: ignore[override] Returns ------- Dictionary with fit results. The dictionary contains the following keys: - - converged: bool indicating whether training converged. + - converged: `bool` indicating whether training converged. - n_iter: number of iterations performed. - lower_bound: loss lower bound. """ @@ -338,7 +338,7 @@ def fit( # type: ignore[override] self._set_fitted() return { 'converged': converged, - 'lower_bound': min_loss, + 'lower_bound': self._to_frontend_dtype(min_loss), 'n_iter': iter } diff --git a/alibi_detect/od/tests/test__gmm/test__gmm.py b/alibi_detect/od/tests/test__gmm/test__gmm.py index c210a134b..c4d724557 100644 --- a/alibi_detect/od/tests/test__gmm/test__gmm.py +++ b/alibi_detect/od/tests/test__gmm/test__gmm.py @@ -117,3 +117,19 @@ def test_gmm_torchscript(tmp_path): ts_gmm = torch.load(tmp_path / 'gmm.pt') y = ts_gmm(x) assert torch.all(y == torch.tensor([False, True])) + + +@pytest.mark.parametrize('backend', ['pytorch', 'sklearn']) +def test_gmm_fit(backend): + """Test GMM detector fit method. + + Tests detector checks for convergence and stops early if it does. + """ + gmm = GMM(n_components=1, backend=backend) + mean = [8, 8] + cov = [[2., 0.], [0., 1.]] + x_ref = torch.tensor(np.random.multivariate_normal(mean, cov, 1000)) + fit_results = gmm.fit(x_ref, tol=0.01, batch_size=32) + assert isinstance(fit_results['lower_bound'], float) + assert fit_results['converged'] + assert fit_results['lower_bound'] < 1 diff --git a/alibi_detect/od/tests/test__svm/test__svm.py b/alibi_detect/od/tests/test__svm/test__svm.py index eb5e991bb..f76cc4e58 100644 --- a/alibi_detect/od/tests/test__svm/test__svm.py +++ b/alibi_detect/od/tests/test__svm/test__svm.py @@ -207,3 +207,28 @@ def test_svm_torchscript(tmp_path): ts_svm = torch.load(tmp_path / 'svm.pt') y = ts_svm(x) assert torch.all(y == torch.tensor([False, True])) + + +@pytest.mark.parametrize('optimization', ['sgd', 'bgd']) +def test_svm_fit(optimization): + """Test SVM detector fit method. + + Tests pytorch detector checks for convergence and stops early if it does. + """ + kernel = GaussianRBF(torch.tensor(1.)) + svm = SVM( + n_components=10, + kernel=kernel, + nu=0.01, + optimization=optimization, + ) + mean = [8, 8] + cov = [[2., 0.], [0., 1.]] + x_ref = torch.tensor(np.random.multivariate_normal(mean, cov, 1000)) + fit_results = svm.fit(x_ref, tol=0.01) + assert fit_results['converged'] + assert fit_results['n_iter'] < 100 + assert fit_results.get('lower_bound', 0) < 1 + # 'sgd' optimization does not return lower bound + if optimization == 'bgd': + assert isinstance(fit_results['lower_bound'], float)