Skip to content

Commit

Permalink
API: improve parameter checking for StratifiedBootstrapCV
Browse files Browse the repository at this point in the history
  • Loading branch information
j-ittner committed Jul 27, 2021
1 parent ba5dede commit eb1ab17
Showing 1 changed file with 33 additions and 15 deletions.
48 changes: 33 additions & 15 deletions src/facet/validation/_validation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
Core implementation of :mod:`facet.validation`.
"""

import warnings
from abc import ABCMeta, abstractmethod
from typing import Generator, Iterator, Optional, Sequence, Tuple, Union

Expand Down Expand Up @@ -67,31 +67,49 @@ def get_n_splits(
:param groups: for compatibility only, not used
:return: the number of splits
"""

for arg_name, arg in ("X", X), ("y", y), ("groups", groups):
if arg is not None:
warnings.warn(
f"arg {arg_name} is not used but got {arg_name}={arg!r}",
stacklevel=2,
)

return self.n_splits

# noinspection PyPep8Naming
def split(
self,
X: Union[np.ndarray, pd.DataFrame],
y: Union[np.ndarray, pd.Series, pd.DataFrame, None] = None,
groups: Sequence = None,
groups: Union[np.ndarray, pd.Series, pd.DataFrame, None] = None,
) -> Generator[Tuple[np.ndarray, np.ndarray], None, None]:
"""
Generate indices to split data into training and test set.
:param X: features
:param y: target
:param groups: not used
:param y: target: target variable for supervised learning problems,
used as labels for stratification
:param groups: ignored; exists for compatibility
:return: a generator yielding `(train, test)` tuples where
train and test are numpy arrays with train and test indices, respectively
"""

n = len(X)

if y is not None and n != len(y):
raise ValueError("args X and y must have the same length")
if n < 2:
raise ValueError("args X and y must have a length of at least 2")
raise ValueError("arg X must have at least 2 rows")

if y is None:
raise ValueError(
"no target variable specified in arg y as labels for stratification"
)

if n != len(y):
raise ValueError("args X and y must have the same length")

if groups is not None:
warnings.warn(f"ignoring arg groups={groups!r}", stacklevel=2)

rs = check_random_state(self.random_state)
indices = np.arange(n)
Expand All @@ -114,9 +132,9 @@ def _select_train_indices(
y: Union[np.ndarray, pd.Series, pd.DataFrame, None],
) -> np.ndarray:
"""
:param y: target
:param n_samples: number of indices to sample
:param random_state: random state object to be used for random sampling
:param y: labels for stratification
:return: an array of integer indices with shape ``[n_samples]``
"""
pass
Expand Down Expand Up @@ -167,16 +185,16 @@ def _select_train_indices(
random_state: np.random.RandomState,
y: Union[np.ndarray, pd.Series, pd.DataFrame, None],
) -> np.ndarray:
if y is None:
raise ValueError("arg y must be specified")
if not (
isinstance(y, pd.Series) or (isinstance(y, np.ndarray) and y.ndim == 1)
):
raise ValueError("arg y must be a Series or a 1d numpy array")
if isinstance(y, pd.Series):
y = y.values
elif not (isinstance(y, np.ndarray) and y.ndim == 1):
raise ValueError(
"target labels must be provided as a Series or a 1d numpy array"
)

return (
pd.Series(np.arange(len(y)))
.groupby(by=y.values if isinstance(y, pd.Series) else y)
.groupby(by=y)
.apply(
lambda group: group.sample(
n=len(group), replace=True, random_state=random_state
Expand Down

0 comments on commit eb1ab17

Please sign in to comment.