From 4769ab4396ff892437ace106cc48da098fd80035 Mon Sep 17 00:00:00 2001 From: aimet-pasitpk Date: Mon, 11 Mar 2024 15:27:51 +0700 Subject: [PATCH] rename data splitting functions --- aimet_ml/model_selection/__init__.py | 2 +- aimet_ml/model_selection/splits.py | 114 +++++++++++++-------------- tests/model_selection/test_splits.py | 20 +++-- 3 files changed, 71 insertions(+), 65 deletions(-) diff --git a/aimet_ml/model_selection/__init__.py b/aimet_ml/model_selection/__init__.py index 8b61eff..b2bb26f 100644 --- a/aimet_ml/model_selection/__init__.py +++ b/aimet_ml/model_selection/__init__.py @@ -1 +1 @@ -from .splits import get_splitter, join_cols, split_dataset, split_dataset_v2, stratified_group_split +from .splits import get_splitter, join_cols, split_dataset, split_dataset_single_test, stratified_group_split diff --git a/aimet_ml/model_selection/splits.py b/aimet_ml/model_selection/splits.py index 2239d2f..193c397 100644 --- a/aimet_ml/model_selection/splits.py +++ b/aimet_ml/model_selection/splits.py @@ -126,125 +126,125 @@ def stratified_group_split( def split_dataset( dataset_df: pd.DataFrame, - test_fraction: Union[float, int] = 0.2, - val_n_splits: int = 5, + val_fraction: Union[float, int] = 0.1, + test_n_splits: int = 5, stratify_cols: Optional[Collection[str]] = None, group_cols: Optional[Collection[str]] = None, - test_split_name: str = "test", - dev_split_name: str = "dev", train_split_name_format: str = "train_fold_{}", val_split_name_format: str = "val_fold_{}", + test_split_name_format: str = "test_fold_{}", random_seed: int = 1414, ) -> Dict[str, pd.DataFrame]: """ - Split a dataset into development, test, and cross-validation sets with stratification and grouping. + Split a dataset into k-fold cross-validation sets with stratification and grouping. - The dataset will be split into a development set and a test set. The development set will then be further - split into k-fold cross-validation sets, each containing its own training and validation sets. - The final data splits include a test set, k training sets, and k validation sets. + The dataset will be split into k-fold cross-validation sets, each containing development and test sets. + For each fold, the development set will be further split into training and validation sets. + The final data splits include k test sets, k training sets, and k validation sets. Args: dataset_df (pd.DataFrame): The input DataFrame to be split. - test_fraction (Union[float, int], optional): The fraction of data to be used for testing. + val_fraction (Union[float, int], optional): The fraction of data to be used for validation. If a float is given, it's rounded to the nearest fraction. If an integer (n) is given, the fraction is calculated as 1/n. - Defaults to 0.2. - val_n_splits (int, optional): Number of cross-validation splits. Defaults to 5. + Defaults to 0.1. + test_n_splits (int, optional): Number of cross-validation splits. Defaults to 5. stratify_cols (Collection[str], optional): Column names for stratification. Defaults to None. group_cols (Collection[str], optional): Column names for grouping. Defaults to None. - test_split_name (str, optional): Name for the test split. Defaults to "test". - dev_split_name (str, optional): Name for the development split. Defaults to "dev". train_split_name_format (str, optional): Format for naming training splits. Defaults to "train_fold_{}". val_split_name_format (str, optional): Format for naming validation splits. Defaults to "val_fold_{}". + test_split_name_format (str, optional): Format for naming validation splits. Defaults to "test_fold_{}". random_seed (int, optional): Random seed for reproducibility. Defaults to 1414. Returns: Dict[str, pd.DataFrame]: A dictionary containing the split DataFrames. """ - if val_n_splits <= 1: - raise ValueError("val_n_splits must be greater than 1") + if test_n_splits <= 1: + raise ValueError("test_n_splits must be greater than 1") data_splits = dict() - # split into dev and test datasets - dev_dataset_df, test_dataset_df = stratified_group_split( - dataset_df, test_fraction, stratify_cols, group_cols, random_seed - ) - data_splits[dev_split_name] = dev_dataset_df - data_splits[test_split_name] = test_dataset_df - # cross-validation split - k_fold_splitter = get_splitter(stratify_cols, group_cols, val_n_splits, random_seed) + k_fold_splitter = get_splitter(stratify_cols, group_cols, test_n_splits, random_seed) - dev_stratify = join_cols(dev_dataset_df, stratify_cols) if stratify_cols else None - dev_groups = join_cols(dev_dataset_df, group_cols) if group_cols else None + stratify = join_cols(dataset_df, stratify_cols) if stratify_cols else None + groups = join_cols(dataset_df, group_cols) if group_cols else None - for n, (train_rows, val_rows) in enumerate( - k_fold_splitter.split(X=dev_dataset_df, y=dev_stratify, groups=dev_groups) - ): + for n, (dev_rows, test_rows) in enumerate(k_fold_splitter.split(X=dataset_df, y=stratify, groups=groups)): k = n + 1 - data_splits[train_split_name_format.format(k)] = dev_dataset_df.iloc[train_rows].reset_index(drop=True) - data_splits[val_split_name_format.format(k)] = dev_dataset_df.iloc[val_rows].reset_index(drop=True) + data_splits[test_split_name_format.format(k)] = dataset_df.iloc[test_rows].reset_index(drop=True) + + # split into training and validation sets + dev_dataset_df = dataset_df.iloc[dev_rows].reset_index(drop=True) + train_dataset_df, val_dataset_df = stratified_group_split( + dev_dataset_df, val_fraction, stratify_cols, group_cols, random_seed + ) + data_splits[train_split_name_format.format(k)] = train_dataset_df + data_splits[val_split_name_format.format(k)] = val_dataset_df return data_splits -def split_dataset_v2( +def split_dataset_single_test( dataset_df: pd.DataFrame, - val_fraction: Union[float, int] = 0.1, - test_n_splits: int = 5, + test_fraction: Union[float, int] = 0.2, + val_n_splits: int = 5, stratify_cols: Optional[Collection[str]] = None, group_cols: Optional[Collection[str]] = None, + test_split_name: str = "test", + dev_split_name: str = "dev", train_split_name_format: str = "train_fold_{}", val_split_name_format: str = "val_fold_{}", - test_split_name_format: str = "test_fold_{}", random_seed: int = 1414, ) -> Dict[str, pd.DataFrame]: """ - Split a dataset into k-fold cross-validation sets with stratification and grouping. + Split a dataset into development, test, and cross-validation sets with stratification and grouping. - The dataset will be split into k-fold cross-validation sets, each containing development and test sets. - For each fold, the development set will be further split into training and validation sets. - The final data splits include k test sets, k training sets, and k validation sets. + The dataset will be split into a development set and a test set. The development set will then be further + split into k-fold cross-validation sets, each containing its own training and validation sets. + The final data splits include a test set, k training sets, and k validation sets. Args: dataset_df (pd.DataFrame): The input DataFrame to be split. - val_fraction (Union[float, int], optional): The fraction of data to be used for validation. + test_fraction (Union[float, int], optional): The fraction of data to be used for testing. If a float is given, it's rounded to the nearest fraction. If an integer (n) is given, the fraction is calculated as 1/n. - Defaults to 0.1. - test_n_splits (int, optional): Number of cross-validation splits. Defaults to 5. + Defaults to 0.2. + val_n_splits (int, optional): Number of cross-validation splits. Defaults to 5. stratify_cols (Collection[str], optional): Column names for stratification. Defaults to None. group_cols (Collection[str], optional): Column names for grouping. Defaults to None. + test_split_name (str, optional): Name for the test split. Defaults to "test". + dev_split_name (str, optional): Name for the development split. Defaults to "dev". train_split_name_format (str, optional): Format for naming training splits. Defaults to "train_fold_{}". val_split_name_format (str, optional): Format for naming validation splits. Defaults to "val_fold_{}". - test_split_name_format (str, optional): Format for naming validation splits. Defaults to "test_fold_{}". random_seed (int, optional): Random seed for reproducibility. Defaults to 1414. Returns: Dict[str, pd.DataFrame]: A dictionary containing the split DataFrames. """ - if test_n_splits <= 1: - raise ValueError("test_n_splits must be greater than 1") + if val_n_splits <= 1: + raise ValueError("val_n_splits must be greater than 1") data_splits = dict() + # split into dev and test datasets + dev_dataset_df, test_dataset_df = stratified_group_split( + dataset_df, test_fraction, stratify_cols, group_cols, random_seed + ) + data_splits[dev_split_name] = dev_dataset_df + data_splits[test_split_name] = test_dataset_df + # cross-validation split - k_fold_splitter = get_splitter(stratify_cols, group_cols, test_n_splits, random_seed) + k_fold_splitter = get_splitter(stratify_cols, group_cols, val_n_splits, random_seed) - stratify = join_cols(dataset_df, stratify_cols) if stratify_cols else None - groups = join_cols(dataset_df, group_cols) if group_cols else None + dev_stratify = join_cols(dev_dataset_df, stratify_cols) if stratify_cols else None + dev_groups = join_cols(dev_dataset_df, group_cols) if group_cols else None - for n, (dev_rows, test_rows) in enumerate(k_fold_splitter.split(X=dataset_df, y=stratify, groups=groups)): + for n, (train_rows, val_rows) in enumerate( + k_fold_splitter.split(X=dev_dataset_df, y=dev_stratify, groups=dev_groups) + ): k = n + 1 - data_splits[test_split_name_format.format(k)] = dataset_df.iloc[test_rows].reset_index(drop=True) - - # split into training and validation sets - dev_dataset_df = dataset_df.iloc[dev_rows].reset_index(drop=True) - train_dataset_df, val_dataset_df = stratified_group_split( - dev_dataset_df, val_fraction, stratify_cols, group_cols, random_seed - ) - data_splits[train_split_name_format.format(k)] = train_dataset_df - data_splits[val_split_name_format.format(k)] = val_dataset_df + data_splits[train_split_name_format.format(k)] = dev_dataset_df.iloc[train_rows].reset_index(drop=True) + data_splits[val_split_name_format.format(k)] = dev_dataset_df.iloc[val_rows].reset_index(drop=True) return data_splits diff --git a/tests/model_selection/test_splits.py b/tests/model_selection/test_splits.py index 41af3be..016a17c 100644 --- a/tests/model_selection/test_splits.py +++ b/tests/model_selection/test_splits.py @@ -5,7 +5,13 @@ import pytest from sklearn.model_selection import GroupKFold, KFold, StratifiedGroupKFold, StratifiedKFold -from aimet_ml.model_selection import get_splitter, join_cols, split_dataset, split_dataset_v2, stratified_group_split +from aimet_ml.model_selection import ( + get_splitter, + join_cols, + split_dataset, + split_dataset_single_test, + stratified_group_split, +) def validate_splits( @@ -201,7 +207,7 @@ def test_stratified_group_split( ), ], ) -def test_split_dataset( +def test_split_dataset_single_test( test_fraction: Union[float, int], val_n_splits: int, stratify_cols: Optional[Collection[str]], @@ -214,7 +220,7 @@ def test_split_dataset( sample_df: pd.DataFrame, ): """ - Test the split_dataset function. + Test the split_dataset_single_test function. Args: test_fraction (Union[float, int], optional): The fraction of data to be used for testing. @@ -231,7 +237,7 @@ def test_split_dataset( sample_df (pd.DataFrame): A sample DataFrame. """ with expectation: - data_splits = split_dataset( + data_splits = split_dataset_single_test( sample_df, test_fraction, val_n_splits, @@ -284,7 +290,7 @@ def test_split_dataset( ), ], ) -def test_split_dataset_v2( +def test_split_dataset( val_fraction: Union[float, int], test_n_splits: int, stratify_cols: Optional[Collection[str]], @@ -296,7 +302,7 @@ def test_split_dataset_v2( sample_df: pd.DataFrame, ): """ - Test the split_dataset_v2 function. + Test the split_dataset function. Args: val_fraction (Union[float, int], optional): The fraction of data to be used for validation. @@ -312,7 +318,7 @@ def test_split_dataset_v2( sample_df (pd.DataFrame): A sample DataFrame. """ with expectation: - data_splits = split_dataset_v2( + data_splits = split_dataset( sample_df, val_fraction, test_n_splits,