From 2ff3a0d36b19a1bafb64c09d9baaa5e31f036a0d Mon Sep 17 00:00:00 2001 From: solegalli Date: Wed, 12 Nov 2025 22:59:21 -0500 Subject: [PATCH 1/2] test sklearn compatibility --- boruta/boruta_py.py | 7 ++++--- boruta/test/test_sklearn_compatibility.py | 11 +++++++++++ 2 files changed, 15 insertions(+), 3 deletions(-) create mode 100644 boruta/test/test_sklearn_compatibility.py diff --git a/boruta/boruta_py.py b/boruta/boruta_py.py index d2a2d0a..31bafb7 100644 --- a/boruta/boruta_py.py +++ b/boruta/boruta_py.py @@ -11,6 +11,7 @@ from __future__ import print_function, division import numpy as np import scipy as sp +from sklearn.ensemble import RandomForestClassifier from sklearn.utils import check_random_state, check_X_y from sklearn.base import BaseEstimator from sklearn.feature_selection import SelectorMixin @@ -18,7 +19,7 @@ import warnings -class BorutaPy(BaseEstimator, SelectorMixin): +class BorutaPy(SelectorMixin, BaseEstimator): """ Improved Python implementation of the Boruta R package. @@ -74,7 +75,7 @@ class BorutaPy(BaseEstimator, SelectorMixin): Parameters ---------- - estimator : object + estimator : object, default = RandomForestClassifier() A supervised learning estimator, with a 'fit' method that returns the feature_importances_ attribute. Important features must correspond to high absolute values in the feature_importances_. @@ -192,7 +193,7 @@ class BorutaPy(BaseEstimator, SelectorMixin): Journal of Statistical Software, Vol. 36, Issue 11, Sep 2010 """ - def __init__(self, estimator, n_estimators=1000, perc=100, alpha=0.05, + def __init__(self, estimator=RandomForestClassifier(), n_estimators=1000, perc=100, alpha=0.05, two_step=True, max_iter=100, random_state=None, verbose=0, early_stopping=False, n_iter_no_change=20): self.estimator = estimator diff --git a/boruta/test/test_sklearn_compatibility.py b/boruta/test/test_sklearn_compatibility.py new file mode 100644 index 0000000..c293275 --- /dev/null +++ b/boruta/test/test_sklearn_compatibility.py @@ -0,0 +1,11 @@ +from sklearn.utils.estimator_checks import check_estimator +from boruta import BorutaPy + +def test_check_estimator(): + return check_estimator(BorutaPy()) + +#TODO: delete after fixing +from sklearn.utils.estimator_checks import parametrize_with_checks +@parametrize_with_checks([BorutaPy()]) +def test_sklearn_compatible(estimator, check): + check(estimator) From c9a77ae11e2ddace5397c5dfe2cf3629f006bfff Mon Sep 17 00:00:00 2001 From: solegalli Date: Wed, 12 Nov 2025 23:00:29 -0500 Subject: [PATCH 2/2] comment out tests --- boruta/test/test_sklearn_compatibility.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/boruta/test/test_sklearn_compatibility.py b/boruta/test/test_sklearn_compatibility.py index c293275..b2d3b13 100644 --- a/boruta/test/test_sklearn_compatibility.py +++ b/boruta/test/test_sklearn_compatibility.py @@ -1,11 +1,11 @@ from sklearn.utils.estimator_checks import check_estimator from boruta import BorutaPy -def test_check_estimator(): - return check_estimator(BorutaPy()) - -#TODO: delete after fixing -from sklearn.utils.estimator_checks import parametrize_with_checks -@parametrize_with_checks([BorutaPy()]) -def test_sklearn_compatible(estimator, check): - check(estimator) +# def test_check_estimator(): +# return check_estimator(BorutaPy()) +# +# #TODO: delete after fixing +# from sklearn.utils.estimator_checks import parametrize_with_checks +# @parametrize_with_checks([BorutaPy()]) +# def test_sklearn_compatible(estimator, check): +# check(estimator)