diff --git a/python/cuml/internals/input_utils.py b/python/cuml/internals/input_utils.py index 76c1d8a5b7..e8567e9c42 100644 --- a/python/cuml/internals/input_utils.py +++ b/python/cuml/internals/input_utils.py @@ -15,6 +15,7 @@ # from collections import namedtuple +from typing import Literal from cuml.internals.array import CumlArray from cuml.internals.array_sparse import SparseCumlArray @@ -46,6 +47,7 @@ cp_ndarray = gpu_only_import_from("cupy", "ndarray") CudfSeries = gpu_only_import_from("cudf", "Series") CudfDataFrame = gpu_only_import_from("cudf", "DataFrame") +CudfIndex = gpu_only_import_from("cudf", "Index") DaskCudfSeries = gpu_only_import_from("dask_cudf", "Series") DaskCudfDataFrame = gpu_only_import_from("dask_cudf", "DataFrame") np_ndarray = cpu_only_import_from("numpy", "ndarray") @@ -64,6 +66,7 @@ nvtx_annotate = gpu_only_import_from("nvtx", "annotate", alt=null_decorator) PandasSeries = cpu_only_import_from("pandas", "Series") PandasDataFrame = cpu_only_import_from("pandas", "DataFrame") +PandasIndex = cpu_only_import_from("pandas", "Index") cuml_array = namedtuple("cuml_array", "array n_rows n_cols dtype") @@ -73,6 +76,7 @@ np_ndarray: "numpy", PandasSeries: "pandas", PandasDataFrame: "pandas", + PandasIndex: "pandas", } @@ -80,6 +84,7 @@ _input_type_to_str[cp_ndarray] = "cupy" _input_type_to_str[CudfSeries] = "cudf" _input_type_to_str[CudfDataFrame] = "cudf" + _input_type_to_str[CudfIndex] = "cudf" _input_type_to_str[NumbaDeviceNDArrayBase] = "numba" except UnavailableError: pass @@ -160,9 +165,15 @@ def get_supported_input_type(X): if isinstance(X, PandasSeries): return PandasSeries + if isinstance(X, PandasIndex): + return PandasIndex + if isinstance(X, CudfDataFrame): return CudfDataFrame + if isinstance(X, CudfIndex): + return CudfIndex + try: if numba_cuda.devicearray.is_cuda_ndarray(X): return numba_cuda.devicearray.DeviceNDArrayBase @@ -205,6 +216,21 @@ def determine_array_type(X): return _input_type_to_str.get(gen_type, None) +def determine_df_obj_type(X): + if X is None: + return None + + # Get the generic type + gen_type = get_supported_input_type(X) + + if gen_type in (CudfDataFrame, PandasDataFrame): + return "dataframe" + elif gen_type in (CudfSeries, PandasSeries): + return "series" + + return None + + def determine_array_dtype(X): if X is None: @@ -575,3 +601,27 @@ def sparse_scipy_to_cp(sp, dtype): v = cp.asarray(values, dtype=dtype) return cupyx.scipy.sparse.coo_matrix((v, (r, c)), sp.shape) + + +def output_to_df_obj_like( + X_out: CumlArray, X_in, output_type: Literal["series", "dataframe"] +): + """Cast CumlArray `X_out` to the dataframe / series type as `X_in` + `CumlArray` abstracts away the dataframe / series metadata, when API + methods needs to return a dataframe / series matching original input + metadata, this function can copy input metadata to output. + """ + + if output_type not in ["series", "dataframe"]: + raise ValueError( + f'output_type must be either "series" or "dataframe" : {output_type}' + ) + + out = None + if output_type == "series": + out = X_out.to_output("series") + out.name = X_in.name + elif output_type == "dataframe": + out = X_out.to_output("dataframe") + out.columns = X_in.columns + return out diff --git a/python/cuml/model_selection/_split.py b/python/cuml/model_selection/_split.py index cb58db4f5f..0727f82c82 100644 --- a/python/cuml/model_selection/_split.py +++ b/python/cuml/model_selection/_split.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2023, NVIDIA CORPORATION. +# Copyright (c) 2019-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,10 +13,16 @@ # limitations under the License. # -from typing import Optional, Union +from typing import Optional, Union, List, Tuple from cuml.common import input_to_cuml_array -from cuml.internals.array import array_to_memory_order +from cuml.internals.input_utils import ( + determine_array_type, + determine_df_obj_type, + output_to_df_obj_like, +) +from cuml.internals.mem_type import MemoryType +from cuml.internals.array import array_to_memory_order, CumlArray from cuml.internals.safe_imports import ( cpu_only_import, gpu_only_import, @@ -31,68 +37,48 @@ cuda = gpu_only_import_from("numba", "cuda") -def _stratify_split( - X, stratify, labels, n_train, n_test, x_numba, y_numba, random_state -): +def _compute_stratify_split_indices( + indices: cp.ndarray, + stratify: CumlArray, + n_train: int, + n_test: int, + random_state: cp.random.RandomState, +) -> Tuple[cp.ndarray]: """ - Function to perform a stratified split based on stratify column. + Compute the indices for stratified split based on stratify keys. Based on scikit-learn stratified split implementation. Parameters ---------- - X, y: Shuffled input data and labels - stratify: column to be stratified on. + indices: cupy array + Indices used to shuffle input data + stratify: CumlArray + Keys used for stratification n_train: Number of samples in train set n_test: number of samples in test set - x_numba: Determines whether the data should be converted to numba - y_numba: Determines whether the labales should be converted to numba + random_state: cupy RandomState + Random state used for shuffling stratify keys Returns ------- - X_train, X_test: Data X divided into train and test sets - y_train, y_test: Labels divided into train and test sets + train_indices, test_indices: + Indices of inputs from which train and test sets are gathered """ - x_cudf = False - labels_cudf = False - - if isinstance(X, cudf.DataFrame): - x_cudf = True - elif hasattr(X, "__cuda_array_interface__"): - X = cp.asarray(X) - x_order = array_to_memory_order(X) - # labels and stratify will be only cp arrays - if isinstance(labels, cudf.Series): - labels_cudf = True - labels = labels.values - elif hasattr(labels, "__cuda_array_interface__"): - labels = cp.asarray(labels) - elif isinstance(stratify, cudf.DataFrame): - # ensuring it has just one column - if labels.shape[1] != 1: - raise ValueError( - "Expected one column for labels, but found df" - "with shape = %d" % (labels.shape) - ) - labels_cudf = True - labels = labels[0].values + if indices.ndim != 1: + raise ValueError( + "Expected one one dimension for indices, but found array" + "with shape = %d" % (indices.shape) + ) - labels_order = array_to_memory_order(labels) + if stratify.ndim != 1: + raise ValueError( + "Expected one one dimension for stratify, but found array" + "with shape = %d" % (stratify.shape) + ) # Converting to cupy array removes the need to add an if-else block # for startify column - if isinstance(stratify, cudf.Series): - stratify = stratify.values - elif hasattr(stratify, "__cuda_array_interface__"): - stratify = cp.asarray(stratify) - elif isinstance(stratify, cudf.DataFrame): - # ensuring it has just one column - if stratify.shape[1] != 1: - raise ValueError( - "Expected one column, but found column" - "with shape = %d" % (stratify.shape) - ) - stratify = stratify[0].values classes, stratify_indices = cp.unique(stratify, return_inverse=True) @@ -112,84 +98,31 @@ def _stratify_split( "equal to the number of classes = %d" % (n_train, n_classes) ) - class_indices = cp.split( + # List of length n_classes. Each element contains indices of that class. + class_indices: List[cp.ndarray] = cp.split( cp.argsort(stratify_indices), cp.cumsum(class_counts)[:-1].tolist() ) - X_train = None - - # random_state won't be None or int, that's handled earlier - if isinstance(random_state, np.random.RandomState): - random_state = cp.random.RandomState(seed=random_state.get_state()[1]) - # Break ties n_i = _approximate_mode(class_counts, n_train, random_state) class_counts_remaining = class_counts - n_i t_i = _approximate_mode(class_counts_remaining, n_test, random_state) + train_indices_partials = [] + test_indices_partials = [] for i in range(n_classes): permutation = random_state.permutation(class_counts[i].item()) perm_indices_class_i = class_indices[i].take(permutation) - y_train_i = cp.array( - labels[perm_indices_class_i[: n_i[i]]], order=labels_order - ) - y_test_i = cp.array( - labels[perm_indices_class_i[n_i[i] : n_i[i] + t_i[i]]], - order=labels_order, + train_indices_partials.append(perm_indices_class_i[: n_i[i]]) + test_indices_partials.append( + perm_indices_class_i[n_i[i] : n_i[i] + t_i[i]] ) - if hasattr(X, "__cuda_array_interface__") or isinstance( - X, cupyx.scipy.sparse.csr_matrix - ): - X_train_i = cp.array( - X[perm_indices_class_i[: n_i[i]]], order=x_order - ) - X_test_i = cp.array( - X[perm_indices_class_i[n_i[i] : n_i[i] + t_i[i]]], - order=x_order, - ) - if X_train is None: - X_train = cp.array(X_train_i, order=x_order) - y_train = cp.array(y_train_i, order=labels_order) - X_test = cp.array(X_test_i, order=x_order) - y_test = cp.array(y_test_i, order=labels_order) - else: - X_train = cp.concatenate([X_train, X_train_i], axis=0) - X_test = cp.concatenate([X_test, X_test_i], axis=0) - y_train = cp.concatenate([y_train, y_train_i], axis=0) - y_test = cp.concatenate([y_test, y_test_i], axis=0) - - elif x_cudf: - X_train_i = X.iloc[perm_indices_class_i[: n_i[i]]] - X_test_i = X.iloc[perm_indices_class_i[n_i[i] : n_i[i] + t_i[i]]] - - if X_train is None: - X_train = X_train_i - y_train = y_train_i - X_test = X_test_i - y_test = y_test_i - else: - X_train = cudf.concat([X_train, X_train_i], ignore_index=False) - X_test = cudf.concat([X_test, X_test_i], ignore_index=False) - y_train = cp.concatenate([y_train, y_train_i], axis=0) - y_test = cp.concatenate([y_test, y_test_i], axis=0) - - if x_numba: - X_train = cuda.as_cuda_array(X_train) - X_test = cuda.as_cuda_array(X_test) - elif x_cudf: - X_train = cudf.DataFrame(X_train) - X_test = cudf.DataFrame(X_test) - - if y_numba: - y_train = cuda.as_cuda_array(y_train) - y_test = cuda.as_cuda_array(y_test) - elif labels_cudf: - y_train = cudf.Series(y_train) - y_test = cudf.Series(y_test) - - return X_train, X_test, y_train, y_test + train_indices = cp.concatenate(train_indices_partials, axis=0) + test_indices = cp.concatenate(test_indices_partials, axis=0) + + return indices[train_indices], indices[test_indices] def _approximate_mode(class_counts, n_draws, rng): @@ -332,103 +265,78 @@ def train_test_split( string" ) - # todo: this check will be replaced with upcoming improvements - # to input_utils - # + x_order = array_to_memory_order(X) + X_arr, X_row, *_ = input_to_cuml_array(X, order=x_order) if y is not None: - if not hasattr(X, "__cuda_array_interface__") and not isinstance( - X, cudf.DataFrame - ): - raise TypeError( - "X needs to be either a cuDF DataFrame, Series or \ - a cuda_array_interface compliant array." - ) - - if not hasattr(y, "__cuda_array_interface__") and not isinstance( - y, cudf.DataFrame - ): - raise TypeError( - "y needs to be either a cuDF DataFrame, Series or \ - a cuda_array_interface compliant array." - ) - - if X.shape[0] != y.shape[0]: + y_order = array_to_memory_order(y) + y_arr, y_row, *_ = input_to_cuml_array(y, order=y_order) + if X_row != y_row: raise ValueError( "X and y must have the same first dimension" - "(found {} and {})".format(X.shape[0], y.shape[0]) - ) - else: - if not hasattr(X, "__cuda_array_interface__") and not isinstance( - X, cudf.DataFrame - ): - raise TypeError( - "X needs to be either a cuDF DataFrame, Series or \ - a cuda_array_interface compliant object." + f"(found {X_row} and {y_row})" ) if isinstance(train_size, float): if not 0 <= train_size <= 1: raise ValueError( "proportion train_size should be between" - "0 and 1 (found {})".format(train_size) + f"0 and 1 (found {train_size})" ) if isinstance(train_size, int): - if not 0 <= train_size <= X.shape[0]: + if not 0 <= train_size <= X_row: raise ValueError( "Number of instances train_size should be between 0 and the" - "first dimension of X (found {})".format(train_size) + f"first dimension of X (found {train_size})" ) if isinstance(test_size, float): if not 0 <= test_size <= 1: raise ValueError( "proportion test_size should be between" - "0 and 1 (found {})".format(train_size) + f"0 and 1 (found {train_size})" ) if isinstance(test_size, int): - if not 0 <= test_size <= X.shape[0]: + if not 0 <= test_size <= X_row: raise ValueError( "Number of instances test_size should be between 0 and the" - "first dimension of X (found {})".format(test_size) + f"first dimension of X (found {test_size})" ) - x_numba = cuda.devicearray.is_cuda_ndarray(X) - y_numba = cuda.devicearray.is_cuda_ndarray(y) - # Determining sizes of splits if isinstance(train_size, float): - train_size = int(X.shape[0] * train_size) + train_size = int(X_row * train_size) if test_size is None: if train_size is None: - train_size = int(X.shape[0] * 0.75) + train_size = int(X_row * 0.75) - test_size = X.shape[0] - train_size + test_size = X_row - train_size if isinstance(test_size, float): - test_size = int(X.shape[0] * test_size) + test_size = int(X_row * test_size) if train_size is None: - train_size = X.shape[0] - test_size + train_size = X_row - test_size elif isinstance(test_size, int): if train_size is None: - train_size = X.shape[0] - test_size + train_size = X_row - test_size + # Compute training set and test set indices if shuffle: - # Shuffle the data + idxs = cp.arange(X_row) + + # Compute shuffle indices if random_state is None or isinstance(random_state, int): - idxs = cp.arange(X.shape[0]) random_state = cp.random.RandomState(seed=random_state) - elif isinstance(random_state, cp.random.RandomState): - idxs = cp.arange(X.shape[0]) - elif isinstance(random_state, np.random.RandomState): - idxs = np.arange(X.shape[0]) + random_state = cp.random.RandomState( + seed=random_state.get_state()[1] + ) - else: + elif not isinstance(random_state, cp.random.RandomState): raise TypeError( "`random_state` must be an int, NumPy RandomState \ or CuPy RandomState." @@ -436,77 +344,74 @@ def train_test_split( random_state.shuffle(idxs) - if isinstance(X, cudf.DataFrame) or isinstance(X, cudf.Series): - X = X.iloc[idxs] - - elif hasattr(X, "__cuda_array_interface__"): - # numba (and therefore rmm device_array) does not support - # fancy indexing - X = cp.asarray(X)[idxs] - - if isinstance(y, cudf.DataFrame) or isinstance(y, cudf.Series): - y = y.iloc[idxs] - - elif hasattr(y, "__cuda_array_interface__"): - y = cp.asarray(y)[idxs] - if stratify is not None: - if isinstance(stratify, cudf.DataFrame) or isinstance( - stratify, cudf.Series - ): - stratify = stratify.iloc[idxs] + stratify, *_ = input_to_cuml_array(stratify) + stratify = stratify[idxs] - elif hasattr(stratify, "__cuda_array_interface__"): - stratify = cp.asarray(stratify)[idxs] - - split_return = _stratify_split( - X, + (train_indices, test_indices,) = _compute_stratify_split_indices( + idxs, stratify, - y, train_size, test_size, - x_numba, - y_numba, random_state, ) - return split_return - # If not stratified, perform train_test_split splicing - x_order = array_to_memory_order(X) + else: + train_indices = idxs[:train_size] + test_indices = idxs[-1 * test_size :] + else: + train_indices = range(0, train_size) + test_indices = range(-1 * test_size, 0) + + # Gather from indices + X_train = X_arr[train_indices] + X_test = X_arr[test_indices] + if y is not None: + y_train = y_arr[train_indices] + y_test = y_arr[test_indices] - if y is None: - y_order = None + # Coerce output to original input type + if ty := determine_df_obj_type(X): + x_type = ty else: - y_order = array_to_memory_order(y) + x_type = determine_array_type(X) + + if ty := determine_df_obj_type(y): + y_type = ty + else: + y_type = determine_array_type(y) - if hasattr(X, "__cuda_array_interface__") or isinstance( - X, cupyx.scipy.sparse.csr_matrix - ): - X_train = cp.array(X[0:train_size], order=x_order) - X_test = cp.array(X[-1 * test_size :], order=x_order) - if y is not None: - y_train = cp.array(y[0:train_size], order=y_order) - y_test = cp.array(y[-1 * test_size :], order=y_order) - elif isinstance(X, cudf.DataFrame): - X_train = X.iloc[0:train_size] - X_test = X.iloc[-1 * test_size :] - if y is not None: - if isinstance(y, cudf.Series): - y_train = y.iloc[0:train_size] - y_test = y.iloc[-1 * test_size :] - elif hasattr(y, "__cuda_array_interface__") or isinstance( - y, cupyx.scipy.sparse.csr_matrix - ): - y_train = cp.array(y[0:train_size], order=y_order) - y_test = cp.array(y[-1 * test_size :], order=y_order) - - if x_numba: - X_train = cuda.as_cuda_array(X_train) - X_test = cuda.as_cuda_array(X_test) - - if y_numba: - y_train = cuda.as_cuda_array(y_train) - y_test = cuda.as_cuda_array(y_test) + if x_type in ("series", "dataframe"): + X_train = output_to_df_obj_like(X_train, X, x_type) + X_test = output_to_df_obj_like(X_test, X, x_type) + + if determine_array_type(X.index) == "pandas": + if isinstance(train_indices, cp.ndarray): + train_indices = train_indices.get() + if isinstance(test_indices, cp.ndarray): + test_indices = test_indices.get() + + X_train.index = X.index[train_indices] + X_test.index = X.index[test_indices] + else: + X_train = X_train.to_output(x_type) + X_test = X_test.to_output(x_type) + + if y_type in ("series", "dataframe"): + y_train = output_to_df_obj_like(y_train, y, y_type) + y_test = output_to_df_obj_like(y_test, y, y_type) + + if determine_array_type(y.index) == "pandas": + if isinstance(train_indices, cp.ndarray): + train_indices = train_indices.get() + if isinstance(test_indices, cp.ndarray): + test_indices = test_indices.get() + + y_train.index = y.index[train_indices] + y_test.index = y.index[test_indices] + elif y_type is not None: + y_train = y_train.to_output(y_type) + y_test = y_test.to_output(y_type) if y is not None: return X_train, X_test, y_train, y_test diff --git a/python/cuml/tests/test_train_test_split.py b/python/cuml/tests/test_train_test_split.py index e0f450176b..c6b1ec0a87 100644 --- a/python/cuml/tests/test_train_test_split.py +++ b/python/cuml/tests/test_train_test_split.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2023, NVIDIA CORPORATION. +# Copyright (c) 2019-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -23,14 +23,34 @@ cudf = gpu_only_import("cudf") cp = gpu_only_import("cupy") np = cpu_only_import("numpy") +pd = cpu_only_import("pandas") cuda = gpu_only_import_from("numba", "cuda") -test_array_input_types = ["numba", "cupy"] - test_seeds = ["int", "cupy", "numpy"] +@pytest.fixture( + params=[cuda.to_device, cp.asarray, cudf, pd], + ids=["to_numba", "to_cupy", "to_cudf", "to_pandas"], +) +def convert_to_type(request): + if request.param in (cudf, pd): + + def ctor(X): + if isinstance(X, cp.ndarray) and request.param == pd: + X = X.get() + + if X.ndim > 1: + return request.param.DataFrame(X) + else: + return request.param.Series(X) + + return ctor + + return request.param + + @pytest.mark.parametrize("train_size", [0.2, 0.6, 0.8]) @pytest.mark.parametrize("shuffle", [True, False]) def test_split_dataframe(train_size, shuffle): @@ -153,21 +173,23 @@ def test_random_state(seed_type): assert y_test.equals(y_test2) -@pytest.mark.parametrize("type", test_array_input_types) +@pytest.mark.parametrize( + "X, y", + [ + (np.arange(-100, 0), np.arange(100)), + ( + np.zeros((100, 10)) + np.arange(100).reshape(100, 1), + np.arange(100).reshape(100, 1), + ), + ], +) @pytest.mark.parametrize("test_size", [0.2, 0.4, None]) @pytest.mark.parametrize("train_size", [0.6, 0.8, None]) @pytest.mark.parametrize("shuffle", [True, False]) -def test_array_split(type, test_size, train_size, shuffle): - X = np.zeros((100, 10)) + np.arange(100).reshape(100, 1) - y = np.arange(100).reshape(100, 1) - - if type == "cupy": - X = cp.asarray(X) - y = cp.asarray(y) +def test_array_split(X, y, convert_to_type, test_size, train_size, shuffle): - if type == "numba": - X = cuda.to_device(X) - y = cuda.to_device(y) + X = convert_to_type(X) + y = convert_to_type(y) X_train, X_test, y_train, y_test = train_test_split( X, @@ -251,17 +273,19 @@ def test_split_df_single_argument(test_size, train_size, shuffle): assert X_test.shape[0] == (int)(X.shape[0] * test_size) -@pytest.mark.parametrize("type", test_array_input_types) +@pytest.mark.parametrize( + "X", + [np.arange(-100, 0), np.zeros((100, 10)) + np.arange(100).reshape(100, 1)], +) @pytest.mark.parametrize("test_size", [0.2, 0.4, None]) @pytest.mark.parametrize("train_size", [0.6, 0.8, None]) @pytest.mark.parametrize("shuffle", [True, False]) -def test_split_array_single_argument(type, test_size, train_size, shuffle): - X = np.zeros((100, 10)) + np.arange(100).reshape(100, 1) - if type == "cupy": - X = cp.asarray(X) +def test_split_array_single_argument( + X, convert_to_type, test_size, train_size, shuffle +): + + X = convert_to_type(X) - if type == "numba": - X = cuda.to_device(X) X_train, X_test = train_test_split( X, train_size=train_size, @@ -293,20 +317,14 @@ def test_split_array_single_argument(type, test_size, train_size, shuffle): assert X_rec == X -@pytest.mark.parametrize("type", test_array_input_types) @pytest.mark.parametrize("test_size", [0.2, 0.4, None]) @pytest.mark.parametrize("train_size", [0.6, 0.8, None]) -def test_stratified_split(type, test_size, train_size): +def test_stratified_split(convert_to_type, test_size, train_size): # For more tolerance and reliable estimates X, y = make_classification(n_samples=10000) - if type == "cupy": - X = cp.asarray(X) - y = cp.asarray(y) - - if type == "numba": - X = cuda.to_device(X) - y = cuda.to_device(y) + X = convert_to_type(X) + y = convert_to_type(y) def counts(y): _, y_indices = cp.unique(y, return_inverse=True)