diff --git a/doc/whats_new/v1.3.rst b/doc/whats_new/v1.3.rst index 9cab0db995c5d..ec1301844b877 100644 --- a/doc/whats_new/v1.3.rst +++ b/doc/whats_new/v1.3.rst @@ -487,6 +487,11 @@ Changelog categorical encoding based on target mean conditioned on the value of the category. :pr:`25334` by `Thomas Fan`_. +- |Enhancement| A new parameter `sparse_output` was added to + :class:`SplineTransformer`, available as of SciPy 1.8. If `sparse_output=True`, + :class:`SplineTransformer` returns a sparse CSR matrix. + :pr:`24145` by :user:`Christian Lorentzen `. + - |Enhancement| Adds a `feature_name_combiner` parameter to :class:`preprocessing.OneHotEncoder`. This specifies a custom callable to create feature names to be returned by :meth:`get_feature_names_out`. diff --git a/sklearn/preprocessing/_polynomial.py b/sklearn/preprocessing/_polynomial.py index 64ecb9864fae0..f379ee9135706 100644 --- a/sklearn/preprocessing/_polynomial.py +++ b/sklearn/preprocessing/_polynomial.py @@ -13,11 +13,11 @@ from ..base import BaseEstimator, TransformerMixin from ..utils import check_array +from ..utils.fixes import sp_version, parse_version from ..utils.validation import check_is_fitted, FLOAT_DTYPES, _check_sample_weight from ..utils.validation import _check_feature_names_in from ..utils._param_validation import Interval, StrOptions from ..utils.stats import _weighted_percentile -from ..utils.fixes import sp_version, parse_version from ._csr_polynomial_expansion import ( _csr_polynomial_expansion, @@ -574,8 +574,6 @@ def transform(self, X): return XP -# TODO: -# - sparse support (either scipy or own cython solution)? class SplineTransformer(TransformerMixin, BaseEstimator): """Generate univariate B-spline bases for features. @@ -635,8 +633,14 @@ class SplineTransformer(TransformerMixin, BaseEstimator): i.e. a column of ones. It acts as an intercept term in a linear models. order : {'C', 'F'}, default='C' - Order of output array. 'F' order is faster to compute, but may slow - down subsequent estimators. + Order of output array in the dense case. `'F'` order is faster to compute, but + may slow down subsequent estimators. + + sparse_output : bool, default=False + Will return sparse CSR matrix if set True else will return an array. This + option is only available with `scipy>=1.8`. + + .. versionadded:: 1.2 Attributes ---------- @@ -699,6 +703,7 @@ class SplineTransformer(TransformerMixin, BaseEstimator): ], "include_bias": ["boolean"], "order": [StrOptions({"C", "F"})], + "sparse_output": ["boolean"], } def __init__( @@ -710,6 +715,7 @@ def __init__( extrapolation="constant", include_bias=True, order="C", + sparse_output=False, ): self.n_knots = n_knots self.degree = degree @@ -717,6 +723,7 @@ def __init__( self.extrapolation = extrapolation self.include_bias = include_bias self.order = order + self.sparse_output = sparse_output @staticmethod def _get_base_knot_positions(X, n_knots=10, knots="uniform", sample_weight=None): @@ -843,6 +850,12 @@ def fit(self, X, y=None, sample_weight=None): elif not np.all(np.diff(base_knots, axis=0) > 0): raise ValueError("knots must be sorted without duplicates.") + if self.sparse_output and sp_version < parse_version("1.8.0"): + raise ValueError( + "Option sparse_output=True is only available with scipy>=1.8.0, " + f"but here scipy=={sp_version} is used." + ) + # number of knots for base interval n_knots = base_knots.shape[0] @@ -934,7 +947,7 @@ def transform(self, X): Returns ------- - XBS : ndarray of shape (n_samples, n_features * n_splines) + XBS : {ndarray, sparse matrix} of shape (n_samples, n_features * n_splines) The matrix of features, where n_splines is the number of bases elements of the B-splines, n_knots + degree - 1. """ @@ -946,6 +959,19 @@ def transform(self, X): n_splines = self.bsplines_[0].c.shape[1] degree = self.degree + # TODO: Remove this condition, once scipy 1.10 is the minimum version. + # Only scipy => 1.10 supports design_matrix(.., extrapolate=..). + # The default (implicit in scipy < 1.10) is extrapolate=False. + scipy_1_10 = sp_version >= parse_version("1.10.0") + # Note: self.bsplines_[0].extrapolate is True for extrapolation in + # ["periodic", "continue"] + if scipy_1_10: + use_sparse = self.sparse_output + kwargs_extrapolate = {"extrapolate": self.bsplines_[0].extrapolate} + else: + use_sparse = self.sparse_output and not self.bsplines_[0].extrapolate + kwargs_extrapolate = dict() + # Note that scipy BSpline returns float64 arrays and converts input # x=X[:, i] to c-contiguous float64. n_out = self.n_features_out_ + n_features * (1 - self.include_bias) @@ -953,7 +979,10 @@ def transform(self, X): dtype = X.dtype else: dtype = np.float64 - XBS = np.zeros((n_samples, n_out), dtype=dtype, order=self.order) + if use_sparse: + output_list = [] + else: + XBS = np.zeros((n_samples, n_out), dtype=dtype, order=self.order) for i in range(n_features): spl = self.bsplines_[i] @@ -972,20 +1001,53 @@ def transform(self, X): else: x = X[:, i] - XBS[:, (i * n_splines) : ((i + 1) * n_splines)] = spl(x) - - else: - xmin = spl.t[degree] - xmax = spl.t[-degree - 1] + if use_sparse: + XBS_sparse = BSpline.design_matrix( + x, spl.t, spl.k, **kwargs_extrapolate + ) + if self.extrapolation == "periodic": + # See the construction of coef in fit. We need to add the last + # degree spline basis function to the first degree ones and + # then drop the last ones. + # Note: See comment about SparseEfficiencyWarning below. + XBS_sparse = XBS_sparse.tolil() + XBS_sparse[:, :degree] += XBS_sparse[:, -degree:] + XBS_sparse = XBS_sparse[:, :-degree] + else: + XBS[:, (i * n_splines) : ((i + 1) * n_splines)] = spl(x) + else: # extrapolation in ("constant", "linear") + xmin, xmax = spl.t[degree], spl.t[-degree - 1] + # spline values at boundaries + f_min, f_max = spl(xmin), spl(xmax) mask = (xmin <= X[:, i]) & (X[:, i] <= xmax) - XBS[mask, (i * n_splines) : ((i + 1) * n_splines)] = spl(X[mask, i]) + if use_sparse: + mask_inv = ~mask + x = X[:, i].copy() + # Set some arbitrary values outside boundary that will be reassigned + # later. + x[mask_inv] = spl.t[self.degree] + XBS_sparse = BSpline.design_matrix(x, spl.t, spl.k) + # Note: Without converting to lil_matrix we would get: + # scipy.sparse._base.SparseEfficiencyWarning: Changing the sparsity + # structure of a csr_matrix is expensive. lil_matrix is more + # efficient. + if np.any(mask_inv): + XBS_sparse = XBS_sparse.tolil() + XBS_sparse[mask_inv, :] = 0 + else: + XBS[mask, (i * n_splines) : ((i + 1) * n_splines)] = spl(X[mask, i]) # Note for extrapolation: # 'continue' is already returned as is by scipy BSplines if self.extrapolation == "error": # BSpline with extrapolate=False does not raise an error, but - # output np.nan. - if np.any(np.isnan(XBS[:, (i * n_splines) : ((i + 1) * n_splines)])): + # outputs np.nan. + if (use_sparse and np.any(np.isnan(XBS_sparse.data))) or ( + not use_sparse + and np.any( + np.isnan(XBS[:, (i * n_splines) : ((i + 1) * n_splines)]) + ) + ): raise ValueError( "X contains values beyond the limits of the knots." ) @@ -995,21 +1057,29 @@ def transform(self, X): # Only the first degree and last degree number of splines # have non-zero values at the boundaries. - # spline values at boundaries - f_min = spl(xmin) - f_max = spl(xmax) mask = X[:, i] < xmin if np.any(mask): - XBS[mask, (i * n_splines) : (i * n_splines + degree)] = f_min[ - :degree - ] + if use_sparse: + # Note: See comment about SparseEfficiencyWarning above. + XBS_sparse = XBS_sparse.tolil() + XBS_sparse[mask, :degree] = f_min[:degree] + + else: + XBS[mask, (i * n_splines) : (i * n_splines + degree)] = f_min[ + :degree + ] mask = X[:, i] > xmax if np.any(mask): - XBS[ - mask, - ((i + 1) * n_splines - degree) : ((i + 1) * n_splines), - ] = f_max[-degree:] + if use_sparse: + # Note: See comment about SparseEfficiencyWarning above. + XBS_sparse = XBS_sparse.tolil() + XBS_sparse[mask, -degree:] = f_max[-degree:] + else: + XBS[ + mask, + ((i + 1) * n_splines - degree) : ((i + 1) * n_splines), + ] = f_max[-degree:] elif self.extrapolation == "linear": # Continue the degree first and degree last spline bases @@ -1018,8 +1088,6 @@ def transform(self, X): # Note that all others have derivative = value = 0 at the # boundaries. - # spline values at boundaries - f_min, f_max = spl(xmin), spl(xmax) # spline derivatives = slopes at boundaries fp_min, fp_max = spl(xmin, nu=1), spl(xmax, nu=1) # Compute the linear continuation. @@ -1030,16 +1098,57 @@ def transform(self, X): for j in range(degree): mask = X[:, i] < xmin if np.any(mask): - XBS[mask, i * n_splines + j] = ( - f_min[j] + (X[mask, i] - xmin) * fp_min[j] - ) + linear_extr = f_min[j] + (X[mask, i] - xmin) * fp_min[j] + if use_sparse: + # Note: See comment about SparseEfficiencyWarning above. + XBS_sparse = XBS_sparse.tolil() + XBS_sparse[mask, j] = linear_extr + else: + XBS[mask, i * n_splines + j] = linear_extr mask = X[:, i] > xmax if np.any(mask): k = n_splines - 1 - j - XBS[mask, i * n_splines + k] = ( - f_max[k] + (X[mask, i] - xmax) * fp_max[k] - ) + linear_extr = f_max[k] + (X[mask, i] - xmax) * fp_max[k] + if use_sparse: + # Note: See comment about SparseEfficiencyWarning above. + XBS_sparse = XBS_sparse.tolil() + XBS_sparse[mask, k : k + 1] = linear_extr[:, None] + else: + XBS[mask, i * n_splines + k] = linear_extr + + if use_sparse: + if not sparse.isspmatrix_csr(XBS_sparse): + XBS_sparse = XBS_sparse.tocsr() + output_list.append(XBS_sparse) + + if use_sparse: + # TODO: Remove this conditional error when the minimum supported version of + # SciPy is 1.9.2 + # `scipy.sparse.hstack` breaks in scipy<1.9.2 + # when `n_features_out_ > max_int32` + max_int32 = np.iinfo(np.int32).max + all_int32 = True + for mat in output_list: + all_int32 &= mat.indices.dtype == np.int32 + if ( + sp_version < parse_version("1.9.2") + and self.n_features_out_ > max_int32 + and all_int32 + ): + raise ValueError( + "In scipy versions `<1.9.2`, the function `scipy.sparse.hstack`" + " produces negative columns when:\n1. The output shape contains" + " `n_cols` too large to be represented by a 32bit signed" + " integer.\n. All sub-matrices to be stacked have indices of" + " dtype `np.int32`.\nTo avoid this error, either use a version" + " of scipy `>=1.9.2` or alter the `SplineTransformer`" + " transformer to produce fewer than 2^31 output features" + ) + XBS = sparse.hstack(output_list) + elif self.sparse_output: + # TODO: Remove ones scipy 1.10 is the minimum version. See comments above. + XBS = sparse.csr_matrix(XBS) if self.include_bias: return XBS diff --git a/sklearn/preprocessing/tests/test_polynomial.py b/sklearn/preprocessing/tests/test_polynomial.py index 727b31b793b1d..1062a3da820e7 100644 --- a/sklearn/preprocessing/tests/test_polynomial.py +++ b/sklearn/preprocessing/tests/test_polynomial.py @@ -35,6 +35,22 @@ def is_c_contiguous(a): assert np.isfortran(est(order="F").fit_transform(X)) +@pytest.mark.parametrize( + "params, err_msg", + [ + ({"knots": [[1]]}, r"Number of knots, knots.shape\[0\], must be >= 2."), + ({"knots": [[1, 1], [2, 2]]}, r"knots.shape\[1\] == n_features is violated"), + ({"knots": [[1], [0]]}, "knots must be sorted without duplicates."), + ], +) +def test_spline_transformer_input_validation(params, err_msg): + """Test that we raise errors for invalid input in SplineTransformer.""" + X = [[1], [2]] + + with pytest.raises(ValueError, match=err_msg): + SplineTransformer(**params).fit(X) + + @pytest.mark.parametrize("extrapolation", ["continue", "periodic"]) def test_spline_transformer_integer_knots(extrapolation): """Test that SplineTransformer accepts integer value knot positions.""" @@ -109,8 +125,7 @@ def test_split_transform_feature_names_extrapolation_degree(extrapolation, degre def test_spline_transformer_unity_decomposition(degree, n_knots, knots, extrapolation): """Test that B-splines are indeed a decomposition of unity. - Splines basis functions must sum up to 1 per row, if we stay in between - boundaries. + Splines basis functions must sum up to 1 per row, if we stay in between boundaries. """ X = np.linspace(0, 1, 100)[:, None] # make the boundaries 0 and 1 part of X_train, for sure. @@ -178,8 +193,7 @@ def test_spline_transformer_linear_regression(bias, intercept): def test_spline_transformer_get_base_knot_positions( knots, n_knots, sample_weight, expected_knots ): - # Check the behaviour to find the positions of the knots with and without - # `sample_weight` + """Check the behaviour to find knot positions with and without sample_weight.""" X = np.array([[0, 2], [0, 2], [2, 2], [3, 3], [4, 6], [5, 8], [6, 14]]) base_knots = SplineTransformer._get_base_knot_positions( X=X, knots=knots, n_knots=n_knots, sample_weight=sample_weight @@ -238,9 +252,7 @@ def test_spline_transformer_periodic_spline_backport(): def test_spline_transformer_periodic_splines_periodicity(): - """ - Test if shifted knots result in the same transformation up to permutation. - """ + """Test if shifted knots result in the same transformation up to permutation.""" X = np.linspace(0, 10, 101)[:, None] transformer_1 = SplineTransformer( @@ -349,9 +361,10 @@ def test_spline_transformer_extrapolation(bias, intercept, degree): n_knots=4, degree=degree, include_bias=bias, extrapolation="error" ) splt.fit(X) - with pytest.raises(ValueError): + msg = "X contains values beyond the limits of the knots" + with pytest.raises(ValueError, match=msg): splt.transform([[-10]]) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=msg): splt.transform([[5]]) @@ -375,12 +388,94 @@ def test_spline_transformer_kbindiscretizer(): assert_allclose(splines, kbins, rtol=1e-13) +@pytest.mark.skipif( + sp_version < parse_version("1.8.0"), + reason="The option `sparse_output` is available as of scipy 1.8.0", +) +@pytest.mark.parametrize("degree", range(1, 3)) +@pytest.mark.parametrize("knots", ["uniform", "quantile"]) +@pytest.mark.parametrize( + "extrapolation", ["error", "constant", "linear", "continue", "periodic"] +) +@pytest.mark.parametrize("include_bias", [False, True]) +def test_spline_transformer_sparse_output( + degree, knots, extrapolation, include_bias, global_random_seed +): + rng = np.random.RandomState(global_random_seed) + X = rng.randn(200).reshape(40, 5) + + splt_dense = SplineTransformer( + degree=degree, + knots=knots, + extrapolation=extrapolation, + include_bias=include_bias, + sparse_output=False, + ) + splt_sparse = SplineTransformer( + degree=degree, + knots=knots, + extrapolation=extrapolation, + include_bias=include_bias, + sparse_output=True, + ) + + splt_dense.fit(X) + splt_sparse.fit(X) + + assert sparse.isspmatrix_csr(splt_sparse.transform(X)) + assert_allclose(splt_dense.transform(X), splt_sparse.transform(X).toarray()) + + # extrapolation regime + X_min = np.amin(X, axis=0) + X_max = np.amax(X, axis=0) + X_extra = np.r_[ + np.linspace(X_min - 5, X_min, 10), np.linspace(X_max, X_max + 5, 10) + ] + if extrapolation == "error": + msg = "X contains values beyond the limits of the knots" + with pytest.raises(ValueError, match=msg): + splt_dense.transform(X_extra) + msg = "Out of bounds" + with pytest.raises(ValueError, match=msg): + splt_sparse.transform(X_extra) + else: + assert_allclose( + splt_dense.transform(X_extra), splt_sparse.transform(X_extra).toarray() + ) + + +@pytest.mark.skipif( + sp_version >= parse_version("1.8.0"), + reason="The option `sparse_output` is available as of scipy 1.8.0", +) +def test_spline_transformer_sparse_output_raise_error_for_old_scipy(): + """Test that SplineTransformer with sparse=True raises for scipy<1.8.0.""" + X = [[1], [2]] + with pytest.raises(ValueError, match="scipy>=1.8.0"): + SplineTransformer(sparse_output=True).fit(X) + + @pytest.mark.parametrize("n_knots", [5, 10]) @pytest.mark.parametrize("include_bias", [True, False]) -@pytest.mark.parametrize("degree", [3, 5]) -def test_spline_transformer_n_features_out(n_knots, include_bias, degree): +@pytest.mark.parametrize("degree", [3, 4]) +@pytest.mark.parametrize( + "extrapolation", ["error", "constant", "linear", "continue", "periodic"] +) +@pytest.mark.parametrize("sparse_output", [False, True]) +def test_spline_transformer_n_features_out( + n_knots, include_bias, degree, extrapolation, sparse_output +): """Test that transform results in n_features_out_ features.""" - splt = SplineTransformer(n_knots=n_knots, degree=degree, include_bias=include_bias) + if sparse_output and sp_version < parse_version("1.8.0"): + pytest.skip("The option `sparse_output` is available as of scipy 1.8.0") + + splt = SplineTransformer( + n_knots=n_knots, + degree=degree, + include_bias=include_bias, + extrapolation=extrapolation, + sparse_output=sparse_output, + ) X = np.linspace(0, 1, 10)[:, None] splt.fit(X)