Skip to content

Commit e5c6590

Browse files
lorentzenchrjjerphanogrisel
authored
ENH add sparse output to SplineTransformer (#24145)
Co-authored-by: Julien Jerphanion <git@jjerphan.xyz> Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
1 parent 314f7ba commit e5c6590

File tree

3 files changed

+254
-45
lines changed

3 files changed

+254
-45
lines changed

doc/whats_new/v1.3.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,11 @@ Changelog
566566
categorical encoding based on target mean conditioned on the value of the
567567
category. :pr:`25334` by `Thomas Fan`_.
568568

569+
- |Enhancement| A new parameter `sparse_output` was added to
570+
:class:`SplineTransformer`, available as of SciPy 1.8. If `sparse_output=True`,
571+
:class:`SplineTransformer` returns a sparse CSR matrix.
572+
:pr:`24145` by :user:`Christian Lorentzen <lorentzenchr>`.
573+
569574
- |Enhancement| Adds a `feature_name_combiner` parameter to
570575
:class:`preprocessing.OneHotEncoder`. This specifies a custom callable to create
571576
feature names to be returned by :meth:`get_feature_names_out`.

sklearn/preprocessing/_polynomial.py

Lines changed: 142 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@
1313

1414
from ..base import BaseEstimator, TransformerMixin
1515
from ..utils import check_array
16+
from ..utils.fixes import sp_version, parse_version
1617
from ..utils.validation import check_is_fitted, FLOAT_DTYPES, _check_sample_weight
1718
from ..utils.validation import _check_feature_names_in
1819
from ..utils._param_validation import Interval, StrOptions
1920
from ..utils.stats import _weighted_percentile
20-
from ..utils.fixes import sp_version, parse_version
2121

2222
from ._csr_polynomial_expansion import (
2323
_csr_polynomial_expansion,
@@ -574,8 +574,6 @@ def transform(self, X):
574574
return XP
575575

576576

577-
# TODO:
578-
# - sparse support (either scipy or own cython solution)?
579577
class SplineTransformer(TransformerMixin, BaseEstimator):
580578
"""Generate univariate B-spline bases for features.
581579
@@ -635,8 +633,14 @@ class SplineTransformer(TransformerMixin, BaseEstimator):
635633
i.e. a column of ones. It acts as an intercept term in a linear models.
636634
637635
order : {'C', 'F'}, default='C'
638-
Order of output array. 'F' order is faster to compute, but may slow
639-
down subsequent estimators.
636+
Order of output array in the dense case. `'F'` order is faster to compute, but
637+
may slow down subsequent estimators.
638+
639+
sparse_output : bool, default=False
640+
Will return sparse CSR matrix if set True else will return an array. This
641+
option is only available with `scipy>=1.8`.
642+
643+
.. versionadded:: 1.2
640644
641645
Attributes
642646
----------
@@ -699,6 +703,7 @@ class SplineTransformer(TransformerMixin, BaseEstimator):
699703
],
700704
"include_bias": ["boolean"],
701705
"order": [StrOptions({"C", "F"})],
706+
"sparse_output": ["boolean"],
702707
}
703708

704709
def __init__(
@@ -710,13 +715,15 @@ def __init__(
710715
extrapolation="constant",
711716
include_bias=True,
712717
order="C",
718+
sparse_output=False,
713719
):
714720
self.n_knots = n_knots
715721
self.degree = degree
716722
self.knots = knots
717723
self.extrapolation = extrapolation
718724
self.include_bias = include_bias
719725
self.order = order
726+
self.sparse_output = sparse_output
720727

721728
@staticmethod
722729
def _get_base_knot_positions(X, n_knots=10, knots="uniform", sample_weight=None):
@@ -843,6 +850,12 @@ def fit(self, X, y=None, sample_weight=None):
843850
elif not np.all(np.diff(base_knots, axis=0) > 0):
844851
raise ValueError("knots must be sorted without duplicates.")
845852

853+
if self.sparse_output and sp_version < parse_version("1.8.0"):
854+
raise ValueError(
855+
"Option sparse_output=True is only available with scipy>=1.8.0, "
856+
f"but here scipy=={sp_version} is used."
857+
)
858+
846859
# number of knots for base interval
847860
n_knots = base_knots.shape[0]
848861

@@ -934,7 +947,7 @@ def transform(self, X):
934947
935948
Returns
936949
-------
937-
XBS : ndarray of shape (n_samples, n_features * n_splines)
950+
XBS : {ndarray, sparse matrix} of shape (n_samples, n_features * n_splines)
938951
The matrix of features, where n_splines is the number of bases
939952
elements of the B-splines, n_knots + degree - 1.
940953
"""
@@ -946,14 +959,30 @@ def transform(self, X):
946959
n_splines = self.bsplines_[0].c.shape[1]
947960
degree = self.degree
948961

962+
# TODO: Remove this condition, once scipy 1.10 is the minimum version.
963+
# Only scipy => 1.10 supports design_matrix(.., extrapolate=..).
964+
# The default (implicit in scipy < 1.10) is extrapolate=False.
965+
scipy_1_10 = sp_version >= parse_version("1.10.0")
966+
# Note: self.bsplines_[0].extrapolate is True for extrapolation in
967+
# ["periodic", "continue"]
968+
if scipy_1_10:
969+
use_sparse = self.sparse_output
970+
kwargs_extrapolate = {"extrapolate": self.bsplines_[0].extrapolate}
971+
else:
972+
use_sparse = self.sparse_output and not self.bsplines_[0].extrapolate
973+
kwargs_extrapolate = dict()
974+
949975
# Note that scipy BSpline returns float64 arrays and converts input
950976
# x=X[:, i] to c-contiguous float64.
951977
n_out = self.n_features_out_ + n_features * (1 - self.include_bias)
952978
if X.dtype in FLOAT_DTYPES:
953979
dtype = X.dtype
954980
else:
955981
dtype = np.float64
956-
XBS = np.zeros((n_samples, n_out), dtype=dtype, order=self.order)
982+
if use_sparse:
983+
output_list = []
984+
else:
985+
XBS = np.zeros((n_samples, n_out), dtype=dtype, order=self.order)
957986

958987
for i in range(n_features):
959988
spl = self.bsplines_[i]
@@ -972,20 +1001,53 @@ def transform(self, X):
9721001
else:
9731002
x = X[:, i]
9741003

975-
XBS[:, (i * n_splines) : ((i + 1) * n_splines)] = spl(x)
976-
977-
else:
978-
xmin = spl.t[degree]
979-
xmax = spl.t[-degree - 1]
1004+
if use_sparse:
1005+
XBS_sparse = BSpline.design_matrix(
1006+
x, spl.t, spl.k, **kwargs_extrapolate
1007+
)
1008+
if self.extrapolation == "periodic":
1009+
# See the construction of coef in fit. We need to add the last
1010+
# degree spline basis function to the first degree ones and
1011+
# then drop the last ones.
1012+
# Note: See comment about SparseEfficiencyWarning below.
1013+
XBS_sparse = XBS_sparse.tolil()
1014+
XBS_sparse[:, :degree] += XBS_sparse[:, -degree:]
1015+
XBS_sparse = XBS_sparse[:, :-degree]
1016+
else:
1017+
XBS[:, (i * n_splines) : ((i + 1) * n_splines)] = spl(x)
1018+
else: # extrapolation in ("constant", "linear")
1019+
xmin, xmax = spl.t[degree], spl.t[-degree - 1]
1020+
# spline values at boundaries
1021+
f_min, f_max = spl(xmin), spl(xmax)
9801022
mask = (xmin <= X[:, i]) & (X[:, i] <= xmax)
981-
XBS[mask, (i * n_splines) : ((i + 1) * n_splines)] = spl(X[mask, i])
1023+
if use_sparse:
1024+
mask_inv = ~mask
1025+
x = X[:, i].copy()
1026+
# Set some arbitrary values outside boundary that will be reassigned
1027+
# later.
1028+
x[mask_inv] = spl.t[self.degree]
1029+
XBS_sparse = BSpline.design_matrix(x, spl.t, spl.k)
1030+
# Note: Without converting to lil_matrix we would get:
1031+
# scipy.sparse._base.SparseEfficiencyWarning: Changing the sparsity
1032+
# structure of a csr_matrix is expensive. lil_matrix is more
1033+
# efficient.
1034+
if np.any(mask_inv):
1035+
XBS_sparse = XBS_sparse.tolil()
1036+
XBS_sparse[mask_inv, :] = 0
1037+
else:
1038+
XBS[mask, (i * n_splines) : ((i + 1) * n_splines)] = spl(X[mask, i])
9821039

9831040
# Note for extrapolation:
9841041
# 'continue' is already returned as is by scipy BSplines
9851042
if self.extrapolation == "error":
9861043
# BSpline with extrapolate=False does not raise an error, but
987-
# output np.nan.
988-
if np.any(np.isnan(XBS[:, (i * n_splines) : ((i + 1) * n_splines)])):
1044+
# outputs np.nan.
1045+
if (use_sparse and np.any(np.isnan(XBS_sparse.data))) or (
1046+
not use_sparse
1047+
and np.any(
1048+
np.isnan(XBS[:, (i * n_splines) : ((i + 1) * n_splines)])
1049+
)
1050+
):
9891051
raise ValueError(
9901052
"X contains values beyond the limits of the knots."
9911053
)
@@ -995,21 +1057,29 @@ def transform(self, X):
9951057
# Only the first degree and last degree number of splines
9961058
# have non-zero values at the boundaries.
9971059

998-
# spline values at boundaries
999-
f_min = spl(xmin)
1000-
f_max = spl(xmax)
10011060
mask = X[:, i] < xmin
10021061
if np.any(mask):
1003-
XBS[mask, (i * n_splines) : (i * n_splines + degree)] = f_min[
1004-
:degree
1005-
]
1062+
if use_sparse:
1063+
# Note: See comment about SparseEfficiencyWarning above.
1064+
XBS_sparse = XBS_sparse.tolil()
1065+
XBS_sparse[mask, :degree] = f_min[:degree]
1066+
1067+
else:
1068+
XBS[mask, (i * n_splines) : (i * n_splines + degree)] = f_min[
1069+
:degree
1070+
]
10061071

10071072
mask = X[:, i] > xmax
10081073
if np.any(mask):
1009-
XBS[
1010-
mask,
1011-
((i + 1) * n_splines - degree) : ((i + 1) * n_splines),
1012-
] = f_max[-degree:]
1074+
if use_sparse:
1075+
# Note: See comment about SparseEfficiencyWarning above.
1076+
XBS_sparse = XBS_sparse.tolil()
1077+
XBS_sparse[mask, -degree:] = f_max[-degree:]
1078+
else:
1079+
XBS[
1080+
mask,
1081+
((i + 1) * n_splines - degree) : ((i + 1) * n_splines),
1082+
] = f_max[-degree:]
10131083

10141084
elif self.extrapolation == "linear":
10151085
# Continue the degree first and degree last spline bases
@@ -1018,8 +1088,6 @@ def transform(self, X):
10181088
# Note that all others have derivative = value = 0 at the
10191089
# boundaries.
10201090

1021-
# spline values at boundaries
1022-
f_min, f_max = spl(xmin), spl(xmax)
10231091
# spline derivatives = slopes at boundaries
10241092
fp_min, fp_max = spl(xmin, nu=1), spl(xmax, nu=1)
10251093
# Compute the linear continuation.
@@ -1030,16 +1098,57 @@ def transform(self, X):
10301098
for j in range(degree):
10311099
mask = X[:, i] < xmin
10321100
if np.any(mask):
1033-
XBS[mask, i * n_splines + j] = (
1034-
f_min[j] + (X[mask, i] - xmin) * fp_min[j]
1035-
)
1101+
linear_extr = f_min[j] + (X[mask, i] - xmin) * fp_min[j]
1102+
if use_sparse:
1103+
# Note: See comment about SparseEfficiencyWarning above.
1104+
XBS_sparse = XBS_sparse.tolil()
1105+
XBS_sparse[mask, j] = linear_extr
1106+
else:
1107+
XBS[mask, i * n_splines + j] = linear_extr
10361108

10371109
mask = X[:, i] > xmax
10381110
if np.any(mask):
10391111
k = n_splines - 1 - j
1040-
XBS[mask, i * n_splines + k] = (
1041-
f_max[k] + (X[mask, i] - xmax) * fp_max[k]
1042-
)
1112+
linear_extr = f_max[k] + (X[mask, i] - xmax) * fp_max[k]
1113+
if use_sparse:
1114+
# Note: See comment about SparseEfficiencyWarning above.
1115+
XBS_sparse = XBS_sparse.tolil()
1116+
XBS_sparse[mask, k : k + 1] = linear_extr[:, None]
1117+
else:
1118+
XBS[mask, i * n_splines + k] = linear_extr
1119+
1120+
if use_sparse:
1121+
if not sparse.isspmatrix_csr(XBS_sparse):
1122+
XBS_sparse = XBS_sparse.tocsr()
1123+
output_list.append(XBS_sparse)
1124+
1125+
if use_sparse:
1126+
# TODO: Remove this conditional error when the minimum supported version of
1127+
# SciPy is 1.9.2
1128+
# `scipy.sparse.hstack` breaks in scipy<1.9.2
1129+
# when `n_features_out_ > max_int32`
1130+
max_int32 = np.iinfo(np.int32).max
1131+
all_int32 = True
1132+
for mat in output_list:
1133+
all_int32 &= mat.indices.dtype == np.int32
1134+
if (
1135+
sp_version < parse_version("1.9.2")
1136+
and self.n_features_out_ > max_int32
1137+
and all_int32
1138+
):
1139+
raise ValueError(
1140+
"In scipy versions `<1.9.2`, the function `scipy.sparse.hstack`"
1141+
" produces negative columns when:\n1. The output shape contains"
1142+
" `n_cols` too large to be represented by a 32bit signed"
1143+
" integer.\n. All sub-matrices to be stacked have indices of"
1144+
" dtype `np.int32`.\nTo avoid this error, either use a version"
1145+
" of scipy `>=1.9.2` or alter the `SplineTransformer`"
1146+
" transformer to produce fewer than 2^31 output features"
1147+
)
1148+
XBS = sparse.hstack(output_list)
1149+
elif self.sparse_output:
1150+
# TODO: Remove ones scipy 1.10 is the minimum version. See comments above.
1151+
XBS = sparse.csr_matrix(XBS)
10431152

10441153
if self.include_bias:
10451154
return XBS

0 commit comments

Comments
 (0)