-
Notifications
You must be signed in to change notification settings - Fork 3.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TST make sklearn integration test compatible with 0.24 #3533
Changes from 6 commits
55d35b6
c38d478
f3fab61
93be6f4
d1b80f3
debdb08
8a1cc9d
a8d7473
6ecf05d
3e984b4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,20 +4,21 @@ | |
import math | ||
import os | ||
import unittest | ||
import warnings | ||
|
||
import lightgbm as lgb | ||
import numpy as np | ||
import pytest | ||
from sklearn import __version__ as sk_version | ||
from sklearn.base import clone | ||
from sklearn.datasets import load_svmlight_file, make_multilabel_classification | ||
from sklearn.exceptions import SkipTestWarning | ||
from sklearn.metrics import log_loss, mean_squared_error | ||
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV, train_test_split | ||
from sklearn.multioutput import (MultiOutputClassifier, ClassifierChain, MultiOutputRegressor, | ||
RegressorChain) | ||
from sklearn.utils.estimator_checks import (_yield_all_checks, SkipTest, | ||
check_parameters_default_constructible) | ||
from sklearn.utils.estimator_checks import ( | ||
check_parameters_default_constructible, | ||
parametrize_with_checks, | ||
) | ||
from sklearn.utils.validation import check_is_fitted | ||
|
||
from .utils import load_boston, load_breast_cancer, load_digits, load_iris, load_linnerud | ||
|
@@ -452,29 +453,6 @@ def test_feature_importances_type(self): | |
importance_gain_top1 = sorted(importances_gain, reverse=True)[0] | ||
self.assertNotEqual(importance_split_top1, importance_gain_top1) | ||
|
||
# sklearn <0.19 cannot accept instance, but many tests could be passed only with min_data=1 and min_data_in_bin=1 | ||
@unittest.skipIf(sk_version < '0.19.0', 'scikit-learn version is less than 0.19') | ||
def test_sklearn_integration(self): | ||
# we cannot use `check_estimator` directly since there is no skip test mechanism | ||
for name, estimator in ((lgb.sklearn.LGBMClassifier.__name__, lgb.sklearn.LGBMClassifier), | ||
(lgb.sklearn.LGBMRegressor.__name__, lgb.sklearn.LGBMRegressor)): | ||
check_parameters_default_constructible(name, estimator) | ||
# we cannot leave default params (see https://github.com/microsoft/LightGBM/issues/833) | ||
estimator = estimator(min_child_samples=1, min_data_in_bin=1) | ||
for check in _yield_all_checks(name, estimator): | ||
check_name = check.func.__name__ if hasattr(check, 'func') else check.__name__ | ||
if check_name == 'check_estimators_nan_inf': | ||
continue # skip test because LightGBM deals with nan | ||
elif check_name == "check_no_attributes_set_in_init": | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be handle by the |
||
# skip test because scikit-learn incorrectly asserts that | ||
# private attributes cannot be set in __init__ | ||
# (see https://github.com/microsoft/LightGBM/issues/2628) | ||
continue | ||
try: | ||
check(name, estimator) | ||
except SkipTest as message: | ||
warnings.warn(message, SkipTestWarning) | ||
|
||
@unittest.skipIf(not lgb.compat.PANDAS_INSTALLED, 'pandas is not installed') | ||
def test_pandas_categorical(self): | ||
import pandas as pd | ||
|
@@ -1149,3 +1127,24 @@ def test_check_is_fitted(self): | |
rnk.fit(X, y, group=np.ones(X.shape[0])) | ||
for model in models: | ||
check_is_fitted(model) | ||
|
||
|
||
def _tested_estimators(): | ||
for Estimator in [lgb.sklearn.LGBMClassifier, lgb.sklearn.LGBMRegressor]: | ||
yield Estimator() | ||
|
||
|
||
@pytest.mark.skipif( | ||
sk_version < "0.23.0", reason="scikit-learn version is less than 0.23" | ||
) | ||
@parametrize_with_checks(list(_tested_estimators())) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Only the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. However, it needs to be a pure function not embedded within a class |
||
def test_sklearn_integration(estimator, check, request): | ||
estimator.set_params(min_child_samples=1, min_data_in_bin=1) | ||
check(estimator) | ||
|
||
|
||
@pytest.mark.parametrize("estimator", list(_tested_estimators())) | ||
def test_parameters_default_constructible(estimator): | ||
name, Estimator = estimator.__class__.__name__, estimator.__class__ | ||
# Test that estimators are default-constructible | ||
check_parameters_default_constructible(name, Estimator) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be handled by the tag
allow_nan
.