Skip to content

Commit

Permalink
Merge pull request #362 from thomas0125/fix_gridsearch
Browse files Browse the repository at this point in the history
Fix GridSearchCV to support sklearn>=1.3.0
  • Loading branch information
Eleven1Liu committed Mar 7, 2024
2 parents 7106704 + 5b04bc5 commit d9307a0
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 11 deletions.
15 changes: 7 additions & 8 deletions libmultilabel/linear/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,32 +109,31 @@ class GridSearchCV(sklearn.model_selection.GridSearchCV):
The usage is similar to sklearn's, except that the parameter ``scoring`` is unavailable. Instead, specify ``scoring_metric`` in ``MultiLabelEstimator`` in the Pipeline.
Args:
pipeline (sklearn.pipeline.Pipeline): A sklearn Pipeline for grid search.
estimator (estimator object): An estimator for grid search.
param_grid (dict): Search space for a grid search containing a dictionary of
parameters and their corresponding list of candidate values.
n_jobs (int, optional): Number of CPU cores run in parallel. Defaults to None.
"""

_required_parameters = ["pipeline", "param_grid"]
_required_parameters = ["estimator", "param_grid"]

def __init__(self, pipeline: sklearn.pipeline.Pipeline, param_grid: dict, n_jobs=None, **kwargs):
assert isinstance(pipeline, sklearn.pipeline.Pipeline)
def __init__(self, estimator, param_grid: dict, n_jobs=None, **kwargs):
if n_jobs is not None and n_jobs > 1:
param_grid = self._set_singlecore_options(pipeline, param_grid)
param_grid = self._set_singlecore_options(estimator, param_grid)
if "scoring" in kwargs.keys():
raise ValueError(
"Please specify the validation metric with `MultiLabelEstimator.scoring_metric` in the Pipeline instead of using the parameter `scoring`."
)

super().__init__(estimator=pipeline, n_jobs=n_jobs, param_grid=param_grid, **kwargs)
super().__init__(estimator=estimator, n_jobs=n_jobs, param_grid=param_grid, **kwargs)

def _set_singlecore_options(self, pipeline: sklearn.pipeline.Pipeline, param_grid: dict):
def _set_singlecore_options(self, estimator, param_grid: dict):
"""Set liblinear options to `-m 1`. The grid search option `n_jobs`
runs multiple processes in parallel. Using multithreaded liblinear
in conjunction with grid search oversubscribes the CPU and deteriorates
the performance significantly.
"""
params = pipeline.get_params()
params = estimator.get_params()
for name, transform in params.items():
if isinstance(transform, MultiLabelEstimator):
regex = r"-m \d+"
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ liblinear-multicore
numba
pandas>1.3.0
PyYAML
scikit-learn==1.2.2
scikit-learn
scipy
tqdm
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = libmultilabel
version = 0.6.1
version = 0.6.2
author = LibMultiLabel Team
license = MIT License
license_file = LICENSE
Expand Down Expand Up @@ -29,7 +29,7 @@ install_requires =
numba
pandas>1.3.0
PyYAML
scikit-learn==1.2.2
scikit-learn
scipy
tqdm

Expand Down

0 comments on commit d9307a0

Please sign in to comment.