Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions boruta/boruta_py.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@
from __future__ import print_function, division
import numpy as np
import scipy as sp
from sklearn.ensemble import RandomForestClassifier
from sklearn.utils import check_random_state, check_X_y
from sklearn.base import BaseEstimator
from sklearn.feature_selection import SelectorMixin
from sklearn.utils.validation import check_is_fitted
import warnings


class BorutaPy(BaseEstimator, SelectorMixin):
class BorutaPy(SelectorMixin, BaseEstimator):
"""
Improved Python implementation of the Boruta R package.

Expand Down Expand Up @@ -74,7 +75,7 @@ class BorutaPy(BaseEstimator, SelectorMixin):
Parameters
----------

estimator : object
estimator : object, default = RandomForestClassifier()
A supervised learning estimator, with a 'fit' method that returns the
feature_importances_ attribute. Important features must correspond to
high absolute values in the feature_importances_.
Expand Down Expand Up @@ -192,7 +193,7 @@ class BorutaPy(BaseEstimator, SelectorMixin):
Journal of Statistical Software, Vol. 36, Issue 11, Sep 2010
"""

def __init__(self, estimator, n_estimators=1000, perc=100, alpha=0.05,
def __init__(self, estimator=RandomForestClassifier(), n_estimators=1000, perc=100, alpha=0.05,
two_step=True, max_iter=100, random_state=None, verbose=0,
early_stopping=False, n_iter_no_change=20):
self.estimator = estimator
Expand Down
11 changes: 11 additions & 0 deletions boruta/test/test_sklearn_compatibility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from sklearn.utils.estimator_checks import check_estimator
from boruta import BorutaPy

# def test_check_estimator():
# return check_estimator(BorutaPy())
#
# #TODO: delete after fixing
# from sklearn.utils.estimator_checks import parametrize_with_checks
# @parametrize_with_checks([BorutaPy()])
# def test_sklearn_compatible(estimator, check):
# check(estimator)