Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rename data splitting functions #8

Merged
merged 1 commit into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aimet_ml/model_selection/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .splits import get_splitter, join_cols, split_dataset, split_dataset_v2, stratified_group_split
from .splits import get_splitter, join_cols, split_dataset, split_dataset_single_test, stratified_group_split
114 changes: 57 additions & 57 deletions aimet_ml/model_selection/splits.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,125 +126,125 @@ def stratified_group_split(

def split_dataset(
dataset_df: pd.DataFrame,
test_fraction: Union[float, int] = 0.2,
val_n_splits: int = 5,
val_fraction: Union[float, int] = 0.1,
test_n_splits: int = 5,
stratify_cols: Optional[Collection[str]] = None,
group_cols: Optional[Collection[str]] = None,
test_split_name: str = "test",
dev_split_name: str = "dev",
train_split_name_format: str = "train_fold_{}",
val_split_name_format: str = "val_fold_{}",
test_split_name_format: str = "test_fold_{}",
random_seed: int = 1414,
) -> Dict[str, pd.DataFrame]:
"""
Split a dataset into development, test, and cross-validation sets with stratification and grouping.
Split a dataset into k-fold cross-validation sets with stratification and grouping.

The dataset will be split into a development set and a test set. The development set will then be further
split into k-fold cross-validation sets, each containing its own training and validation sets.
The final data splits include a test set, k training sets, and k validation sets.
The dataset will be split into k-fold cross-validation sets, each containing development and test sets.
For each fold, the development set will be further split into training and validation sets.
The final data splits include k test sets, k training sets, and k validation sets.

Args:
dataset_df (pd.DataFrame): The input DataFrame to be split.
test_fraction (Union[float, int], optional): The fraction of data to be used for testing.
val_fraction (Union[float, int], optional): The fraction of data to be used for validation.
If a float is given, it's rounded to the nearest fraction.
If an integer (n) is given, the fraction is calculated as 1/n.
Defaults to 0.2.
val_n_splits (int, optional): Number of cross-validation splits. Defaults to 5.
Defaults to 0.1.
test_n_splits (int, optional): Number of cross-validation splits. Defaults to 5.
stratify_cols (Collection[str], optional): Column names for stratification. Defaults to None.
group_cols (Collection[str], optional): Column names for grouping. Defaults to None.
test_split_name (str, optional): Name for the test split. Defaults to "test".
dev_split_name (str, optional): Name for the development split. Defaults to "dev".
train_split_name_format (str, optional): Format for naming training splits. Defaults to "train_fold_{}".
val_split_name_format (str, optional): Format for naming validation splits. Defaults to "val_fold_{}".
test_split_name_format (str, optional): Format for naming validation splits. Defaults to "test_fold_{}".
random_seed (int, optional): Random seed for reproducibility. Defaults to 1414.

Returns:
Dict[str, pd.DataFrame]: A dictionary containing the split DataFrames.
"""
if val_n_splits <= 1:
raise ValueError("val_n_splits must be greater than 1")
if test_n_splits <= 1:
raise ValueError("test_n_splits must be greater than 1")

data_splits = dict()

# split into dev and test datasets
dev_dataset_df, test_dataset_df = stratified_group_split(
dataset_df, test_fraction, stratify_cols, group_cols, random_seed
)
data_splits[dev_split_name] = dev_dataset_df
data_splits[test_split_name] = test_dataset_df

# cross-validation split
k_fold_splitter = get_splitter(stratify_cols, group_cols, val_n_splits, random_seed)
k_fold_splitter = get_splitter(stratify_cols, group_cols, test_n_splits, random_seed)

dev_stratify = join_cols(dev_dataset_df, stratify_cols) if stratify_cols else None
dev_groups = join_cols(dev_dataset_df, group_cols) if group_cols else None
stratify = join_cols(dataset_df, stratify_cols) if stratify_cols else None
groups = join_cols(dataset_df, group_cols) if group_cols else None

for n, (train_rows, val_rows) in enumerate(
k_fold_splitter.split(X=dev_dataset_df, y=dev_stratify, groups=dev_groups)
):
for n, (dev_rows, test_rows) in enumerate(k_fold_splitter.split(X=dataset_df, y=stratify, groups=groups)):
k = n + 1
data_splits[train_split_name_format.format(k)] = dev_dataset_df.iloc[train_rows].reset_index(drop=True)
data_splits[val_split_name_format.format(k)] = dev_dataset_df.iloc[val_rows].reset_index(drop=True)
data_splits[test_split_name_format.format(k)] = dataset_df.iloc[test_rows].reset_index(drop=True)

# split into training and validation sets
dev_dataset_df = dataset_df.iloc[dev_rows].reset_index(drop=True)
train_dataset_df, val_dataset_df = stratified_group_split(
dev_dataset_df, val_fraction, stratify_cols, group_cols, random_seed
)
data_splits[train_split_name_format.format(k)] = train_dataset_df
data_splits[val_split_name_format.format(k)] = val_dataset_df

return data_splits


def split_dataset_v2(
def split_dataset_single_test(
dataset_df: pd.DataFrame,
val_fraction: Union[float, int] = 0.1,
test_n_splits: int = 5,
test_fraction: Union[float, int] = 0.2,
val_n_splits: int = 5,
stratify_cols: Optional[Collection[str]] = None,
group_cols: Optional[Collection[str]] = None,
test_split_name: str = "test",
dev_split_name: str = "dev",
train_split_name_format: str = "train_fold_{}",
val_split_name_format: str = "val_fold_{}",
test_split_name_format: str = "test_fold_{}",
random_seed: int = 1414,
) -> Dict[str, pd.DataFrame]:
"""
Split a dataset into k-fold cross-validation sets with stratification and grouping.
Split a dataset into development, test, and cross-validation sets with stratification and grouping.

The dataset will be split into k-fold cross-validation sets, each containing development and test sets.
For each fold, the development set will be further split into training and validation sets.
The final data splits include k test sets, k training sets, and k validation sets.
The dataset will be split into a development set and a test set. The development set will then be further
split into k-fold cross-validation sets, each containing its own training and validation sets.
The final data splits include a test set, k training sets, and k validation sets.

Args:
dataset_df (pd.DataFrame): The input DataFrame to be split.
val_fraction (Union[float, int], optional): The fraction of data to be used for validation.
test_fraction (Union[float, int], optional): The fraction of data to be used for testing.
If a float is given, it's rounded to the nearest fraction.
If an integer (n) is given, the fraction is calculated as 1/n.
Defaults to 0.1.
test_n_splits (int, optional): Number of cross-validation splits. Defaults to 5.
Defaults to 0.2.
val_n_splits (int, optional): Number of cross-validation splits. Defaults to 5.
stratify_cols (Collection[str], optional): Column names for stratification. Defaults to None.
group_cols (Collection[str], optional): Column names for grouping. Defaults to None.
test_split_name (str, optional): Name for the test split. Defaults to "test".
dev_split_name (str, optional): Name for the development split. Defaults to "dev".
train_split_name_format (str, optional): Format for naming training splits. Defaults to "train_fold_{}".
val_split_name_format (str, optional): Format for naming validation splits. Defaults to "val_fold_{}".
test_split_name_format (str, optional): Format for naming validation splits. Defaults to "test_fold_{}".
random_seed (int, optional): Random seed for reproducibility. Defaults to 1414.

Returns:
Dict[str, pd.DataFrame]: A dictionary containing the split DataFrames.
"""
if test_n_splits <= 1:
raise ValueError("test_n_splits must be greater than 1")
if val_n_splits <= 1:
raise ValueError("val_n_splits must be greater than 1")

data_splits = dict()

# split into dev and test datasets
dev_dataset_df, test_dataset_df = stratified_group_split(
dataset_df, test_fraction, stratify_cols, group_cols, random_seed
)
data_splits[dev_split_name] = dev_dataset_df
data_splits[test_split_name] = test_dataset_df

# cross-validation split
k_fold_splitter = get_splitter(stratify_cols, group_cols, test_n_splits, random_seed)
k_fold_splitter = get_splitter(stratify_cols, group_cols, val_n_splits, random_seed)

stratify = join_cols(dataset_df, stratify_cols) if stratify_cols else None
groups = join_cols(dataset_df, group_cols) if group_cols else None
dev_stratify = join_cols(dev_dataset_df, stratify_cols) if stratify_cols else None
dev_groups = join_cols(dev_dataset_df, group_cols) if group_cols else None

for n, (dev_rows, test_rows) in enumerate(k_fold_splitter.split(X=dataset_df, y=stratify, groups=groups)):
for n, (train_rows, val_rows) in enumerate(
k_fold_splitter.split(X=dev_dataset_df, y=dev_stratify, groups=dev_groups)
):
k = n + 1
data_splits[test_split_name_format.format(k)] = dataset_df.iloc[test_rows].reset_index(drop=True)

# split into training and validation sets
dev_dataset_df = dataset_df.iloc[dev_rows].reset_index(drop=True)
train_dataset_df, val_dataset_df = stratified_group_split(
dev_dataset_df, val_fraction, stratify_cols, group_cols, random_seed
)
data_splits[train_split_name_format.format(k)] = train_dataset_df
data_splits[val_split_name_format.format(k)] = val_dataset_df
data_splits[train_split_name_format.format(k)] = dev_dataset_df.iloc[train_rows].reset_index(drop=True)
data_splits[val_split_name_format.format(k)] = dev_dataset_df.iloc[val_rows].reset_index(drop=True)

return data_splits
20 changes: 13 additions & 7 deletions tests/model_selection/test_splits.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@
import pytest
from sklearn.model_selection import GroupKFold, KFold, StratifiedGroupKFold, StratifiedKFold

from aimet_ml.model_selection import get_splitter, join_cols, split_dataset, split_dataset_v2, stratified_group_split
from aimet_ml.model_selection import (
get_splitter,
join_cols,
split_dataset,
split_dataset_single_test,
stratified_group_split,
)


def validate_splits(
Expand Down Expand Up @@ -201,7 +207,7 @@ def test_stratified_group_split(
),
],
)
def test_split_dataset(
def test_split_dataset_single_test(
test_fraction: Union[float, int],
val_n_splits: int,
stratify_cols: Optional[Collection[str]],
Expand All @@ -214,7 +220,7 @@ def test_split_dataset(
sample_df: pd.DataFrame,
):
"""
Test the split_dataset function.
Test the split_dataset_single_test function.

Args:
test_fraction (Union[float, int], optional): The fraction of data to be used for testing.
Expand All @@ -231,7 +237,7 @@ def test_split_dataset(
sample_df (pd.DataFrame): A sample DataFrame.
"""
with expectation:
data_splits = split_dataset(
data_splits = split_dataset_single_test(
sample_df,
test_fraction,
val_n_splits,
Expand Down Expand Up @@ -284,7 +290,7 @@ def test_split_dataset(
),
],
)
def test_split_dataset_v2(
def test_split_dataset(
val_fraction: Union[float, int],
test_n_splits: int,
stratify_cols: Optional[Collection[str]],
Expand All @@ -296,7 +302,7 @@ def test_split_dataset_v2(
sample_df: pd.DataFrame,
):
"""
Test the split_dataset_v2 function.
Test the split_dataset function.

Args:
val_fraction (Union[float, int], optional): The fraction of data to be used for validation.
Expand All @@ -312,7 +318,7 @@ def test_split_dataset_v2(
sample_df (pd.DataFrame): A sample DataFrame.
"""
with expectation:
data_splits = split_dataset_v2(
data_splits = split_dataset(
sample_df,
val_fraction,
test_n_splits,
Expand Down
Loading