From 6de26d7ed1eba30360b5e4a7f72d2dbeab47a072 Mon Sep 17 00:00:00 2001 From: Matthias Feurer Date: Sat, 2 May 2015 13:21:17 +0200 Subject: [PATCH] Feature: weighting for imbalanced classes --- ParamSklearn/base.py | 4 +- ParamSklearn/classification.py | 16 ++- .../components/classification/adaboost.py | 5 +- .../classification/decision_tree.py | 4 +- .../components/classification/extra_trees.py | 4 +- .../{liblinear.py => liblinear_svc.py} | 0 .../classification/passive_aggresive.py | 1 - .../classification/random_forest.py | 4 +- .../components/classification/ridge.py | 7 +- ParamSklearn/components/classification/sgd.py | 7 +- .../components/preprocessing/balancing.py | 113 +++++++++++++++++ ...extra_trees_preproc_for_classification.py} | 8 +- .../components/preprocessing/imputation.py | 9 +- ...inear.py => liblinear_svc_preprocessor.py} | 5 +- .../preprocessing/random_trees_embedding.py | 2 +- .../components/preprocessing/select_rates.py | 11 +- ParamSklearn/implementations/Imputation.py | 11 +- misc/support_for_imbalanced_classes.txt | 22 ++++ source/first_steps.rst | 4 +- .../classification/test_liblinear.py | 2 +- .../preprocessing/test_balancing.py | 117 ++++++++++++++++++ .../preprocessing/test_extra_trees.py | 2 +- .../preprocessing/test_liblinear.py | 2 +- tests/test_classification.py | 40 +++--- ..._create_searchspace_util_classification.py | 2 +- tests/test_textclassification.py | 2 +- 26 files changed, 341 insertions(+), 63 deletions(-) rename ParamSklearn/components/classification/{liblinear.py => liblinear_svc.py} (100%) create mode 100644 ParamSklearn/components/preprocessing/balancing.py rename ParamSklearn/components/preprocessing/{extra_trees.py => extra_trees_preproc_for_classification.py} (96%) rename ParamSklearn/components/preprocessing/{liblinear.py => liblinear_svc_preprocessor.py} (97%) create mode 100644 misc/support_for_imbalanced_classes.txt create mode 100644 tests/components/preprocessing/test_balancing.py diff --git a/ParamSklearn/base.py b/ParamSklearn/base.py index 0b728faadc..91b9651596 100644 --- a/ParamSklearn/base.py +++ b/ParamSklearn/base.py @@ -88,8 +88,8 @@ def fit(self, X, Y, fit_params=None, init_params=None): # seperate the init parameters for the single methods init_params_per_method = defaultdict(dict) - if init_params is not None: - for init_param, value in init_params: + if init_params is not None and len(init_params) != 0: + for init_param, value in init_params.items(): method, param = init_param.split(":") init_params_per_method[method][param] = value diff --git a/ParamSklearn/classification.py b/ParamSklearn/classification.py index f7851624b8..6f21555878 100644 --- a/ParamSklearn/classification.py +++ b/ParamSklearn/classification.py @@ -11,6 +11,7 @@ from ParamSklearn import components as components from ParamSklearn.base import ParamSklearnBaseEstimator from ParamSklearn.util import SPARSE +from ParamSklearn.components.preprocessing.balancing import Balancing import ParamSklearn.create_searchspace_util @@ -61,9 +62,19 @@ class ParamSklearnClassifier(ClassifierMixin, ParamSklearnBaseEstimator): """ def fit(self, X, Y, fit_params=None, init_params=None): + self.num_targets = 1 if len(Y.shape) == 1 else Y.shape[1] + + # Weighting samples has to be done here, not in the components + if self.configuration['balancing:strategy'].value == 'weighting': + balancing = Balancing(strategy='weighting') + init_params, fit_params = balancing.get_weights( + Y, self.configuration['classifier'].value, + self.configuration['preprocessor'].value, + init_params, fit_params) + super(ParamSklearnClassifier, self).fit(X, Y, fit_params=fit_params, init_params=init_params) - self.num_targets = 1 if len(Y.shape) == 1 else Y.shape[1] + return self def predict_proba(self, X, batch_size=None): @@ -415,4 +426,5 @@ def _get_estimator_components(): @staticmethod def _get_pipeline(): - return ["imputation", "rescaling", "__preprocessor__", "__estimator__"] \ No newline at end of file + return ["imputation", "rescaling", "balancing", "__preprocessor__", + "__estimator__"] \ No newline at end of file diff --git a/ParamSklearn/components/classification/adaboost.py b/ParamSklearn/components/classification/adaboost.py index b8381ff082..4974d2de75 100644 --- a/ParamSklearn/components/classification/adaboost.py +++ b/ParamSklearn/components/classification/adaboost.py @@ -25,7 +25,7 @@ def __init__(self, n_estimators, learning_rate, algorithm='SAMME.R', self.estimator = None - def fit(self, X, Y): + def fit(self, X, Y, sample_weight=None): base_estimator = sklearn.tree.DecisionTreeClassifier(max_depth=self.max_depth) self.estimator = sklearn.ensemble.AdaBoostClassifier( @@ -34,9 +34,8 @@ def fit(self, X, Y): learning_rate=self.learning_rate, algorithm=self.algorithm, random_state=self.random_state - ) - self.estimator.fit(X, Y) + self.estimator.fit(X, Y, sample_weight=sample_weight) return self def predict(self, X): diff --git a/ParamSklearn/components/classification/decision_tree.py b/ParamSklearn/components/classification/decision_tree.py index 1c6c9ff06f..d0f8a98357 100644 --- a/ParamSklearn/components/classification/decision_tree.py +++ b/ParamSklearn/components/classification/decision_tree.py @@ -35,7 +35,7 @@ def __init__(self, criterion, max_features, max_depth, self.random_state = random_state self.estimator = None - def fit(self, X, y): + def fit(self, X, y, sample_weight=None): self.estimator = DecisionTreeClassifier( criterion=self.criterion, max_depth=self.max_depth, @@ -43,7 +43,7 @@ def fit(self, X, y): min_samples_leaf=self.min_samples_leaf, max_leaf_nodes=self.max_leaf_nodes, random_state=self.random_state) - self.estimator.fit(X, y) + self.estimator.fit(X, y, sample_weight=sample_weight) return self def predict(self, X): diff --git a/ParamSklearn/components/classification/extra_trees.py b/ParamSklearn/components/classification/extra_trees.py index c4cf44a5af..1d923bc7d1 100644 --- a/ParamSklearn/components/classification/extra_trees.py +++ b/ParamSklearn/components/classification/extra_trees.py @@ -59,7 +59,7 @@ def __init__(self, n_estimators, criterion, min_samples_leaf, self.verbose = int(verbose) self.estimator = None - def fit(self, X, Y): + def fit(self, X, Y, sample_weight=None): num_features = X.shape[1] max_features = int(float(self.max_features) * (np.log(num_features) + 1)) # Use at most half of the features @@ -78,7 +78,7 @@ def fit(self, X, Y): while len(self.estimator.estimators_) < self.n_estimators: tmp = self.estimator # TODO copy ? tmp.n_estimators += self.estimator_increment - tmp.fit(X, Y) + tmp.fit(X, Y, sample_weight=sample_weight) self.estimator = tmp return self diff --git a/ParamSklearn/components/classification/liblinear.py b/ParamSklearn/components/classification/liblinear_svc.py similarity index 100% rename from ParamSklearn/components/classification/liblinear.py rename to ParamSklearn/components/classification/liblinear_svc.py diff --git a/ParamSklearn/components/classification/passive_aggresive.py b/ParamSklearn/components/classification/passive_aggresive.py index 070845bfd7..61d95dd617 100644 --- a/ParamSklearn/components/classification/passive_aggresive.py +++ b/ParamSklearn/components/classification/passive_aggresive.py @@ -4,7 +4,6 @@ from HPOlibConfigSpace.hyperparameters import UniformFloatHyperparameter, \ CategoricalHyperparameter, UnParametrizedHyperparameter, \ UniformIntegerHyperparameter -from HPOlibConfigSpace.conditions import EqualsCondition from ParamSklearn.components.classification_base import \ ParamSklearnClassificationAlgorithm diff --git a/ParamSklearn/components/classification/random_forest.py b/ParamSklearn/components/classification/random_forest.py index 7c45259c94..e3f13e71a7 100644 --- a/ParamSklearn/components/classification/random_forest.py +++ b/ParamSklearn/components/classification/random_forest.py @@ -28,7 +28,7 @@ def __init__(self, n_estimators, criterion, max_features, self.n_jobs = n_jobs self.estimator = None - def fit(self, X, Y): + def fit(self, X, Y, sample_weight=None): self.n_estimators = int(self.n_estimators) if self.max_depth == "None": @@ -67,7 +67,7 @@ def fit(self, X, Y): while len(self.estimator.estimators_) < self.n_estimators: tmp = self.estimator # TODO I think we need to copy here! tmp.n_estimators += self.estimator_increment - tmp.fit(X, Y) + tmp.fit(X, Y, sample_weight=sample_weight) self.estimator = tmp return self diff --git a/ParamSklearn/components/classification/ridge.py b/ParamSklearn/components/classification/ridge.py index ac7d4966a2..6cfbc01be9 100644 --- a/ParamSklearn/components/classification/ridge.py +++ b/ParamSklearn/components/classification/ridge.py @@ -13,17 +13,20 @@ class Ridge(ParamSklearnClassificationAlgorithm): - def __init__(self, alpha, fit_intercept, tol, random_state=None): + def __init__(self, alpha, fit_intercept, tol, class_weight=None, + random_state=None): self.alpha = float(alpha) self.fit_intercept = bool(fit_intercept) self.tol = float(tol) + self.class_weight = class_weight self.random_state = random_state self.estimator = None def fit(self, X, Y): self.estimator = RidgeClassifier(alpha=self.alpha, fit_intercept=self.fit_intercept, - tol=self.tol) + tol=self.tol, + class_weight=self.class_weight) self.estimator.fit(X, Y) return self diff --git a/ParamSklearn/components/classification/sgd.py b/ParamSklearn/components/classification/sgd.py index d1aecf3547..38f673eaad 100644 --- a/ParamSklearn/components/classification/sgd.py +++ b/ParamSklearn/components/classification/sgd.py @@ -13,7 +13,7 @@ class SGD(ParamSklearnClassificationAlgorithm): def __init__(self, loss, penalty, alpha, fit_intercept, n_iter, - learning_rate, class_weight, l1_ratio=0.15, epsilon=0.1, + learning_rate, class_weight=None, l1_ratio=0.15, epsilon=0.1, eta0=0.01, power_t=0.5, random_state=None): self.loss = loss self.penalty = penalty @@ -111,10 +111,6 @@ def get_hyperparameter_search_space(dataset_properties=None): ["optimal", "invscaling", "constant"], default="optimal") eta0 = UniformFloatHyperparameter("eta0", 10**-7, 0.1, default=0.01) power_t = UniformFloatHyperparameter("power_t", 1e-5, 1, default=0.5) - # This does not allow for other resampling methods! - class_weight = CategoricalHyperparameter("class_weight", - ["None", "auto"], - default="None") cs = ConfigurationSpace() cs.add_hyperparameter(loss) cs.add_hyperparameter(penalty) @@ -126,7 +122,6 @@ def get_hyperparameter_search_space(dataset_properties=None): cs.add_hyperparameter(learning_rate) cs.add_hyperparameter(eta0) cs.add_hyperparameter(power_t) - cs.add_hyperparameter(class_weight) # TODO add passive/aggressive here, although not properly documented? elasticnet = EqualsCondition(l1_ratio, penalty, "elasticnet") diff --git a/ParamSklearn/components/preprocessing/balancing.py b/ParamSklearn/components/preprocessing/balancing.py new file mode 100644 index 0000000000..bde52ea0c5 --- /dev/null +++ b/ParamSklearn/components/preprocessing/balancing.py @@ -0,0 +1,113 @@ +import numpy as np + +from HPOlibConfigSpace.configuration_space import ConfigurationSpace +from HPOlibConfigSpace.hyperparameters import CategoricalHyperparameter + +from ParamSklearn.components.preprocessor_base import \ + ParamSklearnPreprocessingAlgorithm +from ParamSklearn.util import DENSE, SPARSE, INPUT + + +class Balancing(ParamSklearnPreprocessingAlgorithm): + def __init__(self, strategy, random_state=None): + self.strategy = strategy + + def fit(self, X, y=None): + raise NotImplementedError() + + def transform(self, X): + raise NotImplementedError() + + def get_weights(self, Y, classifier, preprocessor, init_params, fit_params): + if init_params is None: + init_params = {} + + if fit_params is None: + fit_params = {} + + # Classifiers which require sample weights: + # We can have adaboost in here, because in the fit method, + # the sample weights are normalized: + # https://github.com/scikit-learn/scikit-learn/blob/0.15.X/sklearn/ensemble/weight_boosting.py#L121 + clf_ = ['adaboost', 'decision_tree', 'extra_trees', 'random_forest', + 'gradient_boosting'] + pre_ = ['extra_trees_preproc_for_classification'] + if classifier in clf_ or preprocessor in pre_: + if len(Y.shape) > 1: + offsets = [2 ** i for i in range(Y.shape[1])] + Y_ = np.sum(Y * offsets, axis=1) + else: + Y_ = Y + + unique, counts = np.unique(Y_, return_counts=True) + cw = 1. / counts + cw = cw / np.mean(cw) + + sample_weights = np.ones(Y_.shape) + + for i, ue in enumerate(unique): + mask = Y_ == ue + sample_weights[mask] *= cw[i] + + if classifier in clf_: + fit_params['%s:sample_weight' % classifier] = sample_weights + if preprocessor in pre_: + fit_params['%s:sample_weight' % preprocessor] = sample_weights + + # Classifiers which can adjust sample weights themselves via the + # argument `class_weight` + clf_ = ['liblinear_svc', 'libsvm_svc', 'sgd'] + pre_ = ['liblinear_svc_preprocessor'] + if classifier in clf_: + init_params['%s:class_weight' % classifier] = 'auto' + if preprocessor in pre_: + init_params['%s:class_weight' % preprocessor] = 'auto' + + clf_ = ['ridge'] + if classifier in clf_: + class_weights = {} + + unique, counts = np.unique(Y, return_counts=True) + cw = 1. / counts + cw = cw / np.mean(cw) + + for i, ue in enumerate(unique): + class_weights[ue] = cw[i] + + if classifier in clf_: + init_params['%s:class_weight' % classifier] = class_weights + + return init_params, fit_params + + @staticmethod + def get_properties(): + return {'shortname': 'Balancing', + 'name': 'Balancing Imbalanced Class Distributions', + 'handles_missing_values': True, + 'handles_nominal_values': True, + 'handles_numerical_features': True, + 'prefers_data_scaled': False, + 'prefers_data_normalized': False, + 'handles_regression': False, + 'handles_classification': True, + 'handles_multiclass': True, + 'handles_multilabel': True, + 'is_deterministic': True, + 'handles_sparse': True, + 'handles_dense': True, + 'input': (DENSE, SPARSE), + 'output': INPUT, + 'preferred_dtype': None} + + @staticmethod + def get_hyperparameter_search_space(dataset_properties=None): + # TODO add replace by zero! + strategy = CategoricalHyperparameter( + "strategy", ["none", "weighting"], default="none") + cs = ConfigurationSpace() + cs.add_hyperparameter(strategy) + return cs + + def __str__(self): + name = self.get_properties()['name'] + return "ParamSklearn %s" % name diff --git a/ParamSklearn/components/preprocessing/extra_trees.py b/ParamSklearn/components/preprocessing/extra_trees_preproc_for_classification.py similarity index 96% rename from ParamSklearn/components/preprocessing/extra_trees.py rename to ParamSklearn/components/preprocessing/extra_trees_preproc_for_classification.py index f0787dca7f..254eb2fb00 100644 --- a/ParamSklearn/components/preprocessing/extra_trees.py +++ b/ParamSklearn/components/preprocessing/extra_trees_preproc_for_classification.py @@ -7,7 +7,7 @@ from ParamSklearn.components.preprocessor_base import \ ParamSklearnPreprocessingAlgorithm -from ParamSklearn.util import DENSE, PREDICTIONS +from ParamSklearn.util import DENSE, INPUT # get our own forests to replace the sklearn ones from ParamSklearn.implementations import forest @@ -60,7 +60,7 @@ def __init__(self, n_estimators, criterion, min_samples_leaf, self.verbose = int(verbose) self.preprocessor = None - def fit(self, X, Y): + def fit(self, X, Y, sample_weight=None): num_features = X.shape[1] max_features = int( float(self.max_features) * (np.log(num_features) + 1)) @@ -80,7 +80,7 @@ def fit(self, X, Y): while len(self.preprocessor.estimators_) < self.n_estimators: tmp = self.preprocessor # TODO copy ? tmp.n_estimators += self.estimator_increment - tmp.fit(X, Y) + tmp.fit(X, Y, sample_weight=sample_weight) self.preprocessor = tmp return self @@ -106,7 +106,7 @@ def get_properties(): 'is_deterministic': True, 'handles_sparse': False, 'input': (DENSE, ), - 'output': PREDICTIONS, + 'output': INPUT, # TODO find out what is best used here! # But rather fortran or C-contiguous? 'preferred_dtype': np.float32} diff --git a/ParamSklearn/components/preprocessing/imputation.py b/ParamSklearn/components/preprocessing/imputation.py index df2005f021..81f3a8274c 100644 --- a/ParamSklearn/components/preprocessing/imputation.py +++ b/ParamSklearn/components/preprocessing/imputation.py @@ -1,4 +1,5 @@ -import ParamSklearn.implementations.Imputation +#import ParamSklearn.implementations.Imputation +import sklearn.preprocessing from HPOlibConfigSpace.configuration_space import ConfigurationSpace from HPOlibConfigSpace.hyperparameters import CategoricalHyperparameter @@ -13,9 +14,9 @@ def __init__(self, strategy, random_state=None): self.strategy = strategy def fit(self, X, y=None): - self.preprocessor = ParamSklearn.implementations.Imputation.Imputer( - strategy=self.strategy, copy=False, dtype=X.dtype) - self.preprocessor.fit(X) + self.preprocessor = sklearn.preprocessing.Imputer( + strategy=self.strategy, copy=False) #, dtype=X.dtype) + self.preprocessor = self.preprocessor.fit(X) return self def transform(self, X): diff --git a/ParamSklearn/components/preprocessing/liblinear.py b/ParamSklearn/components/preprocessing/liblinear_svc_preprocessor.py similarity index 97% rename from ParamSklearn/components/preprocessing/liblinear.py rename to ParamSklearn/components/preprocessing/liblinear_svc_preprocessor.py index d6f03b131a..fc087fd619 100644 --- a/ParamSklearn/components/preprocessing/liblinear.py +++ b/ParamSklearn/components/preprocessing/liblinear_svc_preprocessor.py @@ -8,8 +8,7 @@ from ParamSklearn.components.preprocessor_base import \ ParamSklearnPreprocessingAlgorithm -from ParamSklearn.implementations.util import softmax -from ParamSklearn.util import SPARSE, DENSE, PREDICTIONS +from ParamSklearn.util import SPARSE, DENSE, INPUT class LibLinear_Preprocessor(ParamSklearnPreprocessingAlgorithm): @@ -73,7 +72,7 @@ def get_properties(): # this here suggests so http://scikit-learn.org/stable/modules/svm.html#tips-on-practical-use 'handles_sparse': True, 'input': (SPARSE, DENSE), - 'output': PREDICTIONS, + 'output': INPUT, # TODO find out what is best used here! 'preferred_dtype': None} diff --git a/ParamSklearn/components/preprocessing/random_trees_embedding.py b/ParamSklearn/components/preprocessing/random_trees_embedding.py index 6c6a0d9c76..4640680ec9 100644 --- a/ParamSklearn/components/preprocessing/random_trees_embedding.py +++ b/ParamSklearn/components/preprocessing/random_trees_embedding.py @@ -39,7 +39,7 @@ def fit(self, X, Y=None): n_jobs=self.n_jobs, random_state=self.random_state ) - self.preprocessor.fit(X) + self.preprocessor.fit(X, Y) return self def transform(self, X): diff --git a/ParamSklearn/components/preprocessing/select_rates.py b/ParamSklearn/components/preprocessing/select_rates.py index 284b5c222b..582a8d038f 100644 --- a/ParamSklearn/components/preprocessing/select_rates.py +++ b/ParamSklearn/components/preprocessing/select_rates.py @@ -35,7 +35,16 @@ def fit(self, X, y): def transform(self, X): if self.preprocessor is None: raise NotImplementedError() - Xt = self.preprocessor.transform(X) + try: + Xt = self.preprocessor.transform(X) + except ValueError as e: + if "zero-size array to reduction operation maximum which has no " \ + "identity" in e.message: + raise ValueError( + "%s removed all features." % self.__class__.__name__) + else: + raise e + if Xt.shape[1] == 0: raise ValueError( "%s removed all features." % self.__class__.__name__) diff --git a/ParamSklearn/implementations/Imputation.py b/ParamSklearn/implementations/Imputation.py index 1115f2c2a5..300f96c284 100644 --- a/ParamSklearn/implementations/Imputation.py +++ b/ParamSklearn/implementations/Imputation.py @@ -115,8 +115,9 @@ class Imputer(BaseEstimator, TransformerMixin): - If `axis=0`, then impute along columns. - If `axis=1`, then impute along rows. - dtype : np.dtype - Determines the dtype of the transformed array. + dtype : np.dtype (default=np.float64) + Determines the dtype of the transformed array if it is dense. Has no + effect otherwise. verbose : integer, optional (default=0) Controls the verbosity of the imputer. @@ -183,14 +184,18 @@ def fit(self, X, y=None): # transform(X), the imputation data will be computed in transform() # when the imputation is done per sample (i.e., when axis=1). if self.axis == 0: - X = atleast2d_or_csc(X, dtype=self.dtype, force_all_finite=False) + if sparse.issparse(X): + X = atleast2d_or_csc(X, dtype=np.float64, + force_all_finite=False) self.statistics_ = self._sparse_fit(X, self.strategy, self.missing_values, self.axis) else: + X = atleast2d_or_csc(X, dtype=self.dtype, + force_all_finite=False) self.statistics_ = self._dense_fit(X, self.strategy, self.missing_values, diff --git a/misc/support_for_imbalanced_classes.txt b/misc/support_for_imbalanced_classes.txt new file mode 100644 index 0000000000..e69229c5a0 --- /dev/null +++ b/misc/support_for_imbalanced_classes.txt @@ -0,0 +1,22 @@ +AdaBoost: Sample weights. If None, the sample weights are initialized to 1 / n_samples. +Bernoulli_NB: Weights applied to individual samples (1. for unweighted). +DecisionTree: Sample weights. If None, then samples are equally weighted. Splits that would create child nodes with net zero or negative weight are ignored while searching for a split in each node. In the case of classification, splits are also ignored if they would result in any single class carrying a negative weight in either child node. +ExtraTrees: Sample weights. If None, then samples are equally weighted. Splits that would create child nodes with net zero or negative weight are ignored while searching for a split in each node. In the case of classification, splits are also ignored if they would result in any single class carrying a negative weight in either child node. +GaussianNB: - +GB: - +kNN: - +LDA: priors : array, optional, shape = [n_classes] ? +LibLinear: class_weight : {dict, ‘auto’}, optional +SVC: class_weight : {dict, ‘auto’}, optional; Per-sample weights. Rescale C per sample. Higher weights force the classifier to put more emphasis on these points. +MultinomialNB: - +PA: sample_weight : array-like, shape = [n_samples], optional +QDA: - +RF: sample_weight : array-like, shape = [n_samples] or None +RidgeClassifier:class_weight : dict, optional +SGD :class_weight : dict, {class_label + + + + +Preprocessors: + diff --git a/source/first_steps.rst b/source/first_steps.rst index c1c06c9597..2285ddb2be 100644 --- a/source/first_steps.rst +++ b/source/first_steps.rst @@ -18,10 +18,10 @@ configuration on the iris dataset. >>> np.random.seed(1) >>> np.random.shuffle(indices) >>> configuration_space = ParamSklearnClassifier.get_hyperparameter_search_space() - >>> sampler = RandomSampler(configuration_space, 5) + >>> sampler = RandomSampler(configuration_space, 3) >>> configuration = sampler.sample_configuration() >>> cls = ParamSklearnClassifier(configuration, random_state=1) >>> cls = cls.fit(X[indices[:100]], Y[indices[:100]]) >>> predictions = cls.predict(X[indices[100:]]) >>> sklearn.metrics.accuracy_score(predictions, Y[indices[100:]]) - 0.95999999999999996 + 0.90000000000000002 diff --git a/tests/components/classification/test_liblinear.py b/tests/components/classification/test_liblinear.py index 7242f946ab..167397fd33 100644 --- a/tests/components/classification/test_liblinear.py +++ b/tests/components/classification/test_liblinear.py @@ -1,6 +1,6 @@ import unittest -from ParamSklearn.components.classification.liblinear import LibLinear_SVC +from ParamSklearn.components.classification.liblinear_svc import LibLinear_SVC from ParamSklearn.util import _test_classifier diff --git a/tests/components/preprocessing/test_balancing.py b/tests/components/preprocessing/test_balancing.py new file mode 100644 index 0000000000..d916778b9e --- /dev/null +++ b/tests/components/preprocessing/test_balancing.py @@ -0,0 +1,117 @@ +__author__ = 'feurerm' + +import unittest + +import numpy as np +import sklearn.metrics + +from HPOlibConfigSpace.hyperparameters import InactiveHyperparameter + +from ParamSklearn.components.preprocessing.balancing import Balancing +from ParamSklearn.classification import ParamSklearnClassifier +from ParamSklearn.components.classification.adaboost import AdaboostClassifier +from ParamSklearn.components.classification.decision_tree import DecisionTree +from ParamSklearn.components.classification.extra_trees import ExtraTreesClassifier +from ParamSklearn.components.classification.gradient_boosting import GradientBoostingClassifier +from ParamSklearn.components.classification.random_forest import RandomForest +from ParamSklearn.components.classification.liblinear_svc import LibLinear_SVC +from ParamSklearn.components.classification.libsvm_svc import LibSVM_SVC +from ParamSklearn.components.classification.sgd import SGD +from ParamSklearn.components.classification.ridge import Ridge +from ParamSklearn.components.preprocessing\ + .extra_trees_preproc_for_classification import ExtraTreesPreprocessor +from ParamSklearn.components.preprocessing.liblinear_svc_preprocessor import LibLinear_Preprocessor +from ParamSklearn.components.preprocessing.random_trees_embedding import RandomTreesEmbedding +from ParamSklearn.util import get_dataset + + +class BalancingComponentTest(unittest.TestCase): + def test_balancing_get_weights_treed_single_label(self): + Y = np.array([0] * 80 + [1] * 20) + balancing = Balancing(strategy='weighting') + init_params, fit_params = balancing.get_weights( + Y, 'random_forest', None, None, None) + self.assertTrue(np.allclose(fit_params['random_forest:sample_weight'], + np.array([0.4] * 80 + [1.6] * 20))) + init_params, fit_params = balancing.get_weights( + Y, None, 'extra_trees_preproc_for_classification', None, None) + self.assertTrue(np.allclose(fit_params['extra_trees_preproc_for_classification:sample_weight'], + np.array([0.4] * 80 + [1.6] * 20))) + + def test_balancing_get_weights_treed_multilabel(self): + Y = np.array([[0, 0, 0]] * 100 + [[1, 0, 0]] * 100 + [[0, 1, 0]] * 100 + + [[1, 1, 0]] * 100 + [[0, 0, 1]] * 100 + [[1, 0, 1]] * 10) + balancing = Balancing(strategy='weighting') + init_params, fit_params = balancing.get_weights( + Y, 'random_forest', None, None, None) + self.assertTrue(np.allclose(fit_params['random_forest:sample_weight'], + np.array([0.4] * 500 + [4.0] * 10))) + init_params, fit_params = balancing.get_weights( + Y, None, 'extra_trees_preproc_for_classification', None, None) + self.assertTrue(np.allclose(fit_params['extra_trees_preproc_for_classification:sample_weight'], + np.array([0.4] * 500 + [4.0] * 10))) + + def test_balancing_get_weights_svm_sgd(self): + Y = np.array([0] * 80 + [1] * 20) + balancing = Balancing(strategy='weighting') + init_params, fit_params = balancing.get_weights( + Y, 'libsvm_svc', None, None, None) + self.assertEqual(("libsvm_svc:class_weight", "auto"), + init_params.items()[0]) + init_params, fit_params = balancing.get_weights( + Y, None, 'liblinear_svc_preprocessor', None, None) + self.assertEqual(("liblinear_svc_preprocessor:class_weight", "auto"), + init_params.items()[0]) + + def test_balancing_get_weights_ridge(self): + Y = np.array([0] * 80 + [1] * 20) + balancing = Balancing(strategy='weighting') + init_params, fit_params = balancing.get_weights( + Y, 'ridge', None, None, None) + self.assertAlmostEqual(0.4, init_params['ridge:class_weight'][0]) + self.assertAlmostEqual(1.6, init_params['ridge:class_weight'][1]) + + def test_weighting_effect(self): + for name, clf, acc_no_weighting, acc_weighting in \ + [('adaboost', AdaboostClassifier, 0.692, 0.719), + ('decision_tree', DecisionTree, 0.712, 0.668), + ('extra_trees', ExtraTreesClassifier, 0.910, 0.913), + ('random_forest', RandomForest, 0.896, 0.895), + ('libsvm_svc', LibSVM_SVC, 0.915, 0.937), + ('liblinear_svc', LibLinear_SVC, 0.920, 0.923), + ('sgd', SGD, 0.879, 0.906), + ('ridge', Ridge, 0.868, 0.880)]: + for strategy, acc in [('none', acc_no_weighting), + ('weighting', acc_weighting)]: + X_train, Y_train, X_test, Y_test = get_dataset(dataset='digits') + cs = ParamSklearnClassifier.get_hyperparameter_search_space( + include_estimators=[name]) + default = cs.get_default_configuration() + default.values['balancing:strategy'].value = strategy + classifier = ParamSklearnClassifier(default, random_state=1) + predictor = classifier.fit(X_train, Y_train) + predictions = predictor.predict(X_test) + self.assertAlmostEqual(acc, + sklearn.metrics.accuracy_score(predictions, Y_test), + places=3) + + for name, pre, acc_no_weighting, acc_weighting in \ + [('extra_trees_preproc_for_classification', + ExtraTreesPreprocessor, 0.900, 0.908), + ('liblinear_svc_preprocessor', LibLinear_Preprocessor, + 0.907, 0.882)]: + for strategy, acc in [('none', acc_no_weighting), + ('weighting', acc_weighting)]: + X_train, Y_train, X_test, Y_test = get_dataset(dataset='digits') + + cs = ParamSklearnClassifier.get_hyperparameter_search_space( + include_estimators=['sgd'], include_preprocessors=[name]) + default = cs.get_default_configuration() + default.values['balancing:strategy'].value = strategy + classifier = ParamSklearnClassifier(default, random_state=1) + predictor = classifier.fit(X_train, Y_train) + predictions = predictor.predict(X_test) + self.assertAlmostEqual(acc, + sklearn.metrics.accuracy_score( + predictions, Y_test), + places=3) \ No newline at end of file diff --git a/tests/components/preprocessing/test_extra_trees.py b/tests/components/preprocessing/test_extra_trees.py index 3826f45b0e..2e912475f6 100644 --- a/tests/components/preprocessing/test_extra_trees.py +++ b/tests/components/preprocessing/test_extra_trees.py @@ -1,7 +1,7 @@ import unittest from sklearn.linear_model import RidgeClassifier -from ParamSklearn.components.preprocessing.extra_trees import \ +from ParamSklearn.components.preprocessing.extra_trees_preproc_for_classification import \ ExtraTreesPreprocessor from ParamSklearn.util import _test_preprocessing, PreprocessingTestCase, \ get_dataset diff --git a/tests/components/preprocessing/test_liblinear.py b/tests/components/preprocessing/test_liblinear.py index e7532d8aab..a6c1b394ae 100644 --- a/tests/components/preprocessing/test_liblinear.py +++ b/tests/components/preprocessing/test_liblinear.py @@ -1,7 +1,7 @@ import unittest from sklearn.linear_model import RidgeClassifier -from ParamSklearn.components.preprocessing.liblinear import \ +from ParamSklearn.components.preprocessing.liblinear_svc_preprocessor import \ LibLinear_Preprocessor from ParamSklearn.util import _test_preprocessing, PreprocessingTestCase, \ get_dataset diff --git a/tests/test_classification.py b/tests/test_classification.py index 6514242851..46e6dd67b5 100644 --- a/tests/test_classification.py +++ b/tests/test_classification.py @@ -126,8 +126,8 @@ def test_configurations(self): def test_configurations_sparse(self): cs = ParamSklearnClassifier.get_hyperparameter_search_space( dataset_properties={'sparse': True}) - sampler = RandomSampler(cs, 123456) - for i in range(1000): + sampler = RandomSampler(cs, 1) + for i in range(10): config = sampler.sample_configuration() X_train, Y_train, X_test, Y_test = get_dataset(dataset='digits', make_sparse=True) @@ -178,10 +178,10 @@ def test_get_hyperparameter_search_space(self): self.assertIsInstance(cs, ConfigurationSpace) conditions = cs.get_conditions() hyperparameters = cs.get_hyperparameters() - self.assertEqual(130, len(hyperparameters)) + self.assertEqual(146, len(hyperparameters)) # The four parameters which are always active are classifier, # preprocessor, imputation strategy and scaling strategy - self.assertEqual(len(hyperparameters) - 4, len(conditions)) + self.assertEqual(len(hyperparameters) - 5, len(conditions)) def test_get_hyperparameter_search_space_include_exclude_models(self): cs = ParamSklearnClassifier.get_hyperparameter_search_space( @@ -205,6 +205,7 @@ def test_get_hyperparameter_search_space_include_exclude_models(self): def test_get_hyperparameter_search_space_only_forbidden_combinations(self): self.assertRaisesRegexp(ValueError, "Default Configuration:\n" + " balancing:strategy, Value: none\n" " classifier, Value: multinomial_nb\n" " imputation:strategy, Value: mean\n" " multinomial_nb:alpha, Value: 1.000000\n" @@ -222,23 +223,24 @@ def test_get_hyperparameter_search_space_only_forbidden_combinations(self): # It must also be catched that no classifiers which can handle sparse # data are located behind the densifier self.assertRaisesRegexp(ValueError, "Configuration:\n" - " classifier, Value: liblinear\n" + " balancing:strategy, Value: none\n" + " classifier, Value: liblinear_svc\n" " imputation:strategy, Value: mean\n" - " liblinear:C, Value: 1.000000\n" - " liblinear:class_weight, Value: None\n" - " liblinear:dual, Constant: False\n" - " liblinear:fit_intercept, Constant: True\n" - " liblinear:intercept_scaling, Constant: 1\n" - " liblinear:loss, Value: l2\n" - " liblinear:multi_class, Constant: ovr\n" - " liblinear:penalty, Value: l2\n" - " liblinear:tol, Value: 0.000100\n" + " liblinear_svc:C, Value: 1.000000\n" + " liblinear_svc:class_weight, Value: None\n" + " liblinear_svc:dual, Constant: False\n" + " liblinear_svc:fit_intercept, Constant: True\n" + " liblinear_svc:intercept_scaling, Constant: 1\n" + " liblinear_svc:loss, Value: l2\n" + " liblinear_svc:multi_class, Constant: ovr\n" + " liblinear_svc:penalty, Value: l2\n" + " liblinear_svc:tol, Value: 0.000100\n" " preprocessor, Value: densifier\n" " rescaling:strategy, Value: min/max\n" - "violates forbidden clause \(Forbidden: classifier == liblinear &&" + "violates forbidden clause \(Forbidden: classifier == liblinear_svc &&" " Forbidden: preprocessor == densifier\)", ParamSklearnClassifier.get_hyperparameter_search_space, - include_estimators=['liblinear'], + include_estimators=['liblinear_svc'], include_preprocessors=['densifier'], dataset_properties={'sparse': True}) @@ -309,7 +311,8 @@ def test_predict_batched_sparse(self): # Densifier + RF is the only combination that easily tests sparse # data with multilabel classification! config = Configuration(cs, - hyperparameters={"classifier": "random_forest", + hyperparameters={"balancing:strategy": "none", + "classifier": "random_forest", "imputation:strategy": "mean", "preprocessor": "densifier", 'random_forest:bootstrap': 'True', @@ -392,7 +395,8 @@ def test_predict_proba_batched_sparse(self): # Densifier + RF is the only combination that easily tests sparse # data with multilabel classification! config = Configuration(cs, - hyperparameters={"classifier": "random_forest", + hyperparameters={"balancing:strategy": "none", + "classifier": "random_forest", "imputation:strategy": "mean", "preprocessor": "densifier", 'random_forest:bootstrap': 'True', diff --git a/tests/test_create_searchspace_util_classification.py b/tests/test_create_searchspace_util_classification.py index 8bba36527f..16eaea42c4 100644 --- a/tests/test_create_searchspace_util_classification.py +++ b/tests/test_create_searchspace_util_classification.py @@ -6,7 +6,7 @@ from HPOlibConfigSpace.configuration_space import ConfigurationSpace from HPOlibConfigSpace.hyperparameters import CategoricalHyperparameter -from ParamSklearn.components.classification.liblinear import LibLinear_SVC +from ParamSklearn.components.classification.liblinear_svc import LibLinear_SVC from ParamSklearn.components.classification.random_forest import RandomForest from ParamSklearn.components.preprocessing.pca import PCA diff --git a/tests/test_textclassification.py b/tests/test_textclassification.py index fd9427537f..21d304f9bf 100644 --- a/tests/test_textclassification.py +++ b/tests/test_textclassification.py @@ -11,7 +11,7 @@ def test_get_hyperparameter_search_space(self): self.assertIsInstance(cs, ConfigurationSpace) conditions = cs.get_conditions() hyperparameters = cs.get_hyperparameters() - self.assertEqual(129, len(hyperparameters)) + self.assertEqual(145, len(hyperparameters)) # The four parameters which are always active are classifier, # preprocessor and imputation strategy self.assertEqual(len(hyperparameters) - 3, len(conditions))