Skip to content

Commit

Permalink
UPD: add raise value error with wrong cv parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
thibaultcordier committed Dec 20, 2023
1 parent 1fa8b7c commit f46f117
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 4 deletions.
17 changes: 15 additions & 2 deletions mapie/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from numpy.random import RandomState
from sklearn.datasets import make_regression
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import BaseCrossValidator, KFold, ShuffleSplit
from sklearn.model_selection import (BaseCrossValidator, KFold, LeaveOneOut,
ShuffleSplit)
from sklearn.utils.validation import check_is_fitted

from mapie._typing import ArrayLike, NDArray
Expand Down Expand Up @@ -483,10 +484,22 @@ def test_check_cv_same_split_no_random_state(cv: BaseCrossValidator) -> None:
("split", True), (KFold(5), False),
(ShuffleSplit(1), True),
(ShuffleSplit(2), False),
(object(), False)
(LeaveOneOut(), False),
]
)
def test_check_no_agg_cv(cv_result: Tuple) -> None:
"""Test that if `check_no_agg_cv` function returns the expected result."""
array = ["prefit", "split"]
cv, result = cv_result
np.testing.assert_almost_equal(check_no_agg_cv(X_toy, cv, array), result)


@pytest.mark.parametrize("cv", [object()])
def test_check_no_agg_cv_value_error(cv: Any) -> None:
"""Test that if `check_no_agg_cv` function raises value error."""
array = ["prefit", "split"]
with pytest.raises(
ValueError,
match=r"Allowed values must have the `get_n_splits` method"
):
check_no_agg_cv(X_toy, cv, array)
8 changes: 6 additions & 2 deletions mapie/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,10 +238,14 @@ def check_no_agg_cv(
return cv in no_agg_cv_array
elif isinstance(cv, int):
return cv == 1
if hasattr(cv, "get_n_splits"):
elif hasattr(cv, "get_n_splits"):
return cv.get_n_splits(X) == 1
else:
return False
raise ValueError(
"Invalid cv argument. "
"Allowed values must have the `get_n_splits` method "
"with zero or one parameter (X)."
)


def check_alpha(
Expand Down

0 comments on commit f46f117

Please sign in to comment.