diff --git a/python/cuml/_thirdparty/sklearn/preprocessing/__init__.py b/python/cuml/_thirdparty/sklearn/preprocessing/__init__.py index 3d8db2c1dc..5c034e0aa1 100644 --- a/python/cuml/_thirdparty/sklearn/preprocessing/__init__.py +++ b/python/cuml/_thirdparty/sklearn/preprocessing/__init__.py @@ -22,9 +22,14 @@ from ._imputation import MissingIndicator from ._discretization import KBinsDiscretizer +from ._function_transformer import FunctionTransformer + +from ._column_transformer import ColumnTransformer, \ + make_column_transformer, make_column_selector + + __all__ = [ 'Binarizer', - 'FunctionTransformer', 'KBinsDiscretizer', 'KernelCenterer', 'LabelBinarizer', @@ -41,6 +46,8 @@ 'StandardScaler', 'SimpleImputer', 'MissingIndicator', + 'ColumnTransformer', + 'FunctionTransformer', 'add_dummy_feature', 'PolynomialFeatures', 'binarize', @@ -52,4 +59,6 @@ 'label_binarize', 'quantile_transform', 'power_transform', + 'make_column_selector', + 'make_column_transformer' ] diff --git a/python/cuml/_thirdparty/sklearn/preprocessing/_column_transformer.py b/python/cuml/_thirdparty/sklearn/preprocessing/_column_transformer.py new file mode 100644 index 0000000000..be29025ebd --- /dev/null +++ b/python/cuml/_thirdparty/sklearn/preprocessing/_column_transformer.py @@ -0,0 +1,1169 @@ +# Original authors from Sckit-Learn: +# Andreas Mueller +# Joris Van den Bossche +# License: BSD + +# This code originates from the Scikit-Learn library, +# it was since modified to allow GPU acceleration. +# This code is under BSD 3 clause license. +# Authors mentioned above do not endorse or promote this production. + + +from itertools import chain +from itertools import compress +from joblib import Parallel +import functools +import timeit +import numbers +from sklearn.base import clone +from sklearn.utils import Bunch +from contextlib import contextmanager +from collections import defaultdict +import warnings + +from scipy import sparse as sp_sparse +from cupy import sparse as cu_sparse +import numpy as cpu_np +import cupy as np +import numba + +import cuml +from cuml.internals.global_settings import _global_settings_data +from cuml.common.array_sparse import SparseCumlArray +from cuml.internals import _deprecate_pos_args +from ..utils.skl_dependencies import TransformerMixin, BaseComposition, \ + BaseEstimator +from ..utils.validation import check_is_fitted +from ....thirdparty_adapters import check_array +from ..preprocessing import FunctionTransformer + + +_ERR_MSG_1DCOLUMN = ("1D data passed to a transformer that expects 2D data. " + "Try to specify the column selection as a list of one " + "item instead of a scalar.") + + +def issparse(X): + return sp_sparse.issparse(X) or cu_sparse.issparse(X) + + +def _determine_key_type(key, accept_slice=True): + """Determine the data type of key. + + Parameters + ---------- + key : scalar, slice or array-like + The key from which we want to infer the data type. + + accept_slice : bool, default=True + Whether or not to raise an error if the key is a slice. + + Returns + ------- + dtype : {'int', 'str', 'bool', None} + Returns the data type of key. + """ + err_msg = ("No valid specification of the columns. Only a scalar, list or " + "slice of all integers or all strings, or boolean mask is " + "allowed") + + dtype_to_str = {int: 'int', str: 'str', bool: 'bool', np.bool_: 'bool'} + array_dtype_to_str = {'i': 'int', 'u': 'int', 'b': 'bool', 'O': 'str', + 'U': 'str', 'S': 'str'} + + if key is None: + return None + if isinstance(key, tuple(dtype_to_str.keys())): + try: + return dtype_to_str[type(key)] + except KeyError: + raise ValueError(err_msg) + if isinstance(key, slice): + if not accept_slice: + raise TypeError( + 'Only array-like or scalar are supported. ' + 'A Python slice was given.' + ) + if key.start is None and key.stop is None: + return None + key_start_type = _determine_key_type(key.start) + key_stop_type = _determine_key_type(key.stop) + if key_start_type is not None and key_stop_type is not None: + if key_start_type != key_stop_type: + raise ValueError(err_msg) + if key_start_type is not None: + return key_start_type + return key_stop_type + if isinstance(key, (list, tuple)): + unique_key = set(key) + key_type = {_determine_key_type(elt) for elt in unique_key} + if not key_type: + return None + if len(key_type) != 1: + raise ValueError(err_msg) + return key_type.pop() + if hasattr(key, 'dtype'): + try: + return array_dtype_to_str[key.dtype.kind] + except KeyError: + raise ValueError(err_msg) + raise ValueError(err_msg) + + +def _get_column_indices(X, key): + """Get feature column indices for input data X and key. + """ + n_columns = X.shape[1] + + key_dtype = _determine_key_type(key) + + if isinstance(key, (list, tuple)) and not key: + # we get an empty list + return [] + elif key_dtype in ('bool', 'int'): + # Convert key into positive indexes + try: + idx = _safe_indexing(np.arange(n_columns), key) + except IndexError as e: + raise ValueError( + 'all features must be in [0, {}] or [-{}, 0]' + .format(n_columns - 1, n_columns) + ) from e + return np.atleast_1d(idx).tolist() + elif key_dtype == 'str': + try: + all_columns = X.columns + except AttributeError: + raise ValueError("Specifying the columns using strings is only " + "supported for pandas DataFrames") + if isinstance(key, str): + columns = [key] + elif isinstance(key, slice): + start, stop = key.start, key.stop + if start is not None: + start = all_columns.get_loc(start) + if stop is not None: + # pandas indexing with strings is endpoint included + stop = all_columns.get_loc(stop) + 1 + else: + stop = n_columns + 1 + return list(range(n_columns)[slice(start, stop)]) + else: + columns = list(key) + + try: + column_indices = [] + for col in columns: + col_idx = all_columns.get_loc(col) + if not isinstance(col_idx, numbers.Integral): + raise ValueError(f"Selected columns, {columns}, are not " + "unique in dataframe") + column_indices.append(col_idx) + + except KeyError as e: + raise ValueError( + "A given column is not a column of the dataframe" + ) from e + + return column_indices + else: + raise ValueError("No valid specification of the columns. Only a " + "scalar, list or slice of all integers or all " + "strings, or boolean mask is allowed") + + +def _safe_indexing(X, indices, *, axis=0): + """Return rows, items or columns of X using indices. + + Parameters + ---------- + X : array-like, sparse-matrix, list, dataframes, series data + from which to sample rows, items or columns. `list` are only + supported when `axis=0`. + indices : bool, int, str, slice, array-like + - If `axis=0`, boolean and integer array-like, integer slice, + and scalar integer are supported. + - If `axis=1`: + - to select a single column, `indices` can be of `int` type for + all `X` types and `str` only for dataframe. The selected subset + will be 1D, unless `X` is a sparse matrix in which case it will + be 2D. + - to select multiples columns, `indices` can be one of the + following: `list`, `array`, `slice`. The type used in + these containers can be one of the following: `int`, 'bool' and + `str`. However, `str` is only supported when `X` is a dataframe. + The selected subset will be 2D. + axis : int, default=0 + The axis along which `X` will be subsampled. `axis=0` will select + rows while `axis=1` will select columns. + + Returns + ------- + subset + Subset of X on axis 0 or 1. + + Notes + ----- + CSR, CSC, and LIL sparse matrices are supported. COO sparse matrices are + not supported. + """ + if indices is None: + return X + + if axis not in (0, 1): + raise ValueError( + "'axis' should be either 0 (to index rows) or 1 (to index " + " column). Got {} instead.".format(axis) + ) + + indices_dtype = _determine_key_type(indices) + + if axis == 0 and indices_dtype == 'str': + raise ValueError( + "String indexing is not supported with 'axis=0'" + ) + + if axis == 1 and X.ndim != 2: + raise ValueError( + "'X' should be a 2D NumPy array, 2D sparse matrix or pandas " + "dataframe when indexing the columns (i.e. 'axis=1'). " + "Got {} instead with {} dimension(s).".format(type(X), X.ndim) + ) + + if axis == 1 and indices_dtype == 'str' and not hasattr(X, 'loc'): + raise ValueError( + "Specifying the columns using strings is only supported for " + "pandas DataFrames" + ) + + if hasattr(X, "iloc"): + return _pandas_indexing(X, indices, indices_dtype, axis=axis) + elif hasattr(X, "shape"): + return _array_indexing(X, indices, indices_dtype, axis=axis) + else: + return _list_indexing(X, indices, indices_dtype) + + +def _array_indexing(array, key, key_dtype, axis): + """Index an array or a sparse array""" + if issparse(array): + # check if we have an boolean array-likes to make the proper indexing + if key_dtype == 'bool': + key = np.asarray(key) + if isinstance(key, tuple): + key = list(key) + if numba.cuda.is_cuda_array(array): + array = np.asarray(array) + return array[key] if axis == 0 else array[:, key] + + +def _pandas_indexing(X, key, key_dtype, axis): + """Index a dataframe or a series""" + if hasattr(key, 'shape'): + # Work-around for indexing with read-only key in pandas + # FIXME: solved in pandas 0.25 + key = np.asarray(key) + key = key if key.flags.writeable else key.copy() + elif isinstance(key, tuple): + key = list(key) + # check whether we should index with loc or iloc + indexer = X.iloc if key_dtype == 'int' else X.loc + return indexer[:, key] if axis else indexer[key] + + +def _list_indexing(X, key, key_dtype): + """Index a Python list.""" + if np.isscalar(key) or isinstance(key, slice): + # key is a slice or a scalar + return X[key] + if key_dtype == 'bool': + # key is a boolean array-like + return list(compress(X, key)) + # key is a integer array-like of key + return [X[idx] for idx in key] + + +def _transform_one(transformer, X, y, weight, **fit_params): + res = transformer.transform(X).to_output('cupy') + # if we have a weight for this transformer, multiply output + if weight is None: + return res + return res * weight + + +def _fit_transform_one(transformer, + X, + y, + weight, + message_clsname='', + message=None, + **fit_params): + """ + Fits ``transformer`` to ``X`` and ``y``. The transformed result is returned + with the fitted transformer. If ``weight`` is not ``None``, the result will + be multiplied by ``weight``. + """ + with _print_elapsed_time(message_clsname, message): + with cuml.using_output_type("cupy"): + transformer.accept_sparse = True + if hasattr(transformer, 'fit_transform'): + res = transformer.fit_transform(X, y, **fit_params) + else: + res = transformer.fit(X, y, **fit_params).transform(X) + + if weight is None: + return res, transformer + return res * weight, transformer + + +def _name_estimators(estimators): + """Generate names for estimators.""" + + names = [ + estimator + if isinstance(estimator, str) else type(estimator).__name__.lower() + for estimator in estimators + ] + namecount = defaultdict(int) + for est, name in zip(estimators, names): + namecount[name] += 1 + + for k, v in list(namecount.items()): + if v == 1: + del namecount[k] + + for i in reversed(range(len(estimators))): + name = names[i] + if name in namecount: + names[i] += "-%d" % namecount[name] + namecount[name] -= 1 + + return list(zip(names, estimators)) + + +def delayed(function): + """Decorator used to capture the arguments of a function.""" + @functools.wraps(function) + def delayed_function(*args, **kwargs): + return _FuncWrapper(function), args, kwargs + return delayed_function + + +class _FuncWrapper: + """"Load the global configuration before calling the function.""" + def __init__(self, function): + self.function = function + self.config = _global_settings_data.shared_state + functools.update_wrapper(self, self.function) + + def __call__(self, *args, **kwargs): + _global_settings_data.shared_state = self.config + return self.function(*args, **kwargs) + + +@contextmanager +def _print_elapsed_time(source, message=None): + """Log elapsed time to stdout when the context is exited. + Parameters + ---------- + source : str + String indicating the source or the reference of the message. + message : str, default=None + Short message. If None, nothing will be printed. + Returns + ------- + context_manager + Prints elapsed time upon exit if verbose. + """ + if message is None: + yield + else: + start = timeit.default_timer() + yield + print( + _message_with_time(source, message, + timeit.default_timer() - start)) + + +def _message_with_time(source, message, time): + """Create one line message for logging purposes. + Parameters + ---------- + source : str + String indicating the source or the reference of the message. + message : str + Short message. + time : int + Time in seconds. + """ + start_message = "[%s] " % source + + # adapted from joblib.logger.short_format_time without the Windows -.1s + # adjustment + if time > 60: + time_str = "%4.1fmin" % (time / 60) + else: + time_str = " %5.1fs" % time + end_message = " %s, total=%s" % (message, time_str) + dots_len = (70 - len(start_message) - len(end_message)) + return "%s%s%s" % (start_message, dots_len * '.', end_message) + + +class ColumnTransformer(TransformerMixin, BaseComposition, BaseEstimator): + """Applies transformers to columns of an array or dataframe. + + This estimator allows different columns or column subsets of the input + to be transformed separately and the features generated by each transformer + will be concatenated to form a single feature space. + This is useful for heterogeneous or columnar data, to combine several + feature extraction mechanisms or transformations into a single transformer. + + Parameters + ---------- + transformers : list of tuples + List of (name, transformer, columns) tuples specifying the + transformer objects to be applied to subsets of the data. + + name : str + Like in Pipeline and FeatureUnion, this allows the transformer and + its parameters to be set using ``set_params`` and searched in grid + search. + transformer : {'drop', 'passthrough'} or estimator + Estimator must support :term:`fit` and :term:`transform`. + Special-cased strings 'drop' and 'passthrough' are accepted as + well, to indicate to drop the columns or to pass them through + untransformed, respectively. + columns : str, array-like of str, int, array-like of int, \ + array-like of bool, slice or callable + Indexes the data on its second axis. Integers are interpreted as + positional columns, while strings can reference DataFrame columns + by name. A scalar string or int should be used where + ``transformer`` expects X to be a 1d array-like (vector), + otherwise a 2d array will be passed to the transformer. + A callable is passed the input data `X` and can return any of the + above. To select multiple columns by name or dtype, you can use + :obj:`make_column_selector`. + + remainder : {'drop', 'passthrough'} or estimator, default='drop' + By default, only the specified columns in `transformers` are + transformed and combined in the output, and the non-specified + columns are dropped. (default of ``'drop'``). + By specifying ``remainder='passthrough'``, all remaining columns that + were not specified in `transformers` will be automatically passed + through. This subset of columns is concatenated with the output of + the transformers. + By setting ``remainder`` to be an estimator, the remaining + non-specified columns will use the ``remainder`` estimator. The + estimator must support :term:`fit` and :term:`transform`. + Note that using this feature requires that the DataFrame columns + input at :term:`fit` and :term:`transform` have identical order. + + sparse_threshold : float, default=0.3 + If the output of the different transformers contains sparse matrices, + these will be stacked as a sparse matrix if the overall density is + lower than this value. Use ``sparse_threshold=0`` to always return + dense. When the transformed output consists of all dense data, the + stacked result will be dense, and this keyword will be ignored. + + n_jobs : int, default=None + Number of jobs to run in parallel. + ``None`` means 1 unless in a :obj:`joblib.parallel_backend` context. + ``-1`` means using all processors. + for more details. + + transformer_weights : dict, default=None + Multiplicative weights for features per transformer. The output of the + transformer is multiplied by these weights. Keys are transformer names, + values the weights. + + verbose : bool, default=False + If True, the time elapsed while fitting each transformer will be + printed as it is completed. + + Attributes + ---------- + transformers_ : list + The collection of fitted transformers as tuples of + (name, fitted_transformer, column). `fitted_transformer` can be an + estimator, 'drop', or 'passthrough'. In case there were no columns + selected, this will be the unfitted transformer. + If there are remaining columns, the final element is a tuple of the + form: + ('remainder', transformer, remaining_columns) corresponding to the + ``remainder`` parameter. If there are remaining columns, then + ``len(transformers_)==len(transformers)+1``, otherwise + ``len(transformers_)==len(transformers)``. + + named_transformers_ : :class:`~sklearn.utils.Bunch` + Read-only attribute to access any transformer by given name. + Keys are transformer names and values are the fitted transformer + objects. + + sparse_output_ : bool + Boolean flag indicating whether the output of ``transform`` is a + sparse matrix or a dense numpy array, which depends on the output + of the individual transformers and the `sparse_threshold` keyword. + + Notes + ----- + The order of the columns in the transformed feature matrix follows the + order of how the columns are specified in the `transformers` list. + Columns of the original feature matrix that are not specified are + dropped from the resulting transformed feature matrix, unless specified + in the `passthrough` keyword. Those columns specified with `passthrough` + are added at the right to the output of the transformers. + + See Also + -------- + make_column_transformer : Convenience function for + combining the outputs of multiple transformer objects applied to + column subsets of the original feature space. + make_column_selector : Convenience function for selecting + columns based on datatype or the columns name with a regex pattern. + + Examples + -------- + >>> import cupy as cp + >>> from cuml.compose import ColumnTransformer + >>> from cuml.preprocessing import Normalizer + >>> ct = ColumnTransformer( + ... [("norm1", Normalizer(norm='l1'), [0, 1]), + ... ("norm2", Normalizer(norm='l1'), slice(2, 4))]) + >>> X = cp.array([[0., 1., 2., 2.], + ... [1., 1., 0., 1.]]) + >>> # Normalizer scales each row of X to unit norm. A separate scaling + >>> # is applied for the two first and two last elements of each + >>> # row independently. + >>> ct.fit_transform(X) + array([[0. , 1. , 0.5, 0.5], + [0.5, 0.5, 0. , 1. ]]) + + """ + _required_parameters = ['transformers'] + + @_deprecate_pos_args(version="0.20") + def __init__(self, + transformers=None, + remainder='drop', + sparse_threshold=0.3, + n_jobs=None, + transformer_weights=None, + verbose=False): + if not transformers: + warnings.warn('Transformers are required') + self.transformers = transformers + self.remainder = remainder + self.sparse_threshold = sparse_threshold + self.n_jobs = n_jobs + self.transformer_weights = transformer_weights + self.verbose = verbose + + @property + def _transformers(self): + """ + Internal list of transformer only containing the name and + transformers, dropping the columns. This is for the implementation + of get_params via BaseComposition._get_params which expects lists + of tuples of len 2. + """ + return [(name, trans) for name, trans, _ in self.transformers] + + @_transformers.setter + def _transformers(self, value): + self.transformers = [ + (name, trans, col) for ((name, trans), (_, _, col)) + in zip(value, self.transformers)] + + def get_params(self, deep=True): + """Get parameters for this estimator. + + Returns the parameters given in the constructor as well as the + estimators contained within the `transformers` of the + `ColumnTransformer`. + + Parameters + ---------- + deep : bool, default=True + If True, will return the parameters for this estimator and + contained subobjects that are estimators. + + Returns + ------- + params : dict + Parameter names mapped to their values. + """ + return self._get_params('_transformers', deep=deep) + + def set_params(self, **kwargs): + """Set the parameters of this estimator. + + Valid parameter keys can be listed with ``get_params()``. Note that you + can directly set the parameters of the estimators contained in + `transformers` of `ColumnTransformer`. + + Returns + ------- + self + """ + self._set_params('_transformers', **kwargs) + return self + + def _iter(self, fitted=False, replace_strings=False): + """ + Generate (name, trans, column, weight) tuples. + + If fitted=True, use the fitted transformers, else use the + user specified transformers updated with converted column names + and potentially appended with transformer for remainder. + + """ + if fitted: + transformers = self.transformers_ + else: + # interleave the validated column specifiers + transformers = [ + (name, trans, column) for (name, trans, _), column + in zip(self.transformers, self._columns) + ] + # add transformer tuple for remainder + if self._remainder[2] is not None: + transformers = chain(transformers, [self._remainder]) + get_weight = (self.transformer_weights or {}).get + + for name, trans, column in transformers: + if replace_strings: + # replace 'passthrough' with identity transformer and + # skip in case of 'drop' + if trans == 'passthrough': + with cuml.using_output_type("cupy"): + trans = FunctionTransformer(accept_sparse=True, + check_inverse=False) + elif trans == 'drop': + continue + elif _is_empty_column_selection(column): + continue + + yield (name, trans, column, get_weight(name)) + + def _validate_transformers(self): + if not self.transformers: + return + + names, transformers, _ = zip(*self.transformers) + + # validate names + self._validate_names(names) + + # validate estimators + for t in transformers: + if t in ('drop', 'passthrough'): + continue + if (not (hasattr(t, "fit") or hasattr(t, "fit_transform")) or not + hasattr(t, "transform")): + raise TypeError("All estimators should implement fit and " + "transform, or can be 'drop' or 'passthrough' " + "specifiers. '%s' (type %s) doesn't." % + (t, type(t))) + + def _validate_column_callables(self, X): + """ + Converts callable column specifications. + """ + columns = [] + for _, _, column in self.transformers: + if callable(column): + column = column(X) + columns.append(column) + self._columns = columns + + def _validate_remainder(self, X): + """ + Validates ``remainder`` and defines ``_remainder`` targeting + the remaining columns. + """ + is_transformer = ((hasattr(self.remainder, "fit") + or hasattr(self.remainder, "fit_transform")) + and hasattr(self.remainder, "transform")) + if (self.remainder not in ('drop', 'passthrough') + and not is_transformer): + raise ValueError( + "The remainder keyword needs to be one of 'drop', " + "'passthrough', or estimator. '%s' was passed instead" % + self.remainder) + + # Make it possible to check for reordered named columns on transform + self._has_str_cols = any(_determine_key_type(cols) == 'str' + for cols in self._columns) + if hasattr(X, 'columns'): + self._df_columns = X.columns + + self._n_features = X.shape[1] + cols = [] + for columns in self._columns: + cols.extend(_get_column_indices(X, columns)) + + remaining_idx = sorted(set(range(self._n_features)) - set(cols)) + self._remainder = ('remainder', self.remainder, remaining_idx or None) + + @property + def named_transformers_(self): + """Access the fitted transformer by name. + + Read-only attribute to access any transformer by given name. + Keys are transformer names and values are the fitted transformer + objects. + + """ + # Use Bunch object to improve autocomplete + return Bunch(**{name: trans for name, trans, _ + in self.transformers_}) + + def get_feature_names(self): + """Get feature names from all transformers. + + Returns + ------- + feature_names : list of strings + Names of the features produced by transform. + """ + check_is_fitted(self) + feature_names = [] + for name, trans, column, _ in self._iter(fitted=True): + if trans == 'drop' or ( + hasattr(column, '__len__') and not len(column)): + continue + if trans == 'passthrough': + if hasattr(self, '_df_columns'): + if ((not isinstance(column, slice)) + and all(isinstance(col, str) for col in column)): + feature_names.extend(column) + else: + feature_names.extend(self._df_columns[column]) + else: + indices = np.arange(self._n_features) + feature_names.extend(['x%d' % i for i in indices[column]]) + continue + if not hasattr(trans, 'get_feature_names'): + raise AttributeError("Transformer %s (type %s) does not " + "provide get_feature_names." + % (str(name), type(trans).__name__)) + feature_names.extend([name + "__" + f for f in + trans.get_feature_names()]) + return feature_names + + def _update_fitted_transformers(self, transformers): + # transformers are fitted; excludes 'drop' cases + fitted_transformers = iter(transformers) + transformers_ = [] + + for name, old, column, _ in self._iter(): + if old == 'drop': + trans = 'drop' + elif old == 'passthrough': + # FunctionTransformer is present in list of transformers, + # so get next transformer, but save original string + next(fitted_transformers) + trans = 'passthrough' + elif _is_empty_column_selection(column): + trans = old + else: + trans = next(fitted_transformers) + transformers_.append((name, trans, column)) + + # sanity check that transformers is exhausted + assert not list(fitted_transformers) + self.transformers_ = transformers_ + + def _validate_output(self, result): + """ + Ensure that the output of each transformer is 2D. Otherwise + hstack can raise an error or produce incorrect results. + """ + names = [name for name, _, _, _ in self._iter(fitted=True, + replace_strings=True)] + for Xs, name in zip(result, names): + if not getattr(Xs, 'ndim', 0) == 2: + raise ValueError( + "The output of the '{0}' transformer should be 2D (scipy " + "matrix, array, or pandas DataFrame).".format(name)) + + def _log_message(self, name, idx, total): + if not self.verbose: + return None + return '(%d of %d) Processing %s' % (idx, total, name) + + def _fit_transform(self, X, y, func, fitted=False): + """ + Private function to fit and/or transform on demand. + + Return value (transformers and/or transformed X data) depends + on the passed function. + ``fitted=True`` ensures the fitted transformers are used. + """ + transformers = list( + self._iter(fitted=fitted, replace_strings=True)) + try: + return Parallel(n_jobs=self.n_jobs)( + delayed(func)( + transformer=clone(trans) if not fitted else trans, + X=_safe_indexing(X, column, axis=1), + y=y, + weight=weight, + message_clsname='ColumnTransformer', + message=self._log_message(name, idx, len(transformers))) + for idx, (name, trans, column, weight) in enumerate( + self._iter(fitted=fitted, replace_strings=True), 1)) + except ValueError as e: + if "Expected 2D array, got 1D array instead" in str(e): + raise ValueError(_ERR_MSG_1DCOLUMN) from e + else: + raise + + def fit(self, X, y=None) -> "ColumnTransformer": + """Fit all transformers using X. + + Parameters + ---------- + X : {array-like, dataframe} of shape (n_samples, n_features) + Input data, of which specified subsets are used to fit the + transformers. + + y : array-like of shape (n_samples,...), default=None + Targets for supervised learning. + + Returns + ------- + self : ColumnTransformer + This estimator + + """ + # we use fit_transform to make sure to set sparse_output_ (for which we + # need the transformed data) to have consistent output type in predict + self.fit_transform(X, y=y) + return self + + def fit_transform(self, X, y=None) -> SparseCumlArray: + """Fit all transformers, transform the data and concatenate results. + + Parameters + ---------- + X : {array-like, dataframe} of shape (n_samples, n_features) + Input data, of which specified subsets are used to fit the + transformers. + + y : array-like of shape (n_samples,), default=None + Targets for supervised learning. + + Returns + ------- + X_t : {array-like, sparse matrix} of \ + shape (n_samples, sum_n_components) + hstack of results of transformers. sum_n_components is the + sum of n_components (output dimension) over transformers. If + any result is a sparse matrix, everything will be converted to + sparse matrices. + + """ + # TODO: this should be `feature_names_in_` when we start having it + if hasattr(X, "columns"): + self._feature_names_in = cpu_np.asarray(X.columns) + else: + self._feature_names_in = None + # set n_features_in_ attribute + self._check_n_features(X, reset=True) + self._validate_transformers() + self._validate_column_callables(X) + self._validate_remainder(X) + + result = self._fit_transform(X, y, _fit_transform_one) + + if not result: + self._update_fitted_transformers([]) + # All transformers are None + return np.zeros((X.shape[0], 0)) + + Xs, transformers = zip(*result) + + # determine if concatenated output will be sparse or not + if any(issparse(X) for X in Xs): + nnz = sum(X.nnz if issparse(X) else X.size for X in Xs) + total = sum(X.shape[0] * X.shape[1] if issparse(X) + else X.size for X in Xs) + density = nnz / total + self.sparse_output_ = density < self.sparse_threshold + else: + self.sparse_output_ = False + + self._update_fitted_transformers(transformers) + self._validate_output(Xs) + + return self._hstack(list(Xs)) + + def transform(self, X) -> SparseCumlArray: + """Transform X separately by each transformer, concatenate results. + + Parameters + ---------- + X : {array-like, dataframe} of shape (n_samples, n_features) + The data to be transformed by subset. + + Returns + ------- + X_t : {array-like, sparse matrix} of \ + shape (n_samples, sum_n_components) + hstack of results of transformers. sum_n_components is the + sum of n_components (output dimension) over transformers. If + any result is a sparse matrix, everything will be converted to + sparse matrices. + + """ + check_is_fitted(self) + if hasattr(X, "columns"): + X_feature_names = cpu_np.asarray(X.columns) + else: + X_feature_names = None + + self._check_n_features(X, reset=False) + if (self._feature_names_in is not None and + X_feature_names is not None and + cpu_np.any(self._feature_names_in != X_feature_names)): + raise RuntimeError( + "Given feature/column names do not match the ones for the " + "data given during fit." + ) + Xs = self._fit_transform(X, None, _transform_one, fitted=True) + self._validate_output(Xs) + + if not Xs: + # All transformers are None + return np.zeros((X.shape[0], 0)) + + return self._hstack(list(Xs)) + + def _hstack(self, Xs): + """Stacks Xs horizontally. + + This allows subclasses to control the stacking behavior, while reusing + everything else from ColumnTransformer. + + Parameters + ---------- + Xs : list of {array-like, sparse matrix, dataframe} + """ + if self.sparse_output_: + try: + # since all columns should be numeric before stacking them + # in a sparse matrix, `check_array` is used for the + # dtype conversion if necessary. + converted_Xs = [check_array(X, + accept_sparse=True, + force_all_finite=False) + for X in Xs] + except ValueError as e: + raise ValueError( + "For a sparse output, all columns should " + "be a numeric or convertible to a numeric." + ) from e + + return cu_sparse.hstack(converted_Xs).tocsr() + else: + Xs = [f.toarray() if issparse(f) else f for f in Xs] + return np.hstack(Xs) + + +def _is_empty_column_selection(column): + """ + Return True if the column selection is empty (empty list or all-False + boolean array). + + """ + if hasattr(column, 'dtype') and np.issubdtype(column.dtype, np.bool_): + return not column.any() + elif hasattr(column, '__len__'): + return (len(column) == 0 or + all(isinstance(col, bool) for col in column) + and not any(column)) + else: + return False + + +def _get_transformer_list(estimators): + """ + Construct (name, trans, column) tuples from list + + """ + transformers, columns = zip(*estimators) + names, _ = zip(*_name_estimators(transformers)) + + transformer_list = list(zip(names, transformers, columns)) + return transformer_list + + +def make_column_transformer(*transformers, + remainder='drop', + sparse_threshold=0.3, + n_jobs=None, + verbose=False): + """Construct a ColumnTransformer from the given transformers. + + This is a shorthand for the ColumnTransformer constructor; it does not + require, and does not permit, naming the transformers. Instead, they will + be given names automatically based on their types. It also does not allow + weighting with ``transformer_weights``. + + Parameters + ---------- + *transformers : tuples + Tuples of the form (transformer, columns) specifying the + transformer objects to be applied to subsets of the data. + + transformer : {'drop', 'passthrough'} or estimator + Estimator must support :term:`fit` and :term:`transform`. + Special-cased strings 'drop' and 'passthrough' are accepted as + well, to indicate to drop the columns or to pass them through + untransformed, respectively. + columns : str, array-like of str, int, array-like of int, slice, \ + array-like of bool or callable + Indexes the data on its second axis. Integers are interpreted as + positional columns, while strings can reference DataFrame columns + by name. A scalar string or int should be used where + ``transformer`` expects X to be a 1d array-like (vector), + otherwise a 2d array will be passed to the transformer. + A callable is passed the input data `X` and can return any of the + above. To select multiple columns by name or dtype, you can use + :obj:`make_column_selector`. + + remainder : {'drop', 'passthrough'} or estimator, default='drop' + By default, only the specified columns in `transformers` are + transformed and combined in the output, and the non-specified + columns are dropped. (default of ``'drop'``). + By specifying ``remainder='passthrough'``, all remaining columns that + were not specified in `transformers` will be automatically passed + through. This subset of columns is concatenated with the output of + the transformers. + By setting ``remainder`` to be an estimator, the remaining + non-specified columns will use the ``remainder`` estimator. The + estimator must support :term:`fit` and :term:`transform`. + + sparse_threshold : float, default=0.3 + If the transformed output consists of a mix of sparse and dense data, + it will be stacked as a sparse matrix if the density is lower than this + value. Use ``sparse_threshold=0`` to always return dense. + When the transformed output consists of all sparse or all dense data, + the stacked result will be sparse or dense, respectively, and this + keyword will be ignored. + + n_jobs : int, default=None + Number of jobs to run in parallel. + ``None`` means 1 unless in a :obj:`joblib.parallel_backend` context. + ``-1`` means using all processors. See :term:`Glossary ` + for more details. + + verbose : bool, default=False + If True, the time elapsed while fitting each transformer will be + printed as it is completed. + + Returns + ------- + ct : ColumnTransformer + + See Also + -------- + ColumnTransformer : Class that allows combining the + outputs of multiple transformer objects used on column subsets + of the data into a single feature space. + + Examples + -------- + >>> from cuml.preprocessing import StandardScaler, OneHotEncoder + >>> from cuml.compose import make_column_transformer + >>> make_column_transformer( + ... (StandardScaler(), ['numerical_column']), + ... (OneHotEncoder(), ['categorical_column'])) + ColumnTransformer(transformers=[('standardscaler', StandardScaler(...), + ['numerical_column']), + ('onehotencoder', OneHotEncoder(...), + ['categorical_column'])]) + + """ + # transformer_weights keyword is not passed through because the user + # would need to know the automatically generated names of the transformers + transformer_list = _get_transformer_list(transformers) + return ColumnTransformer(transformer_list, n_jobs=n_jobs, + remainder=remainder, + sparse_threshold=sparse_threshold, + verbose=verbose) + + +class make_column_selector: + """Create a callable to select columns to be used with + :class:`ColumnTransformer`. + + :func:`make_column_selector` can select columns based on datatype or the + columns name with a regex. When using multiple selection criteria, **all** + criteria must match for a column to be selected. + + Parameters + ---------- + pattern : str, default=None + Name of columns containing this regex pattern will be included. If + None, column selection will not be selected based on pattern. + + dtype_include : column dtype or list of column dtypes, default=None + A selection of dtypes to include. For more details, see + :meth:`pandas.DataFrame.select_dtypes`. + + dtype_exclude : column dtype or list of column dtypes, default=None + A selection of dtypes to exclude. For more details, see + :meth:`pandas.DataFrame.select_dtypes`. + + Returns + ------- + selector : callable + Callable for column selection to be used by a + :class:`ColumnTransformer`. + + See Also + -------- + ColumnTransformer : Class that allows combining the + outputs of multiple transformer objects used on column subsets + of the data into a single feature space. + + Examples + -------- + >>> from cuml.preprocessing import StandardScaler, OneHotEncoder + >>> from cuml.preprocessing import make_column_transformer + >>> from cuml.preprocessing import make_column_selector + >>> import cupy as cp + >>> import cudf # doctest: +SKIP + >>> X = cudf.DataFrame({'city': ['London', 'London', 'Paris', 'Sallisaw'], + ... 'rating': [5, 3, 4, 5]}) # doctest: +SKIP + >>> ct = make_column_transformer( + ... (StandardScaler(), + ... make_column_selector(dtype_include=cp.number)), # rating + ... (OneHotEncoder(), + ... make_column_selector(dtype_include=object))) # city + >>> ct.fit_transform(X) # doctest: +SKIP + array([[ 0.90453403, 1. , 0. , 0. ], + [-1.50755672, 1. , 0. , 0. ], + [-0.30151134, 0. , 1. , 0. ], + [ 0.90453403, 0. , 0. , 1. ]]) + """ + def __init__(self, pattern=None, *, dtype_include=None, + dtype_exclude=None): + self.pattern = pattern + self.dtype_include = dtype_include + self.dtype_exclude = dtype_exclude + + def __call__(self, df): + if not hasattr(df, 'iloc'): + raise ValueError("make_column_selector can only be applied to " + "pandas dataframes") + df_row = df.iloc[:1] + if self.dtype_include is not None or self.dtype_exclude is not None: + df_row = df_row.select_dtypes(include=self.dtype_include, + exclude=self.dtype_exclude) + cols = df_row.columns + if self.pattern is not None: + cols = cols[cols.str.contains(self.pattern, regex=True)] + return cols.tolist() diff --git a/python/cuml/_thirdparty/sklearn/preprocessing/_function_transformer.py b/python/cuml/_thirdparty/sklearn/preprocessing/_function_transformer.py new file mode 100644 index 0000000000..1fea2d8add --- /dev/null +++ b/python/cuml/_thirdparty/sklearn/preprocessing/_function_transformer.py @@ -0,0 +1,157 @@ +# This code originates from the Scikit-Learn library, +# it was since modified to allow GPU acceleration. +# This code is under BSD 3 clause license. +# Authors mentioned above do not endorse or promote this production. + + +import warnings + +import cuml +from ....common.array_sparse import SparseCumlArray +from ..utils.skl_dependencies import TransformerMixin, BaseEstimator +from ..utils.validation import _allclose_dense_sparse +from ....internals import _deprecate_pos_args + + +def _identity(X): + """The identity function. + """ + return X + + +class FunctionTransformer(TransformerMixin, BaseEstimator): + """Constructs a transformer from an arbitrary callable. + + A FunctionTransformer forwards its X (and optionally y) arguments to a + user-defined function or function object and returns the result of this + function. This is useful for stateless transformations such as taking the + log of frequencies, doing custom scaling, etc. + + Note: If a lambda is used as the function, then the resulting + transformer will not be pickleable. + + Parameters + ---------- + func : callable, default=None + The callable to use for the transformation. This will be passed + the same arguments as transform, with args and kwargs forwarded. + If func is None, then func will be the identity function. + + inverse_func : callable, default=None + The callable to use for the inverse transformation. This will be + passed the same arguments as inverse transform, with args and + kwargs forwarded. If inverse_func is None, then inverse_func + will be the identity function. + + accept_sparse : bool, default=False + Indicate that func accepts a sparse matrix as input. Otherwise, + if accept_sparse is false, sparse matrix inputs will cause + an exception to be raised. + + check_inverse : bool, default=True + Whether to check that or ``func`` followed by ``inverse_func`` leads to + the original inputs. It can be used for a sanity check, raising a + warning when the condition is not fulfilled. + + kw_args : dict, default=None + Dictionary of additional keyword arguments to pass to func. + + inv_kw_args : dict, default=None + Dictionary of additional keyword arguments to pass to inverse_func. + + Examples + -------- + >>> import cupy as cp + >>> from cuml.preprocessing import FunctionTransformer + >>> transformer = FunctionTransformer(cp.log1p) + >>> X = cp.array([[0, 1], [2, 3]]) + >>> transformer.transform(X) + array([[0. , 0.6931...], + [1.0986..., 1.3862...]]) + """ + + @_deprecate_pos_args(version="0.20") + def __init__(self, *, func=None, inverse_func=None, accept_sparse=False, + check_inverse=True, kw_args=None, inv_kw_args=None): + self.func = func + self.inverse_func = inverse_func + self.accept_sparse = accept_sparse + self.check_inverse = check_inverse + self.kw_args = kw_args + self.inv_kw_args = inv_kw_args + + def _check_input(self, X): + return self._validate_data(X, accept_sparse=self.accept_sparse) + + def _check_inverse_transform(self, X): + """Check that func and inverse_func are the inverse.""" + interval = max(1, X.shape[0] // 100) + selection = [i * interval for i in range(X.shape[0] // interval)] + with cuml.using_output_type("cupy"): + X_round_trip = self.inverse_transform(self.transform(X[selection])) + if not _allclose_dense_sparse(X[selection], X_round_trip): + warnings.warn("The provided functions are not strictly" + " inverse of each other. If you are sure you" + " want to proceed regardless, set" + " 'check_inverse=False'.", UserWarning) + + def fit(self, X, y=None) -> "FunctionTransformer": + """Fit transformer by checking X. + + Parameters + ---------- + X : {array-like, sparse matrix}, shape (n_samples, n_features) + Input array. + + Returns + ------- + self + """ + X = self._check_input(X) + if (self.check_inverse and not (self.func is None or + self.inverse_func is None)): + self._check_inverse_transform(X) + return self + + def transform(self, X) -> SparseCumlArray: + """Transform X using the forward function. + + Parameters + ---------- + X : {array-like, sparse matrix}, shape (n_samples, n_features) + Input array. + + Returns + ------- + X_out : {array-like, sparse matrix}, shape (n_samples, n_features) + Transformed input. + """ + return self._transform(X, func=self.func, kw_args=self.kw_args) + + def inverse_transform(self, X) -> SparseCumlArray: + """Transform X using the inverse function. + + Parameters + ---------- + X : {array-like, sparse matrix}, shape (n_samples, n_features) + Input array. + + Returns + ------- + X_out : {array-like, sparse matrix}, shape (n_samples, n_features) + Transformed input. + """ + return self._transform(X, func=self.inverse_func, + kw_args=self.inv_kw_args) + + def _transform(self, X, func=None, kw_args=None): + X = self._check_input(X) + + if func is None: + func = _identity + + return func(X, **(kw_args if kw_args else {})) + + def _more_tags(self): + return {'stateless': True, + 'requires_y': False} diff --git a/python/cuml/_thirdparty/sklearn/utils/skl_dependencies.py b/python/cuml/_thirdparty/sklearn/utils/skl_dependencies.py index 30a28fd44d..ab76f4820e 100644 --- a/python/cuml/_thirdparty/sklearn/utils/skl_dependencies.py +++ b/python/cuml/_thirdparty/sklearn/utils/skl_dependencies.py @@ -163,3 +163,59 @@ def fit_transform(self, X, y=None, **fit_params): else: # fit method of arity 2 (supervised transformation) return self.fit(X, y, **fit_params).transform(X) + + +class BaseComposition: + """Handles parameter management for classifiers composed of named estimators. + """ + + def _get_params(self, attr, deep=True): + out = super().get_params(deep=deep) + if not deep: + return out + estimators = getattr(self, attr) + out.update(estimators) + for name, estimator in estimators: + if hasattr(estimator, 'get_params'): + for key, value in estimator.get_params(deep=True).items(): + out['%s__%s' % (name, key)] = value + return out + + def _set_params(self, attr, **params): + # Ensure strict ordering of parameter setting: + # 1. All steps + if attr in params: + setattr(self, attr, params.pop(attr)) + # 2. Step replacement + items = getattr(self, attr) + names = [] + if items: + names, _ = zip(*items) + for name in list(params.keys()): + if '__' not in name and name in names: + self._replace_estimator(attr, name, params.pop(name)) + # 3. Step parameters and other initialisation arguments + super().set_params(**params) + return self + + def _replace_estimator(self, attr, name, new_val): + # assumes `name` is a valid estimator name + new_estimators = list(getattr(self, attr)) + for i, (estimator_name, _) in enumerate(new_estimators): + if estimator_name == name: + new_estimators[i] = (name, new_val) + break + setattr(self, attr, new_estimators) + + def _validate_names(self, names): + if len(set(names)) != len(names): + raise ValueError('Names provided are not unique: ' + '{0!r}'.format(list(names))) + invalid_names = set(names).intersection(self.get_params(deep=False)) + if invalid_names: + raise ValueError('Estimator names conflict with constructor ' + 'arguments: {0!r}'.format(sorted(invalid_names))) + invalid_names = [name for name in names if '__' in name] + if invalid_names: + raise ValueError('Estimator names must not contain __: got ' + '{0!r}'.format(invalid_names)) diff --git a/python/cuml/_thirdparty/sklearn/utils/validation.py b/python/cuml/_thirdparty/sklearn/utils/validation.py index 8cc6f044fc..62581b7cdc 100644 --- a/python/cuml/_thirdparty/sklearn/utils/validation.py +++ b/python/cuml/_thirdparty/sklearn/utils/validation.py @@ -17,6 +17,8 @@ import numbers import numpy as np +import cupy as cp +import cupy.sparse as sp from inspect import isclass from ....common.exceptions import NotFittedError @@ -229,3 +231,39 @@ def check_is_fitted(estimator, attributes=None, *, msg=None, all_or_any=all): if not attrs: raise NotFittedError(msg % {'name': type(estimator).__name__}) + + +def _allclose_dense_sparse(x, y, rtol=1e-7, atol=1e-9): + """Check allclose for sparse and dense data. + + Both x and y need to be either sparse or dense, they + can't be mixed. + + Parameters + ---------- + x : array-like or sparse matrix + First array to compare. + + y : array-like or sparse matrix + Second array to compare. + + rtol : float, optional + relative tolerance; see numpy.allclose + + atol : float, optional + absolute tolerance; see numpy.allclose. Note that the default here is + more tolerant than the default for numpy.testing.assert_allclose, where + atol=0. + """ + if sp.issparse(x) and sp.issparse(y): + x = x.tocsr() + y = y.tocsr() + x.sum_duplicates() + y.sum_duplicates() + return (cp.array_equal(x.indices, y.indices) and + cp.array_equal(x.indptr, y.indptr) and + cp.allclose(x.data, y.data, rtol=rtol, atol=atol)) + elif not sp.issparse(x) and not sp.issparse(y): + return cp.allclose(x, y, rtol=rtol, atol=atol) + raise ValueError("Can only compare two sparse matrices, not a sparse " + "matrix and an array") diff --git a/python/cuml/experimental/preprocessing/__init__.py b/python/cuml/experimental/preprocessing/__init__.py index e6951b1f22..91a28952ad 100644 --- a/python/cuml/experimental/preprocessing/__init__.py +++ b/python/cuml/experimental/preprocessing/__init__.py @@ -22,10 +22,15 @@ from cuml.preprocessing import scale, minmax_scale, maxabs_scale, normalize, \ add_dummy_feature, binarize, robust_scale +from cuml._thirdparty.sklearn.preprocessing import ColumnTransformer, \ + FunctionTransformer, make_column_transformer, make_column_selector + __all__ = [ # Classes 'Binarizer', + 'ColumnTransformer', + 'FunctionTransformer', 'KBinsDiscretizer', 'MaxAbsScaler', 'MinMaxScaler', @@ -39,6 +44,8 @@ 'add_dummy_feature', 'binarize', 'minmax_scale', + 'make_column_selector', + 'make_column_transformer', 'maxabs_scale', 'normalize', 'robust_scale', diff --git a/python/cuml/preprocessing/LabelEncoder.py b/python/cuml/preprocessing/LabelEncoder.py index fa83179ee5..2791ebf1d3 100644 --- a/python/cuml/preprocessing/LabelEncoder.py +++ b/python/cuml/preprocessing/LabelEncoder.py @@ -17,7 +17,7 @@ import cudf import cupy as cp from cuml import Base - +from pandas import Series as pdSeries from cuml.common.exceptions import NotFittedError @@ -169,6 +169,9 @@ def fit(self, y, _classes=None): A fitted instance of itself to allow method chaining """ + if isinstance(y, pdSeries): + y = cudf.from_pandas(y) + self._validate_keywords() self.dtype = y.dtype if y.dtype != cp.dtype('O') else str @@ -204,6 +207,9 @@ def transform(self, y: cudf.Series) -> cudf.Series: KeyError if a category appears that was not seen in `fit` """ + if isinstance(y, pdSeries): + y = cudf.from_pandas(y) + self._check_is_fitted() y = y.astype('category') @@ -224,6 +230,9 @@ def fit_transform(self, y: cudf.Series, z=None) -> cudf.Series: This is functionally equivalent to (but faster than) `LabelEncoder().fit(y).transform(y)` """ + if isinstance(y, pdSeries): + y = cudf.from_pandas(y) + self.dtype = y.dtype if y.dtype != cp.dtype('O') else str y = y.astype('category') diff --git a/python/cuml/test/test_compose.py b/python/cuml/test/test_compose.py new file mode 100644 index 0000000000..f80c390662 --- /dev/null +++ b/python/cuml/test/test_compose.py @@ -0,0 +1,270 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest + +import cudf +import numpy as np +from pandas import DataFrame as pdDataFrame +from cudf import DataFrame as cuDataFrame + +from cuml.experimental.preprocessing import \ + ColumnTransformer as cuColumnTransformer, \ + make_column_transformer as cu_make_column_transformer, \ + make_column_selector as cu_make_column_selector + +from sklearn.compose import \ + ColumnTransformer as skColumnTransformer, \ + make_column_transformer as sk_make_column_transformer, \ + make_column_selector as sk_make_column_selector + +from cuml.test.test_preproc_utils import clf_dataset, \ + sparse_clf_dataset # noqa: F401 + +from cuml.preprocessing import \ + StandardScaler as cuStandardScaler, \ + Normalizer as cuNormalizer, \ + PolynomialFeatures as cuPolynomialFeatures, \ + OneHotEncoder as cuOneHotEncoder + +from sklearn.preprocessing import \ + StandardScaler as skStandardScaler, \ + Normalizer as skNormalizer, \ + PolynomialFeatures as skPolynomialFeatures, \ + OneHotEncoder as skOneHotEncoder + +from cuml.test.test_preproc_utils import assert_allclose + + +@pytest.mark.parametrize('remainder', ['drop', 'passthrough']) +@pytest.mark.parametrize('transformer_weights', [None, {'scaler': 2.4, + 'normalizer': 1.8}]) +def test_column_transformer(clf_dataset, remainder, # noqa: F811 + transformer_weights): + X_np, X = clf_dataset + + sk_selec1 = [0, 2] + sk_selec2 = [1, 3] + cu_selec1 = sk_selec1 + cu_selec2 = sk_selec2 + if isinstance(X, (pdDataFrame, cuDataFrame)): + cu_selec1 = ['c'+str(i) for i in sk_selec1] + cu_selec2 = ['c'+str(i) for i in sk_selec2] + + cu_transformers = [ + ("scaler", cuStandardScaler(), cu_selec1), + ("normalizer", cuNormalizer(), cu_selec2) + ] + + transformer = cuColumnTransformer(cu_transformers, + remainder=remainder, + transformer_weights=transformer_weights) + ft_X = transformer.fit_transform(X) + t_X = transformer.transform(X) + assert type(t_X) == type(X) + + sk_transformers = [ + ("scaler", skStandardScaler(), sk_selec1), + ("normalizer", skNormalizer(), sk_selec2) + ] + + transformer = skColumnTransformer(sk_transformers, + remainder=remainder, + transformer_weights=transformer_weights) + sk_t_X = transformer.fit_transform(X_np) + + assert_allclose(ft_X, sk_t_X) + assert_allclose(t_X, sk_t_X) + + +@pytest.mark.parametrize('remainder', ['drop', 'passthrough']) +@pytest.mark.parametrize('transformer_weights', [None, {'scaler': 2.4, + 'normalizer': 1.8}]) +@pytest.mark.parametrize('sparse_threshold', [0.2, 0.8]) +def test_column_transformer_sparse(sparse_clf_dataset, remainder, # noqa: F811 + transformer_weights, sparse_threshold): + X_np, X = sparse_clf_dataset + + if X.format == 'csc': + pytest.xfail() + dataset_density = X.nnz / X.size + + cu_transformers = [ + ("scaler", cuStandardScaler(with_mean=False), [0, 2]), + ("normalizer", cuNormalizer(), [1, 3]) + ] + + transformer = cuColumnTransformer(cu_transformers, + remainder=remainder, + transformer_weights=transformer_weights, + sparse_threshold=sparse_threshold) + ft_X = transformer.fit_transform(X) + t_X = transformer.transform(X) + if dataset_density < sparse_threshold: + # Sparse input -> sparse output if dataset_density > sparse_threshold + # else sparse input -> dense output + assert type(t_X) == type(X) + + sk_transformers = [ + ("scaler", skStandardScaler(with_mean=False), [0, 2]), + ("normalizer", skNormalizer(), [1, 3]) + ] + + transformer = skColumnTransformer(sk_transformers, + remainder=remainder, + transformer_weights=transformer_weights, + sparse_threshold=sparse_threshold) + sk_t_X = transformer.fit_transform(X_np) + + assert_allclose(ft_X, sk_t_X) + assert_allclose(t_X, sk_t_X) + + +@pytest.mark.parametrize('remainder', ['drop', 'passthrough']) +def test_make_column_transformer(clf_dataset, remainder): # noqa: F811 + X_np, X = clf_dataset + + sk_selec1 = [0, 2] + sk_selec2 = [1, 3] + cu_selec1 = sk_selec1 + cu_selec2 = sk_selec2 + if isinstance(X, (pdDataFrame, cuDataFrame)): + cu_selec1 = ['c'+str(i) for i in sk_selec1] + cu_selec2 = ['c'+str(i) for i in sk_selec2] + + transformer = cu_make_column_transformer( + (cuStandardScaler(), cu_selec1), + (cuNormalizer(), cu_selec2), + remainder=remainder) + + ft_X = transformer.fit_transform(X) + t_X = transformer.transform(X) + assert type(t_X) == type(X) + + transformer = sk_make_column_transformer( + (skStandardScaler(), sk_selec1), + (skNormalizer(), sk_selec2), + remainder=remainder) + sk_t_X = transformer.fit_transform(X_np) + + assert_allclose(ft_X, sk_t_X) + assert_allclose(t_X, sk_t_X) + + +@pytest.mark.parametrize('remainder', ['drop', 'passthrough']) +@pytest.mark.parametrize('sparse_threshold', [0.2, 0.8]) +def test_make_column_transformer_sparse(sparse_clf_dataset, # noqa: F811 + remainder, sparse_threshold): + X_np, X = sparse_clf_dataset + + if X.format == 'csc': + pytest.xfail() + dataset_density = X.nnz / X.size + + transformer = cu_make_column_transformer( + (cuStandardScaler(with_mean=False), [0, 2]), + (cuNormalizer(), [1, 3]), + remainder=remainder, + sparse_threshold=sparse_threshold) + + ft_X = transformer.fit_transform(X) + t_X = transformer.transform(X) + if dataset_density < sparse_threshold: + # Sparse input -> sparse output if dataset_density > sparse_threshold + # else sparse input -> dense output + assert type(t_X) == type(X) + + transformer = sk_make_column_transformer( + (skStandardScaler(with_mean=False), [0, 2]), + (skNormalizer(), [1, 3]), + remainder=remainder, + sparse_threshold=sparse_threshold) + + sk_t_X = transformer.fit_transform(X_np) + + assert_allclose(ft_X, sk_t_X) + assert_allclose(t_X, sk_t_X) + + +def test_column_transformer_get_feature_names(clf_dataset): # noqa: F811 + X_np, X = clf_dataset + + cu_transformers = [ + ("PolynomialFeatures", cuPolynomialFeatures(), [0, 2]) + ] + transformer = cuColumnTransformer(cu_transformers) + transformer.fit_transform(X) + cu_feature_names = transformer.get_feature_names() + + sk_transformers = [ + ("PolynomialFeatures", skPolynomialFeatures(), [0, 2]) + ] + transformer = skColumnTransformer(sk_transformers) + transformer.fit_transform(X_np) + sk_feature_names = transformer.get_feature_names() + + assert cu_feature_names == sk_feature_names + + +def test_column_transformer_named_transformers_(clf_dataset): # noqa: F811 + X_np, X = clf_dataset + + cu_transformers = [ + ("PolynomialFeatures", cuPolynomialFeatures(), [0, 2]) + ] + transformer = cuColumnTransformer(cu_transformers) + transformer.fit_transform(X) + cu_named_transformers = transformer.named_transformers_ + + sk_transformers = [ + ("PolynomialFeatures", skPolynomialFeatures(), [0, 2]) + ] + transformer = skColumnTransformer(sk_transformers) + transformer.fit_transform(X_np) + sk_named_transformers = transformer.named_transformers_ + + assert cu_named_transformers.keys() == sk_named_transformers.keys() + + +def test_make_column_selector(): + X_np = pdDataFrame({'city': ['London', 'London', 'Paris', 'Sallisaw'], + 'rating': [5, 3, 4, 5], + 'temperature': [21., 21., 24., 28.]}) + X = cudf.from_pandas(X_np) + + cu_transformers = [ + ("ohe", cuOneHotEncoder(), + cu_make_column_selector(dtype_exclude=np.number)), + ("scaler", cuStandardScaler(), + cu_make_column_selector(dtype_include=np.integer)), + ("normalizer", cuNormalizer(), + cu_make_column_selector(pattern="temp")) + ] + transformer = cuColumnTransformer(cu_transformers, remainder='drop') + t_X = transformer.fit_transform(X) + + sk_transformers = [ + ("ohe", skOneHotEncoder(), + sk_make_column_selector(dtype_exclude=np.number)), + ("scaler", skStandardScaler(), + sk_make_column_selector(dtype_include=np.integer)), + ("normalizer", skNormalizer(), + sk_make_column_selector(pattern="temp")) + ] + transformer = skColumnTransformer(sk_transformers, remainder='drop') + sk_t_X = transformer.fit_transform(X_np) + + assert_allclose(t_X, sk_t_X) + assert type(t_X) == type(X) diff --git a/python/cuml/test/test_preproc_utils.py b/python/cuml/test/test_preproc_utils.py index f2937b1173..b5f7212b04 100644 --- a/python/cuml/test/test_preproc_utils.py +++ b/python/cuml/test/test_preproc_utils.py @@ -87,7 +87,13 @@ def to_output_type(array, output_type, order='F'): if output_type == 'series' and len(array.shape) > 1: output_type = 'cudf' - return cuml_array.to_output(output_type) + output = cuml_array.to_output(output_type) + + if output_type in ['dataframe', 'cudf']: + renaming = {i: 'c'+str(i) for i in range(output.shape[1])} + output = output.rename(columns=renaming) + + return output def create_rand_clf(random_state): @@ -96,7 +102,7 @@ def create_rand_clf(random_state): n_clusters_per_class=1, n_informative=12, n_classes=5, - order='F', + order='C', random_state=random_state) return clf @@ -105,7 +111,7 @@ def create_rand_blobs(random_state): blobs, _ = make_blobs(n_samples=500, n_features=20, centers=20, - order='F', + order='C', random_state=random_state) return blobs diff --git a/python/cuml/test/test_preprocessing.py b/python/cuml/test/test_preprocessing.py index 88dc1bed9f..54152445e2 100644 --- a/python/cuml/test/test_preprocessing.py +++ b/python/cuml/test/test_preprocessing.py @@ -26,6 +26,8 @@ RobustScaler as cuRobustScaler, \ KBinsDiscretizer as cuKBinsDiscretizer, \ MissingIndicator as cuMissingIndicator +from cuml.experimental.preprocessing import \ + FunctionTransformer as cuFunctionTransformer from cuml.preprocessing import scale as cu_scale, \ minmax_scale as cu_minmax_scale, \ maxabs_scale as cu_maxabs_scale, \ @@ -39,7 +41,9 @@ Normalizer as skNormalizer, \ Binarizer as skBinarizer, \ PolynomialFeatures as skPolynomialFeatures, \ - RobustScaler as skRobustScaler + RobustScaler as skRobustScaler, \ + KBinsDiscretizer as skKBinsDiscretizer, \ + FunctionTransformer as skFunctionTransformer from sklearn.preprocessing import scale as sk_scale, \ minmax_scale as sk_minmax_scale, \ maxabs_scale as sk_maxabs_scale, \ @@ -49,7 +53,6 @@ robust_scale as sk_robust_scale from sklearn.impute import SimpleImputer as skSimpleImputer, \ MissingIndicator as skMissingIndicator -from sklearn.preprocessing import KBinsDiscretizer as skKBinsDiscretizer from cuml.test.test_preproc_utils import \ clf_dataset, int_dataset, blobs_dataset, \ @@ -749,6 +752,44 @@ def test_missing_indicator_sparse(failure_logger, assert_allclose(t_X, sk_t_X) +def test_function_transformer(clf_dataset): # noqa: F811 + X_np, X = clf_dataset + + transformer = cuFunctionTransformer(func=cp.exp, inverse_func=cp.log) + t_X = transformer.fit_transform(X) + r_X = transformer.inverse_transform(t_X) + assert type(t_X) == type(X) + assert type(r_X) == type(t_X) + + transformer = skFunctionTransformer(func=np.exp, inverse_func=np.log) + sk_t_X = transformer.fit_transform(X_np) + sk_r_X = transformer.inverse_transform(sk_t_X) + + assert_allclose(t_X, sk_t_X) + assert_allclose(r_X, sk_r_X) + + +def test_function_transformer_sparse(sparse_clf_dataset): # noqa: F811 + X_np, X = sparse_clf_dataset + + transformer = cuFunctionTransformer(func=lambda x: x * 2, + inverse_func=lambda x: x / 2, + accept_sparse=True) + t_X = transformer.fit_transform(X) + r_X = transformer.inverse_transform(t_X) + assert cp.sparse.issparse(t_X) or scipy.sparse.issparse(t_X) + assert cp.sparse.issparse(r_X) or scipy.sparse.issparse(r_X) + + transformer = skFunctionTransformer(func=lambda x: x * 2, + inverse_func=lambda x: x / 2, + accept_sparse=True) + sk_t_X = transformer.fit_transform(X_np) + sk_r_X = transformer.inverse_transform(sk_t_X) + + assert_allclose(t_X, sk_t_X) + assert_allclose(r_X, sk_r_X) + + def test__repr__(): assert cuStandardScaler().__repr__() == 'StandardScaler()' assert cuMinMaxScaler().__repr__() == 'MinMaxScaler()' diff --git a/python/cuml/thirdparty_adapters/adapters.py b/python/cuml/thirdparty_adapters/adapters.py index c57f2d3958..6b67dd7b99 100644 --- a/python/cuml/thirdparty_adapters/adapters.py +++ b/python/cuml/thirdparty_adapters/adapters.py @@ -22,7 +22,8 @@ from scipy import sparse as cpu_sparse from cupy import sparse as gpu_sparse -from cudf.core import DataFrame as cuDataFrame +from pandas import DataFrame as pdDataFrame +from cudf import DataFrame as cuDataFrame numeric_types = [ np.int8, np.int16, np.int32, np.int64, @@ -106,12 +107,13 @@ def check_dtype(array, dtypes='numeric'): # fp16 is not supported, so remove from the list of dtypes if present dtypes = [d for d in dtypes if d != np.float16] - if not isinstance(array, cuDataFrame): + if not isinstance(array, (pdDataFrame, cuDataFrame)): if array.dtype not in dtypes: return dtypes[0] elif any([dt not in dtypes for dt in array.dtypes.tolist()]): return dtypes[0] - if not isinstance(array, cuDataFrame): + + if not isinstance(array, (pdDataFrame, cuDataFrame)): return array.dtype else: return array.dtypes.tolist()[0] @@ -222,7 +224,8 @@ def check_array(array, accept_sparse=False, accept_large_sparse=True, correct_dtype = check_dtype(array, dtype) - if copy and not order and hasattr(array, 'flags'): + if (not isinstance(array, (pdDataFrame, cuDataFrame)) + and copy and not order and hasattr(array, 'flags')): if array.flags['F_CONTIGUOUS']: order = 'F' elif array.flags['C_CONTIGUOUS']: