-
Notifications
You must be signed in to change notification settings - Fork 229
[MRG+2] Update repo to work with both new and old scikit-learn #313
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
3760c60
777f8d6
e5c240d
b754930
195da0a
0c780b0
bfea742
61755dd
4dac77b
ee856ad
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
"""This file is for fixing imports due to different APIs | ||
depending on the scikit-learn version""" | ||
import sklearn | ||
from packaging import version | ||
SKLEARN_AT_LEAST_0_22 = (version.parse(sklearn.__version__) | ||
>= version.parse('0.22.0')) | ||
if SKLEARN_AT_LEAST_0_22: | ||
from sklearn.utils._testing import (set_random_state, | ||
assert_warns_message, | ||
ignore_warnings, | ||
assert_allclose_dense_sparse, | ||
_get_args) | ||
from sklearn.utils.estimator_checks import (_is_public_parameter | ||
as is_public_parameter) | ||
from sklearn.metrics._scorer import get_scorer | ||
else: | ||
from sklearn.utils.testing import (set_random_state, | ||
assert_warns_message, | ||
ignore_warnings, | ||
assert_allclose_dense_sparse, | ||
_get_args) | ||
from sklearn.utils.estimator_checks import is_public_parameter | ||
from sklearn.metrics.scorer import get_scorer | ||
|
||
__all__ = ['set_random_state', 'assert_warns_message', 'set_random_state', | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. writing this |
||
'ignore_warnings', 'assert_allclose_dense_sparse', '_get_args', | ||
'is_public_parameter', 'get_scorer'] |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,71 +4,161 @@ | |
import metric_learn | ||
import numpy as np | ||
from sklearn import clone | ||
from sklearn.utils.testing import set_random_state | ||
from test.test_utils import ids_metric_learners, metric_learners, remove_y | ||
from metric_learn.sklearn_shims import set_random_state, SKLEARN_AT_LEAST_0_22 | ||
|
||
|
||
def remove_spaces(s): | ||
return re.sub(r'\s+', '', s) | ||
|
||
|
||
def sk_repr_kwargs(def_kwargs, nndef_kwargs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here I thought it could be good to test the str repr for both the old (<=0.22) sklearn version and the newer ones, so I made a string representation that depends on the sklearn version |
||
"""Given the non-default arguments, and the default | ||
keywords arguments, build the string that will appear | ||
in the __repr__ of the estimator, depending on the | ||
version of scikit-learn. | ||
""" | ||
if SKLEARN_AT_LEAST_0_22: | ||
def_kwargs = {} | ||
def_kwargs.update(nndef_kwargs) | ||
args_str = ",".join(f"{key}={repr(value)}" | ||
for key, value in def_kwargs.items()) | ||
return args_str | ||
|
||
|
||
class TestStringRepr(unittest.TestCase): | ||
|
||
def test_covariance(self): | ||
def_kwargs = {'preprocessor': None} | ||
nndef_kwargs = {} | ||
merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs) | ||
self.assertEqual(remove_spaces(str(metric_learn.Covariance())), | ||
remove_spaces("Covariance()")) | ||
remove_spaces(f"Covariance({merged_kwargs})")) | ||
|
||
def test_lmnn(self): | ||
def_kwargs = {'convergence_tol': 0.001, 'init': 'auto', 'k': 3, | ||
'learn_rate': 1e-07, 'max_iter': 1000, 'min_iter': 50, | ||
'n_components': None, 'preprocessor': None, | ||
'random_state': None, 'regularization': 0.5, | ||
'verbose': False} | ||
nndef_kwargs = {'convergence_tol': 0.01, 'k': 6} | ||
merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs) | ||
self.assertEqual( | ||
remove_spaces(str(metric_learn.LMNN(convergence_tol=0.01, k=6))), | ||
remove_spaces("LMNN(convergence_tol=0.01, k=6)")) | ||
remove_spaces(f"LMNN({merged_kwargs})")) | ||
|
||
def test_nca(self): | ||
def_kwargs = {'init': 'auto', 'max_iter': 100, 'n_components': None, | ||
'preprocessor': None, 'random_state': None, 'tol': None, | ||
'verbose': False} | ||
nndef_kwargs = {'max_iter': 42} | ||
merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs) | ||
self.assertEqual(remove_spaces(str(metric_learn.NCA(max_iter=42))), | ||
remove_spaces("NCA(max_iter=42)")) | ||
remove_spaces(f"NCA({merged_kwargs})")) | ||
|
||
def test_lfda(self): | ||
def_kwargs = {'embedding_type': 'weighted', 'k': None, | ||
'n_components': None, 'preprocessor': None} | ||
nndef_kwargs = {'k': 2} | ||
merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs) | ||
self.assertEqual(remove_spaces(str(metric_learn.LFDA(k=2))), | ||
remove_spaces("LFDA(k=2)")) | ||
remove_spaces(f"LFDA({merged_kwargs})")) | ||
|
||
def test_itml(self): | ||
def_kwargs = {'convergence_threshold': 0.001, 'gamma': 1.0, | ||
'max_iter': 1000, 'preprocessor': None, | ||
'prior': 'identity', 'random_state': None, 'verbose': False} | ||
nndef_kwargs = {'gamma': 0.5} | ||
merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs) | ||
self.assertEqual(remove_spaces(str(metric_learn.ITML(gamma=0.5))), | ||
remove_spaces("ITML(gamma=0.5)")) | ||
remove_spaces(f"ITML({merged_kwargs})")) | ||
def_kwargs = {'convergence_threshold': 0.001, 'gamma': 1.0, | ||
'max_iter': 1000, 'num_constraints': None, | ||
'preprocessor': None, 'prior': 'identity', | ||
'random_state': None, 'verbose': False} | ||
nndef_kwargs = {'num_constraints': 7} | ||
merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs) | ||
self.assertEqual( | ||
remove_spaces(str(metric_learn.ITML_Supervised(num_constraints=7))), | ||
remove_spaces("ITML_Supervised(num_constraints=7)")) | ||
remove_spaces(f"ITML_Supervised({merged_kwargs})")) | ||
|
||
def test_lsml(self): | ||
def_kwargs = {'max_iter': 1000, 'preprocessor': None, 'prior': 'identity', | ||
'random_state': None, 'tol': 0.001, 'verbose': False} | ||
nndef_kwargs = {'tol': 0.1} | ||
merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs) | ||
self.assertEqual(remove_spaces(str(metric_learn.LSML(tol=0.1))), | ||
remove_spaces("LSML(tol=0.1)")) | ||
remove_spaces(f"LSML({merged_kwargs})")) | ||
def_kwargs = {'max_iter': 1000, 'num_constraints': None, | ||
'preprocessor': None, 'prior': 'identity', | ||
'random_state': None, 'tol': 0.001, 'verbose': False, | ||
'weights': None} | ||
nndef_kwargs = {'verbose': True} | ||
merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs) | ||
self.assertEqual( | ||
remove_spaces(str(metric_learn.LSML_Supervised(verbose=True))), | ||
remove_spaces("LSML_Supervised(verbose=True)")) | ||
remove_spaces(f"LSML_Supervised({merged_kwargs})")) | ||
|
||
def test_sdml(self): | ||
def_kwargs = {'balance_param': 0.5, 'preprocessor': None, | ||
'prior': 'identity', 'random_state': None, | ||
'sparsity_param': 0.01, 'verbose': False} | ||
nndef_kwargs = {'verbose': True} | ||
merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs) | ||
self.assertEqual(remove_spaces(str(metric_learn.SDML(verbose=True))), | ||
remove_spaces("SDML(verbose=True)")) | ||
remove_spaces(f"SDML({merged_kwargs})")) | ||
def_kwargs = {'balance_param': 0.5, 'num_constraints': None, | ||
'preprocessor': None, 'prior': 'identity', | ||
'random_state': None, 'sparsity_param': 0.01, | ||
'verbose': False} | ||
nndef_kwargs = {'sparsity_param': 0.5} | ||
merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs) | ||
self.assertEqual( | ||
remove_spaces(str(metric_learn.SDML_Supervised(sparsity_param=0.5))), | ||
remove_spaces("SDML_Supervised(sparsity_param=0.5)")) | ||
remove_spaces(f"SDML_Supervised({merged_kwargs})")) | ||
|
||
def test_rca(self): | ||
def_kwargs = {'n_components': None, 'preprocessor': None} | ||
nndef_kwargs = {'n_components': 3} | ||
merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs) | ||
self.assertEqual(remove_spaces(str(metric_learn.RCA(n_components=3))), | ||
remove_spaces("RCA(n_components=3)")) | ||
remove_spaces(f"RCA({merged_kwargs})")) | ||
def_kwargs = {'chunk_size': 2, 'n_components': None, 'num_chunks': 100, | ||
'preprocessor': None, 'random_state': None} | ||
nndef_kwargs = {'num_chunks': 5} | ||
merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs) | ||
self.assertEqual( | ||
remove_spaces(str(metric_learn.RCA_Supervised(num_chunks=5))), | ||
remove_spaces("RCA_Supervised(num_chunks=5)")) | ||
remove_spaces(f"RCA_Supervised({merged_kwargs})")) | ||
|
||
def test_mlkr(self): | ||
def_kwargs = {'init': 'auto', 'max_iter': 1000, | ||
'n_components': None, 'preprocessor': None, | ||
'random_state': None, 'tol': None, 'verbose': False} | ||
nndef_kwargs = {'max_iter': 777} | ||
merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs) | ||
self.assertEqual(remove_spaces(str(metric_learn.MLKR(max_iter=777))), | ||
remove_spaces("MLKR(max_iter=777)")) | ||
remove_spaces(f"MLKR({merged_kwargs})")) | ||
|
||
def test_mmc(self): | ||
def_kwargs = {'convergence_threshold': 0.001, 'diagonal': False, | ||
'diagonal_c': 1.0, 'init': 'identity', 'max_iter': 100, | ||
'max_proj': 10000, 'preprocessor': None, | ||
'random_state': None, 'verbose': False} | ||
nndef_kwargs = {'diagonal': True} | ||
merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs) | ||
self.assertEqual(remove_spaces(str(metric_learn.MMC(diagonal=True))), | ||
remove_spaces("MMC(diagonal=True)")) | ||
remove_spaces(f"MMC({merged_kwargs})")) | ||
def_kwargs = {'convergence_threshold': 1e-06, 'diagonal': False, | ||
'diagonal_c': 1.0, 'init': 'identity', 'max_iter': 100, | ||
'max_proj': 10000, 'num_constraints': None, | ||
'preprocessor': None, 'random_state': None, | ||
'verbose': False} | ||
nndef_kwargs = {'max_iter': 1} | ||
merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs) | ||
self.assertEqual( | ||
remove_spaces(str(metric_learn.MMC_Supervised(max_iter=1))), | ||
remove_spaces("MMC_Supervised(max_iter=1)")) | ||
remove_spaces(f"MMC_Supervised({merged_kwargs})")) | ||
|
||
|
||
@pytest.mark.parametrize('estimator, build_dataset', metric_learners, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,10 +4,9 @@ | |
from sklearn.base import TransformerMixin | ||
from sklearn.pipeline import make_pipeline | ||
from sklearn.utils import check_random_state | ||
from sklearn.utils.estimator_checks import is_public_parameter | ||
from sklearn.utils.testing import (assert_allclose_dense_sparse, | ||
set_random_state) | ||
|
||
from metric_learn.sklearn_shims import (assert_allclose_dense_sparse, | ||
set_random_state, _get_args, | ||
is_public_parameter, get_scorer) | ||
from metric_learn import (Covariance, LFDA, LMNN, MLKR, NCA, | ||
ITML_Supervised, LSML_Supervised, | ||
MMC_Supervised, RCA_Supervised, SDML_Supervised, | ||
|
@@ -16,8 +15,6 @@ | |
import numpy as np | ||
from sklearn.model_selection import (cross_val_score, cross_val_predict, | ||
train_test_split, KFold) | ||
from sklearn.metrics.scorer import get_scorer | ||
from sklearn.utils.testing import _get_args | ||
from test.test_utils import (metric_learners, ids_metric_learners, | ||
mock_preprocessor, tuples_learners, | ||
ids_tuples_learners, pairs_learners, | ||
|
@@ -52,37 +49,37 @@ def __init__(self, sparsity_param=0.01, | |
|
||
class TestSklearnCompat(unittest.TestCase): | ||
def test_covariance(self): | ||
check_estimator(Covariance) | ||
check_estimator(Covariance()) | ||
|
||
def test_lmnn(self): | ||
check_estimator(LMNN) | ||
check_estimator(LMNN()) | ||
|
||
def test_lfda(self): | ||
check_estimator(LFDA) | ||
check_estimator(LFDA()) | ||
|
||
def test_mlkr(self): | ||
check_estimator(MLKR) | ||
check_estimator(MLKR()) | ||
|
||
def test_nca(self): | ||
check_estimator(NCA) | ||
check_estimator(NCA()) | ||
|
||
def test_lsml(self): | ||
check_estimator(LSML_Supervised) | ||
check_estimator(LSML_Supervised()) | ||
|
||
def test_itml(self): | ||
check_estimator(ITML_Supervised) | ||
check_estimator(ITML_Supervised()) | ||
|
||
def test_mmc(self): | ||
check_estimator(MMC_Supervised) | ||
check_estimator(MMC_Supervised()) | ||
|
||
def test_sdml(self): | ||
check_estimator(Stable_SDML_Supervised) | ||
check_estimator(Stable_SDML_Supervised()) | ||
|
||
def test_rca(self): | ||
check_estimator(Stable_RCA_Supervised) | ||
check_estimator(Stable_RCA_Supervised()) | ||
|
||
def test_scml(self): | ||
check_estimator(SCML_Supervised) | ||
check_estimator(SCML_Supervised()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here scikit-learn had return an error saying that checks should be run on estimators instances, not classes |
||
|
||
|
||
RNG = check_random_state(0) | ||
|
@@ -121,7 +118,8 @@ def test_array_like_inputs(estimator, build_dataset, with_preprocessor): | |
|
||
# we subsample the data for the test to be more efficient | ||
input_data, _, labels, _ = train_test_split(input_data, labels, | ||
train_size=20) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Test was failing here because of classes with too few labels in LMNN (see comments above), so in the end I created a toy example with a bit more samples (which I guess makes sense because the role of this particular test is not to test edge cases, but rather the fact that array-like objects work with our estimators), |
||
train_size=40, | ||
random_state=42) | ||
X = X[:10] | ||
|
||
estimator = clone(estimator) | ||
|
@@ -160,7 +158,7 @@ def test_various_scoring_on_tuples_learners(estimator, build_dataset, | |
with_preprocessor): | ||
"""Tests that scikit-learn's scoring returns something finite, | ||
for other scoring than default scoring. (List of scikit-learn's scores can be | ||
found in sklearn.metrics.scorer). For each type of output (predict, | ||
found in sklearn.metrics._scorer). For each type of output (predict, | ||
predict_proba, decision_function), we test a bunch of scores. | ||
We only test on pairs learners because quadruplets don't have a y argument. | ||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps add a note here to clarify this additional test's purpose
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, done