From e6a7ec65c785bf74fd9a44ce26a06343a3a7a426 Mon Sep 17 00:00:00 2001 From: Delaney Mackenzie Date: Mon, 13 Dec 2021 01:55:59 -0500 Subject: [PATCH] ENH: Implemented performance speedup for binary ReliefF + bug fixes --- performance_tests.py | 179 +++++++++++ run_performance_benchmark.sh | 2 + skrebate/relieff.py | 571 ++++++++++++++++++++++++++--------- skrebate/scoring_utils.py | 10 +- skrebate/surf.py | 4 +- tests.py | 22 +- 6 files changed, 622 insertions(+), 166 deletions(-) create mode 100644 performance_tests.py create mode 100755 run_performance_benchmark.sh diff --git a/performance_tests.py b/performance_tests.py new file mode 100644 index 0000000..5dbffc7 --- /dev/null +++ b/performance_tests.py @@ -0,0 +1,179 @@ +from skrebate import ReliefF, SURF, SURFstar, MultiSURF, MultiSURFstar +from sklearn.pipeline import make_pipeline +from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor +from sklearn.impute import SimpleImputer +from sklearn.model_selection import cross_val_score +import pandas as pd +import numpy as np +import warnings + +import timeit + +warnings.filterwarnings('ignore') + +np.random.seed(3249083) + +genetic_data = pd.read_csv( + 'data/GAMETES_Epistasis_2-Way_20atts_0.4H_EDM-1_1.tsv.gz', sep='\t', compression='gzip') +# genetic_data = genetic_data.sample(frac=0.25) + +genetic_data_cont_endpoint = pd.read_csv( + 'data/GAMETES_Epistasis_2-Way_continuous_endpoint_a_20s_1600her_0.4__maf_0.2_EDM-2_01.tsv.gz', sep='\t', compression='gzip') +genetic_data_cont_endpoint.rename(columns={'Class': 'class'}, inplace=True) +genetic_data_cont_endpoint = genetic_data_cont_endpoint.sample(frac=0.25) + +genetic_data_mixed_attributes = pd.read_csv( + 'data/GAMETES_Epistasis_2-Way_mixed_attribute_a_20s_1600her_0.4__maf_0.2_EDM-2_01.tsv.gz', sep='\t', compression='gzip') +genetic_data_mixed_attributes.rename(columns={'Class': 'class'}, inplace=True) +genetic_data_mixed_attributes = genetic_data_mixed_attributes.sample(frac=0.25) + +genetic_data_missing_values = pd.read_csv( + 'data/GAMETES_Epistasis_2-Way_missing_values_0.1_a_20s_1600her_0.4__maf_0.2_EDM-2_01.tsv.gz', sep='\t', compression='gzip') +genetic_data_missing_values.rename(columns={'Class': 'class'}, inplace=True) +genetic_data_missing_values = genetic_data_missing_values.sample(frac=0.25) + +genetic_data_multiclass = pd.read_csv('data/3Class_Datasets_Loc_2_01.txt', sep='\t') +genetic_data_multiclass.rename(columns={'Class': 'class'}, inplace=True) +genetic_data_multiclass = genetic_data_multiclass.sample(frac=0.25) + + +features, labels = genetic_data.drop('class', axis=1).values, genetic_data['class'].values +headers = list(genetic_data.drop("class", axis=1)) + +features_cont_endpoint, labels_cont_endpoint = genetic_data_cont_endpoint.drop( + 'class', axis=1).values, genetic_data_cont_endpoint['class'].values +headers_cont_endpoint = list(genetic_data_cont_endpoint.drop("class", axis=1)) + +features_mixed_attributes, labels_mixed_attributes = genetic_data_mixed_attributes.drop( + 'class', axis=1).values, genetic_data_mixed_attributes['class'].values +headers_mixed_attributes = list(genetic_data_mixed_attributes.drop("class", axis=1)) + +features_missing_values, labels_missing_values = genetic_data_missing_values.drop( + 'class', axis=1).values, genetic_data_missing_values['class'].values +headers_missing_values = list(genetic_data_missing_values.drop("class", axis=1)) + +features_multiclass, labels_multiclass = genetic_data_multiclass.drop( + 'class', axis=1).values, genetic_data_multiclass['class'].values +headers_multiclass = list(genetic_data_multiclass.drop("class", axis=1)) + + + +# Basic Parallelization Tests and Core binary data and discrete feature data testing (Focus on ReliefF only for efficiency)------------------------------------------------------------ +def test_relieff(): + """Check: Data (Binary Endpoint, Discrete Features): ReliefF works in a sklearn pipeline""" + np.random.seed(49082) + + alg = ReliefF(n_features_to_select=2, n_neighbors=10) + alg.fit(features, labels) + + +def test_relieff_parallel(): + """Check: Data (Binary Endpoint, Discrete Features): ReliefF works in a sklearn pipeline when ReliefF is parallelized""" + # Note that the rebate algorithm cannot be parallelized with both the random forest and the cross validation all at once. If the rebate algorithm is parallelized, the cross-validation scoring cannot be. + np.random.seed(49082) + + alg = ReliefF(n_features_to_select=2, n_neighbors=10, n_jobs=-1) + alg.fit(features, labels) + + +def test_relieffpercent(): + """Check: Data (Binary Endpoint, Discrete Features): ReliefF with % neighbors works in a sklearn pipeline""" + np.random.seed(49082) + + alg = ReliefF(n_features_to_select=2, n_neighbors=0.1) + alg.fit(features, labels) + + +def test_surf(): + """Check: Data (Binary Endpoint, Discrete Features): SURF works in a sklearn pipeline""" + np.random.seed(240932) + + alg = SURF(n_features_to_select=2) + alg.fit(features, labels) + + +def test_surf_parallel(): + """Check: Data (Binary Endpoint, Discrete Features): SURF works in a sklearn pipeline when SURF is parallelized""" + np.random.seed(240932) + + alg = SURF(n_features_to_select=2, n_jobs=-1) + alg.fit(features, labels) + + +def test_surfstar(): + """Check: Data (Binary Endpoint, Discrete Features): SURF* works in a sklearn pipelined""" + np.random.seed(9238745) + + alg = SURFstar(n_features_to_select=2) + alg.fit(features, labels) + + +def test_surfstar_parallel(): + """Check: Data (Binary Endpoint, Discrete Features): SURF* works in a sklearn pipeline when SURF* is parallelized""" + np.random.seed(9238745) + + alg = SURFstar(n_features_to_select=2, n_jobs=-1) + alg.fit(features, labels) + + +def test_multisurfstar(): + """Check: Data (Binary Endpoint, Discrete Features): MultiSURF* works in a sklearn pipeline""" + np.random.seed(320931) + + alg = MultiSURFstar(n_features_to_select=2) + alg.fit(features, labels) + + +def test_multisurfstar_parallel(): + """Check: Data (Binary Endpoint, Discrete Features): MultiSURF* works in a sklearn pipeline when MultiSURF* is parallelized""" + np.random.seed(320931) + + alg = MultiSURFstar(n_features_to_select=2, n_jobs=-1) + alg.fit(features, labels) + + +def test_multisurf(): + """Check: Data (Binary Endpoint, Discrete Features): MultiSURF works in a sklearn pipeline""" + np.random.seed(320931) + + alg = MultiSURF(n_features_to_select=2) + alg.fit(features, labels) + + +def test_multisurf_parallel(): + """Check: Data (Binary Endpoint, Discrete Features): MultiSURF works in a sklearn pipeline when MultiSURF is parallelized""" + np.random.seed(320931) + + alg = MultiSURF(n_features_to_select=2, n_jobs=-1) + alg.fit(features, labels) + + +test_cases = [ + test_relieff, + test_relieff_parallel, + test_relieffpercent, + # test_surf, + # test_surf_parallel, + # test_surfstar, + # test_surfstar_parallel, + # test_multisurfstar, + # test_multisurfstar_parallel, + # test_multisurf, + # test_multisurf_parallel +] + +if __name__ == '__main__': + timing_df = pd.DataFrame(columns=['test_case', 'mean', 'std']) + + for test_case in test_cases: + timing = timeit.repeat(test_case, number=1, repeat=5) + # ignore the first test to avoid high initial overhead to compile numba + # functions with small datasets + timing = timing[1:] + print(test_case.__name__, np.mean(timing), np.std(timing)) + d = {'test_case' : test_case.__name__, 'mean' : np.mean(timing), 'std' : np.std(timing)} + timing_df = timing_df.append(d, ignore_index = True) + + print(timing_df) + + timing_df.to_csv('timing_benchmarks.csv') diff --git a/run_performance_benchmark.sh b/run_performance_benchmark.sh new file mode 100755 index 0000000..8a1c258 --- /dev/null +++ b/run_performance_benchmark.sh @@ -0,0 +1,2 @@ +python -m cProfile -o perf_data.pstats performance_tests.py +gprof2dot -f pstats perf_data.pstats | dot -Tpng -o perfgraph.png diff --git a/skrebate/relieff.py b/skrebate/relieff.py index b190056..10f8da1 100644 --- a/skrebate/relieff.py +++ b/skrebate/relieff.py @@ -7,28 +7,115 @@ - Ryan J. Urbanowicz (ryanurb@upenn.edu) - Weixuan Fu (weixuanf@upenn.edu) - and many more generous open source contributors -Permission is hereby granted, free of charge, to any person obtaining a copy of this software -and associated documentation files (the "Software"), to deal in the Software without restriction, -including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, -and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, -subject to the following conditions: -The above copyright notice and this permission notice shall be included in all copies or substantial -portions of the Software. -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT -LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. -IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, -WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE -SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +of the Software, and to permit persons to whom the Software is furnished +to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR +THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ from __future__ import print_function import numpy as np import time -import warnings import sys from sklearn.base import BaseEstimator -from joblib import Parallel, delayed -from .scoring_utils import get_row_missing, ReliefF_compute_scores, get_row_missing_iter +from joblib import Parallel, delayed, cpu_count +from numba import jit +from .scoring_utils import ( + get_row_missing, + ReliefF_compute_scores, + get_row_missing_iter +) + + +def chunks(lst, n): + """Yield successive n-sized chunks from lst.""" + for i in range(0, len(lst), n): + yield lst[i:i + n] + + +@jit(nopython=True) +def find_neighbors_binary(inst, datalen, distance_array, n_neighbors, y): + """ Identify k nearest hits and k nearest misses for given instance.""" + + dist_vect = np.zeros(datalen) + for j in range(datalen): + if inst != j: + # code to utilize the lower triangle dist matrix + if inst < j: + dist_vect[j] = distance_array[j][inst] + else: + dist_vect[j] = distance_array[inst][j] + else: + # Ensures that target instance is never selected as neighbor. + dist_vect[j] = sys.maxsize + + nn_list = [0] * n_neighbors * 2 + match_count = 0 + miss_count = 0 + for nn_index in np.argsort(dist_vect): + if nn_index == inst: + continue + if y[inst] == y[nn_index]: # Hit neighbor identified + if match_count >= n_neighbors: + continue + nn_list[match_count+miss_count] = nn_index + match_count += 1 + else: # Miss neighbor identified + if miss_count >= n_neighbors: + continue + nn_list[match_count+miss_count] = nn_index + miss_count += 1 + + if match_count >= n_neighbors and miss_count >= n_neighbors: + break + + # print(n_neighbors, match_count, miss_count) + + return np.array(nn_list) + + +# interestingly, this is sometimes faster without numba +def compute_score_binary(inst_x, nns_x, inst_y, nns_y): + """ + Compute the ReliefF score for a binary dataset + + Input: + inst - numpy array of instance values + nns_x - numpy 2d array containing instance values for + nearest neighbors + inst_y - class value for instance + nns_y - numpy array of class values for nearest neighbors + + Returns: + numpy array with a score for each attribute + """ + + # get diffs + diffs = (inst_x != nns_x) * 1 + + # compare classes to determine hits and misses + hits = np.array((nns_y == inst_y), ndmin=2).T + + # for hits, subtract 1 for each difference + # for misses, add 1 for each difference + + # convert to -1 for hits (to penalize diffs) + # and 1 for misses (to reward far differences) + hits = hits * -2 + 1 + diffs *= hits + + return np.nansum(diffs, axis=0) class ReliefF(BaseEstimator): @@ -38,19 +125,33 @@ class ReliefF(BaseEstimator): Igor et al. Overcoming the myopia of inductive learning algorithms with RELIEFF (1997), Applied Intelligence, 7(1), p39-55""" - """Note that ReliefF class establishes core functionality that is inherited by all other Relief-based algorithms. - Assumes: * There are no missing values in the label/outcome/dependent variable. - * For ReliefF, the setting of k is <= to the number of instances that have the least frequent class label + """Note that ReliefF class establishes core functionality that is inherited + by all other Relief-based algorithms. + Assumes: * There are no missing values in the + label/outcome/dependent variable. + * For ReliefF, the setting of k is <= to the number + of instances that have the least frequent class label (binary and multiclass endpoint data. """ - def __init__(self, n_features_to_select=10, n_neighbors=100, discrete_threshold=10, verbose=False, n_jobs=1,weight_final_scores=False,rank_absolute=False): - """Sets up ReliefF to perform feature selection. Note that an approximation of the original 'Relief' - algorithm may be run by setting 'n_features_to_select' to 1. Also note that the original Relief parameter 'm' - is not included in this software. 'm' specifies the number of random training instances out of 'n' (total - training instances) used to update feature scores. Since scores are most representative when m=n, all - available training instances are utilized in all Relief-based algorithm score updates here. If the user - wishes to utilize a smaller 'm' in Relief-based scoring, simply pass any of these algorithms a subset of the - original training dataset samples. + def __init__(self, + n_features_to_select=10, + n_neighbors=100, + discrete_threshold=10, + verbose=False, + n_jobs=1, + weight_final_scores=False, + rank_absolute=False): + """Sets up ReliefF to perform feature selection. Note that an + approximation of the original 'Relief' algorithm may be run by setting + 'n_features_to_select' to 1. Also note that the original Relief + parameter 'm' is not included in this software. 'm' specifies the + number of random training instances out of 'n' + (total training instances) used to update feature scores. + Since scores are most representative when m=n, all available training + instances are utilized in all Relief-based algorithm score updates + here. If the user wishes to utilize a smaller 'm' in Relief-based + scoring, simply pass any of these algorithms a subset of the original + training dataset samples. Parameters ---------- n_features_to_select: int (default: 10) @@ -58,23 +159,27 @@ def __init__(self, n_features_to_select=10, n_neighbors=100, discrete_threshold= retain after feature selection is applied. n_neighbors: int or float (default: 100) The number of neighbors to consider when assigning feature - importance scores. If a float number is provided, that percentage of - training samples is used as the number of neighbors. + importance scores. If a float number is provided, that percentage + of training samples is used as the number of neighbors. More neighbors results in more accurate scores, but takes longer. discrete_threshold: int (default: 10) Value used to determine if a feature is discrete or continuous. - If the number of unique levels in a feature is > discrete_threshold, then it is - considered continuous, or discrete otherwise. + If the number of unique levels in a feature is > + discrete_threshold, then it is considered continuous, + or discrete otherwise. verbose: bool (default: False) If True, output timing of distance array and scoring n_jobs: int (default: 1) - The number of cores to dedicate to computing the scores with joblib. - Assigning this parameter to -1 will dedicate as many cores as are available on your system. - We recommend setting this parameter to -1 to speed up the algorithm as much as possible. + The number of cores to dedicate to computing the scores + with joblib. Assigning this parameter to -1 will dedicate as + many cores as are available on your system. We recommend setting + this parameter to -1 to speed up the algorithm as much as possible. weight_final_scores: bool (default: False) - Whether to multiply given weights (in fit) to final scores. Only applicable if weights are given. + Whether to multiply given weights (in fit) to final scores. + Only applicable if weights are given. rank_absolute: bool (default: False) - Whether to give top features as by ranking features by absolute value. + Whether to give top features as by ranking + features by absolute value. """ self.n_features_to_select = n_features_to_select self.n_neighbors = n_neighbors @@ -84,9 +189,9 @@ def __init__(self, n_features_to_select=10, n_neighbors=100, discrete_threshold= self.weight_final_scores = weight_final_scores self.rank_absolute = rank_absolute - #=========================================================================# def fit(self, X, y, weights=None): - """Scikit-learn required: Computes the feature importance scores from the training data. + """Scikit-learn required: Computes the feature importance scores from + the training data. Parameters ---------- X: array-like {n_samples, n_features} @@ -99,34 +204,43 @@ def fit(self, X, y, weights=None): ------- Copy of the ReliefF instance """ - self._X = X # matrix of predictive variables ('independent variables') - self._y = y # vector of values for outcome variable ('dependent variable') + self._X = X # matrix of predictive variables + # ('independent variables') + self._y = y # vector of values for outcome variable + # ('dependent variable') if isinstance(weights, np.ndarray): if isinstance(weights, np.ndarray): if len(weights) != len(X[0]): - raise Exception('Dimension of weights param must match number of features') + raise Exception('Dimension of weights \ + param must match number of features') elif isinstance(weights, list): if len(weights) != len(X[0]): - raise Exception('Dimension of weights param must match number of features') + raise Exception('Dimension of weights \ + param must match number of features') weights = np.ndarray(weights) else: raise Exception('weights param must be numpy array or list') self._weights = weights - # Set up the properties for ReliefF ------------------------------------------------------------------------------------- + # Set up the properties for ReliefF self._datalen = len(self._X) # Number of training instances ('n') - """"Below: Handles special case where user requests that a proportion of training instances be neighbors for - ReliefF rather than a specified 'k' number of neighbors. Note that if k is specified, then k 'hits' and k - 'misses' will be used to update feature scores. Thus total number of neighbors is 2k. If instead a proportion - is specified (say 0.1 out of 1000 instances) this represents the total number of neighbors (e.g. 100). In this - case, k would be set to 50 (i.e. 50 hits and 50 misses). """ + """"Below: Handles special case where user requests that a proportion + of training instances be neighbors for ReliefF rather than a specified + 'k' number of neighbors. Note that if k is specified, then k 'hits' + and k 'misses' will be used to update feature scores. Thus total + number of neighbors is 2k. If instead a proportion is specified + (say 0.1 out of 1000 instances) this represents the total number of + neighbors (e.g. 100). In this case, k would be set to 50 + (i.e. 50 hits and 50 misses). """ if hasattr(self, 'n_neighbors') and type(self.n_neighbors) is float: - # Halve the number of neighbors because ReliefF uses n_neighbors matches and n_neighbors misses + # Halve the number of neighbors because ReliefF uses n_neighbors + # matches and n_neighbors misses self.n_neighbors = int(self.n_neighbors * self._datalen * 0.5) - # Number of unique outcome (label) values (used to determine outcome variable type) + # Number of unique outcome (label) values + # (used to determine outcome variable type) self._label_list = list(set(self._y)) # Determine if label is discrete discrete_label = (len(self._label_list) <= self.discrete_threshold) @@ -146,28 +260,36 @@ def fit(self, X, y, weights=None): self._class_type = 'continuous' self.mcmap = 0 - # Training labels standard deviation -- only used if the training labels are continuous + # Training labels standard deviation -- + # only used if the training labels are continuous self._labels_std = 0. if len(self._label_list) > self.discrete_threshold: self._labels_std = np.std(self._y, ddof=1) - self._num_attributes = len(self._X[0]) # Number of features in training data - + # Number of features in training data + self._num_attributes = len(self._X[0]) + # Number of missing data values in predictor variable matrix. self._missing_data_count = np.isnan(self._X).sum() - """Assign internal headers for the features (scikit-learn does not accept external headers from dataset): - The pre_normalize() function relies on the headers being ordered, e.g., X01, X02, etc. - If this is changed, then the sort in the pre_normalize() function needs to be adapted as well. """ + """Assign internal headers for the features (scikit-learn does + not accept external headers from dataset): + The pre_normalize() function relies on the headers + being ordered, e.g., X01, X02, etc. + If this is changed, then the sort in the pre_normalize() + function needs to be adapted as well. """ xlen = len(self._X[0]) mxlen = len(str(xlen + 1)) - self._headers = ['X{}'.format(str(i).zfill(mxlen)) for i in range(1, xlen + 1)] + self._headers = ['X{}'.format(str(i).zfill(mxlen)) + for i in range(1, xlen + 1)] start = time.time() # Runtime tracking - # Determine data types for all features/attributes in training data (i.e. discrete or continuous) + # Determine data types for all features/attributes + # in training data (i.e. discrete or continuous) C = D = False - # Examines each feature and applies discrete_threshold to determine variable type. + # Examines each feature and applies discrete_threshold + # to determine variable type. self.attr = self._get_attribute_info() for key in self.attr.keys(): if self.attr[key][0] == 'discrete': @@ -175,7 +297,9 @@ def fit(self, X, y, weights=None): if self.attr[key][0] == 'continuous': C = True - # For downstream computational efficiency, determine if dataset is comprised of all discrete, all continuous, or a mix of discrete/continuous features. + # For downstream computational efficiency, determine + # if dataset is comprised of all discrete, all continuous, + # or a mix of discrete/continuous features. if C and D: self.data_type = 'mixed' elif D and not C: @@ -184,28 +308,32 @@ def fit(self, X, y, weights=None): self.data_type = 'continuous' else: raise ValueError('Invalid data type in data set.') - #-------------------------------------------------------------------------------------------------------------------- - # Compute the distance array between all data points ---------------------------------------------------------------- - # For downstream efficiency, separate features in dataset by type (i.e. discrete/continuous) + # Compute the distance array between all data points + # For downstream efficiency, separate features + # in dataset by type (i.e. discrete/continuous) diffs, cidx, didx = self._dtype_array() - cdiffs = diffs[cidx] # max/min continuous value difference for continuous features. + cdiffs = diffs[cidx] # max/min continuous value + # difference for continuous features. xc = self._X[:, cidx] # Subset of continuous-valued feature data xd = self._X[:, didx] # Subset of discrete-valued feature data - """ For efficiency, the distance array is computed more efficiently for data with no missing values. - This distance array will only be used to identify nearest neighbors. """ + """ For efficiency, the distance array is computed more efficiently + for data with no missing values. This distance array will only be + used to identify nearest neighbors. """ if self._missing_data_count > 0: if not isinstance(self._weights, np.ndarray): self._distance_array = self._distarray_missing(xc, xd, cdiffs) else: - self._distance_array = self._distarray_missing_iter(xc, xd, cdiffs, self._weights) + self._distance_array = \ + self._distarray_missing_iter(xc, xd, cdiffs, self._weights) else: if not isinstance(self._weights, np.ndarray): self._distance_array = self._distarray_no_missing(xc, xd) else: - self._distance_array = self._distarray_no_missing_iter(xc, xd, self._weights) + self._distance_array = \ + self._distarray_no_missing_iter(xc, xd, self._weights) if self.verbose: elapsed = time.time() - start @@ -213,10 +341,11 @@ def fit(self, X, y, weights=None): print('Feature scoring under way ...') start = time.time() - #-------------------------------------------------------------------------------------------------------------------- - - # Run remainder of algorithm (i.e. identification of 'neighbors' for each instance, and feature scoring).------------ - # Stores feature importance scores for ReliefF or respective Relief-based algorithm. + # Run remainder of algorithm + # (i.e. identification of 'neighbors' for each instance, + # and feature scoring). + # Stores feature importance scores for ReliefF or + # respective Relief-based algorithm. self.feature_importances_ = self._run_algorithm() # Delete the internal distance array because it is no longer needed @@ -228,15 +357,16 @@ def fit(self, X, y, weights=None): # Compute indices of top features if self.rank_absolute: - self.top_features_ = np.argsort(np.absolute(self.feature_importances_))[::-1] + self.top_features_ = np.argsort( + np.absolute(self.feature_importances_))[::-1] else: self.top_features_ = np.argsort(self.feature_importances_)[::-1] return self - #=========================================================================# def transform(self, X): - """Scikit-learn required: Reduces the feature set down to the top `n_features_to_select` features. + """Scikit-learn required: Reduces the feature set down + to the top `n_features_to_select` features. Parameters ---------- X: array-like {n_samples, n_features} @@ -247,13 +377,15 @@ def transform(self, X): Reduced feature matrix """ if self._num_attributes < self.n_features_to_select: - raise ValueError('Number of features to select is larger than the number of features in the dataset.') - + raise ValueError('Number of features to select is \ + larger than the number of features in the dataset.') + return X[:, self.top_features_[:self.n_features_to_select]] - #=========================================================================# def fit_transform(self, X, y): - """Scikit-learn required: Computes the feature importance scores from the training data, then reduces the feature set down to the top `n_features_to_select` features. + """Scikit-learn required: Computes the feature importance + scores from the training data, then reduces the feature set + down to the top `n_features_to_select` features. Parameters ---------- X: array-like {n_samples, n_features} @@ -269,13 +401,20 @@ def fit_transform(self, X, y): return self.transform(X) -######################### SUPPORTING FUNCTIONS ########################### + # SUPPORTING FUNCTIONS def _getMultiClassMap(self): - """ Relief algorithms handle the scoring updates a little differently for data with multiclass outcomes. In ReBATE we implement multiclass scoring in line with - the strategy described by Kononenko 1994 within the RELIEF-F variant which was suggested to outperform the RELIEF-E multiclass variant. This strategy weights - score updates derived from misses of different classes by the class frequency observed in the training data. 'The idea is that the algorithm should estimate the - ability of attributes to separate each pair of classes regardless of which two classes are closest to each other'. In this method we prepare for this normalization - by creating a class dictionary, and storing respective class frequencies. This is needed for ReliefF multiclass score update normalizations. """ + """ Relief algorithms handle the scoring updates a little differently + for data with multiclass outcomes. In ReBATE we implement multiclass + scoring in line with the strategy described by Kononenko 1994 within + the RELIEF-F variant which was suggested to outperform the RELIEF-E + multiclass variant. This strategy weights score updates derived from + misses of different classes by the class frequency observed in the + training data. 'The idea is that the algorithm should estimate the + ability of attributes to separate each pair of classes regardless of + which two classes are closest to each other'. In this method we + prepare for this normalization by creating a class dictionary, and + storing respective class frequencies. This is needed for ReliefF + multiclass score update normalizations. """ mcmap = dict() for i in range(self._datalen): @@ -290,7 +429,9 @@ def _getMultiClassMap(self): return mcmap def _get_attribute_info(self): - """ Preprocess the training dataset to identify which features/attributes are discrete vs. continuous valued. Ignores missing values in this determination.""" + """ Preprocess the training dataset to identify which + features/attributes are discrete vs. continuous valued. + Ignores missing values in this determination.""" attr = dict() d = 0 limit = self.discrete_threshold @@ -300,7 +441,8 @@ def _get_attribute_info(self): h = self._headers[idx] z = w[idx] if self._missing_data_count > 0: - z = z[np.logical_not(np.isnan(z))] # Exclude any missing values from consideration + # Exclude any missing values from consideration + z = z[np.logical_not(np.isnan(z))] zlen = len(np.unique(z)) if zlen <= limit: attr[h] = ('discrete', 0, 0, 0, 0) @@ -310,20 +452,26 @@ def _get_attribute_info(self): mn = np.min(z) sd = np.std(z) attr[h] = ('continuous', mx, mn, mx - mn, sd) - # For each feature/attribute we store (type, max value, min value, max min difference, average, standard deviation) - the latter three values are set to zero if feature is discrete. + # For each feature/attribute we store (type, max value, min value, + # max min difference, average, standard deviation) - the latter three + # values are set to zero if feature is discrete. return attr def _distarray_no_missing(self, xc, xd): - """Distance array calculation for data with no missing values. The 'pdist() function outputs a condense distance array, and squareform() converts this vector-form + """Distance array calculation for data with no missing values. + The 'pdist() function outputs a condense distance array, and + squareform() converts this vector-form distance vector to a square-form, redundant distance matrix. - *This could be a target for saving memory in the future, by not needing to expand to the redundant square-form matrix. """ + *This could be a target for saving memory in the future, + by not needing to expand to the redundant square-form matrix. """ from scipy.spatial.distance import pdist, squareform - #------------------------------------------# def pre_normalize(x): - """Normalizes continuous features so they are in the same range (0 to 1)""" + """Normalizes continuous features so they are in the + same range (0 to 1)""" idx = 0 - # goes through all named features (doesn really need to) this method is only applied to continuous features + # goes through all named features (doesn really need to) + # this method is only applied to continuous features for i in sorted(self.attr.keys()): if self.attr[i][0] == 'discrete': continue @@ -333,23 +481,24 @@ def pre_normalize(x): x[:, idx] /= diff idx += 1 return x - #------------------------------------------# - if self.data_type == 'discrete': # discrete features only + # discrete features only + if self.data_type == 'discrete': return squareform(pdist(self._X, metric='hamming')) - elif self.data_type == 'mixed': # mix of discrete and continuous features + # mix of discrete and continuous features + elif self.data_type == 'mixed': d_dist = squareform(pdist(xd, metric='hamming')) # Cityblock is also known as Manhattan distance c_dist = squareform(pdist(pre_normalize(xc), metric='cityblock')) return np.add(d_dist * self._num_attributes, c_dist) - else: #continuous features only - #xc = pre_normalize(xc) + else: # continuous features only + # xc = pre_normalize(xc) return squareform(pdist(pre_normalize(xc), metric='cityblock')) - #==================================================================# def _dtype_array(self): - """Return mask for discrete(0)/continuous(1) attributes and their indices. Return array of max/min diffs of attributes.""" + """Return mask for discrete(0)/continuous(1) attributes and their + indices. Return array of max/min diffs of attributes.""" attrtype = [] attrdiff = [] @@ -367,38 +516,57 @@ def _dtype_array(self): attrdiff = np.array(attrdiff) return attrdiff, cidx, didx - #==================================================================# + def _distarray_missing(self, xc, xd, cdiffs): """Distance array calculation for data with missing values""" cindices = [] dindices = [] - # Get Boolean mask locating missing values for continuous and discrete features separately. These correspond to xc and xd respectively. + # Get Boolean mask locating missing values for continuous and + # discrete features separately. These correspond to + # xc and xd respectively. for i in range(self._datalen): cindices.append(np.where(np.isnan(xc[i]))[0]) dindices.append(np.where(np.isnan(xd[i]))[0]) if self.n_jobs != 1: dist_array = Parallel(n_jobs=self.n_jobs)(delayed(get_row_missing)( - xc, xd, cdiffs, index, cindices, dindices) for index in range(self._datalen)) + xc, xd, cdiffs, index, cindices, dindices) + for index in range(self._datalen)) else: - # For each instance calculate distance from all other instances (in non-redundant manner) (i.e. computes triangle, and puts zeros in for rest to form square). - dist_array = [get_row_missing(xc, xd, cdiffs, index, cindices, dindices) - for index in range(self._datalen)] + # For each instance calculate distance from all other + # instances (in non-redundant manner) (i.e. computes triangle, + # and puts zeros in for rest to form square). + dist_array = [ + get_row_missing(xc, xd, cdiffs, index, cindices, dindices) + for index in range(self._datalen)] + + if type(self) == ReliefF: + # Pad the array with zeros back to square form so we can use + # numba operations later. Numba gets very confused with variable + # sized rows - as the array is interpreted as object type. + for i in range(len(dist_array)): + dist_array[i] = np.pad(dist_array[i], (0, self._datalen - i)) + + return np.array(dist_array, dtype='float64') return np.array(dist_array) - #==================================================================# + # For Iter Relief def _distarray_no_missing_iter(self, xc, xd, weights): - """Distance array calculation for data with no missing values. The 'pdist() function outputs a condense distance array, and squareform() converts this vector-form + """Distance array calculation for data with no missing values. + The 'pdist() function outputs a condense distance array, and + squareform() converts this vector-form distance vector to a square-form, redundant distance matrix. - *This could be a target for saving memory in the future, by not needing to expand to the redundant square-form matrix. """ + *This could be a target for saving memory in the future, by not + needing to expand to the redundant square-form matrix. """ from scipy.spatial.distance import pdist, squareform - # ------------------------------------------# def pre_normalize(x): - """Normalizes continuous features so they are in the same range (0 to 1)""" + """Normalizes continuous features so they are in the + same range (0 to 1)""" idx = 0 - # goes through all named features (doesn really need to) this method is only applied to continuous features + # goes through all named features (doesn really need to) + # this method is only applied to continuous features for i in sorted(self.attr.keys()): if self.attr[i][0] == 'discrete': continue @@ -409,48 +577,89 @@ def pre_normalize(x): idx += 1 return x - # ------------------------------------------# - - if self.data_type == 'discrete': # discrete features only + # discrete features only + if self.data_type == 'discrete': return squareform(pdist(self._X, metric='hamming', w=weights)) - elif self.data_type == 'mixed': # mix of discrete and continuous features + # mix of discrete and continuous features + elif self.data_type == 'mixed': d_dist = squareform(pdist(xd, metric='hamming', w=weights)) # Cityblock is also known as Manhattan distance - c_dist = squareform(pdist(pre_normalize(xc), metric='cityblock', w=weights)) + c_dist = squareform( + pdist(pre_normalize(xc), metric='cityblock', w=weights)) return np.add(d_dist, c_dist) / self._num_attributes else: # continuous features only # xc = pre_normalize(xc) - return squareform(pdist(pre_normalize(xc), metric='cityblock', w=weights)) - - # ==================================================================# + return squareform( + pdist(pre_normalize(xc), metric='cityblock', w=weights)) # For Iterrelief - get_row_missing_iter is called - def _distarray_missing_iter(self, xc, xd, cdiffs, weights): + def _distarray_missing_iter(self, xc, xd, cdiffs, weights, pad_mode=False): """Distance array calculation for data with missing values""" cindices = [] dindices = [] - # Get Boolean mask locating missing values for continuous and discrete features separately. These correspond to xc and xd respectively. + # Get Boolean mask locating missing values for continuous and discrete + # features separately. These correspond to xc and xd respectively. for i in range(self._datalen): cindices.append(np.where(np.isnan(xc[i]))[0]) dindices.append(np.where(np.isnan(xd[i]))[0]) if self.n_jobs != 1: - dist_array = Parallel(n_jobs=self.n_jobs)(delayed(get_row_missing_iter)( - xc, xd, cdiffs, index, cindices, dindices, weights) for index in range(self._datalen)) + dist_array = Parallel(n_jobs=self.n_jobs) + ( + delayed(get_row_missing_iter)( + xc, xd, cdiffs, index, cindices, dindices, weights + ) for index in range(self._datalen) + ) else: - # For each instance calculate distance from all other instances (in non-redundant manner) (i.e. computes triangle, and puts zeros in for rest to form square). - dist_array = [get_row_missing_iter(xc, xd, cdiffs, index, cindices, dindices, weights) - for index in range(self._datalen)] + # For each instance calculate distance from all other instances + # (in non-redundant manner) (i.e. computes triangle, and puts + # zeros in for rest to form square). + dist_array = [get_row_missing_iter( + xc, xd, cdiffs, index, cindices, dindices, weights) + for index in range(self._datalen)] + + if type(self) == ReliefF: + # Pad the array with zeros back to square form so we can use + # numba operations later. Numba gets very confused with variable + # sized rows - as the array is interpreted as object type. + for i in range(len(dist_array)): + dist_array[i] = np.pad(dist_array[i], (0, self._datalen - i)) + + return np.array(dist_array, dtype='float64') return np.array(dist_array) - # ==================================================================# -############################# ReliefF ############################################ + def _find_neighbors_batch(self, instances): + res = [] + for inst in instances: + res.append(self._find_neighbors(inst)) + + return np.array(res) def _find_neighbors(self, inst): - """ Identify k nearest hits and k nearest misses for given instance. This is accomplished differently based on the type of endpoint (i.e. binary, multiclass, and continuous). """ - # Make a vector of distances between target instance (inst) and all others + """ Identify k nearest hits and k nearest misses for given instance. + This is accomplished differently based on the type of endpoint + (i.e. binary, multiclass, and continuous). """ + + if self._class_type == 'binary': + return find_neighbors_binary( + inst, + self._datalen, + self._distance_array, + self.n_neighbors, + self._y + ) + elif self._class_type == 'multiclass': + # TODO - implement speedup + pass + else: + # continuous + # TODO - implement speedup + pass + + # Make a vector of distances between target instance + # (inst) and all others dist_vect = [] for j in range(self._datalen): if inst != j: @@ -464,14 +673,15 @@ def _find_neighbors(self, inst): dist_vect = np.array(dist_vect) - # Identify neighbors------------------------------------------------------- + # Identify neighbors """ NN for Binary Endpoints: """ if self._class_type == 'binary': nn_list = [] match_count = 0 miss_count = 0 for nn_index in np.argsort(dist_vect): - if self._y[inst] == self._y[nn_index]: # Hit neighbor identified + if self._y[inst] == self._y[nn_index]: + # Hit neighbor identified if match_count >= self.n_neighbors: continue nn_list.append(nn_index) @@ -482,7 +692,8 @@ def _find_neighbors(self, inst): nn_list.append(nn_index) miss_count += 1 - if match_count >= self.n_neighbors and miss_count >= self.n_neighbors: + if match_count >= self.n_neighbors and \ + miss_count >= self.n_neighbors: break elif self._class_type == 'multiclass': @@ -490,7 +701,8 @@ def _find_neighbors(self, inst): match_count = 0 miss_count = dict.fromkeys(self._label_list, 0) for nn_index in np.argsort(dist_vect): - if self._y[inst] == self._y[nn_index]: # Hit neighbor identified + if self._y[inst] == self._y[nn_index]: + # Hit neighbor identified if match_count >= self.n_neighbors: continue nn_list.append(nn_index) @@ -503,14 +715,18 @@ def _find_neighbors(self, inst): nn_list.append(nn_index) miss_count[label] += 1 - if match_count >= self.n_neighbors and all(v >= self.n_neighbors for v in miss_count.values()): + if match_count >= self.n_neighbors and \ + all( + v >= self.n_neighbors for v in miss_count.values() + ): break else: nn_list = [] match_count = 0 miss_count = 0 for nn_index in np.argsort(dist_vect): - if abs(self._y[inst]-self._y[nn_index]) < self._labels_std: # Hit neighbor identified + if abs(self._y[inst]-self._y[nn_index]) < self._labels_std: + # Hit neighbor identified if match_count >= self.n_neighbors: continue nn_list.append(nn_index) @@ -521,31 +737,88 @@ def _find_neighbors(self, inst): nn_list.append(nn_index) miss_count += 1 - if match_count >= self.n_neighbors and miss_count >= self.n_neighbors: + if match_count >= self.n_neighbors and \ + miss_count >= self.n_neighbors: break return np.array(nn_list) def _run_algorithm(self): - """ Runs nearest neighbor (NN) identification and feature scoring to yield ReliefF scores. """ + """ Runs nearest neighbor (NN) identification and feature scoring + to yield ReliefF scores. """ # Find nearest neighbors - NNlist = map(self._find_neighbors, range(self._datalen)) + if self.n_jobs != 1: + # we'll use this as an approx number of chunks for now + # chunking here seems to work better in testing - not clear why it + # does better than Parallel's built-in chunking + num_chunks = cpu_count() * 2 + inst_chunks = list(chunks(range(self._datalen), num_chunks)) + NNlist = np.concatenate( + Parallel(n_jobs=self.n_jobs)( + delayed(self._find_neighbors_batch)(chunk) + for chunk in inst_chunks + ) + ) + else: + NNlist = map(self._find_neighbors, range(self._datalen)) - # Feature scoring - using identified nearest neighbors nan_entries = np.isnan(self._X) # boolean mask for missing data values - - # Call the scoring method for the ReliefF algorithm if isinstance(self._weights, np.ndarray) and self.weight_final_scores: # Call the scoring method for the ReliefF algorithm for IRelief + scores = np.sum(Parallel(n_jobs=self.n_jobs)(delayed( - ReliefF_compute_scores)(instance_num, self.attr, nan_entries, self._num_attributes, self.mcmap, - NN, self._headers, self._class_type, self._X, self._y, self._labels_std, self.data_type, self._weights) - for instance_num, NN in zip(range(self._datalen), NNlist)), axis=0) + ReliefF_compute_scores)( + instance_num, + self.attr, + nan_entries, + self._num_attributes, + self.mcmap, + NN, + self._headers, + self._class_type, + self._X, + self._y, + self._labels_std, + self.data_type, + self._weights + ) for instance_num, NN in zip( + range(self._datalen), + NNlist) + ), axis=0) else: # Call the scoring method for the ReliefF algorithm - scores = np.sum(Parallel(n_jobs=self.n_jobs)(delayed( - ReliefF_compute_scores)(instance_num, self.attr, nan_entries, self._num_attributes, self.mcmap, - NN, self._headers, self._class_type, self._X, self._y, self._labels_std, self.data_type) - for instance_num, NN in zip(range(self._datalen), NNlist)), axis=0) + + if self._class_type == 'binary' and \ + all([x[0] == 'discrete' for x in self.attr.values()]): + # use optimized function + scores = np.sum(Parallel(n_jobs=self.n_jobs)( + delayed(compute_score_binary)( + self._X[instance_num], + self._X[NNs], + self._y[instance_num], + self._y[NNs] + ) for instance_num, NNs in zip( + range(self._datalen), + NNlist) + ), axis=0) / (self._datalen * self.n_neighbors * 2) + else: + scores = np.sum(Parallel(n_jobs=self.n_jobs)( + delayed(ReliefF_compute_scores)( + instance_num, + self.attr, + nan_entries, + self._num_attributes, + self.mcmap, + NN, + self._headers, + self._class_type, + self._X, + self._y, + self._labels_std, + self.data_type + ) for instance_num, NN in zip( + range(self._datalen), + NNlist) + ), axis=0) return np.array(scores) diff --git a/skrebate/scoring_utils.py b/skrebate/scoring_utils.py index fd98949..0e12169 100644 --- a/skrebate/scoring_utils.py +++ b/skrebate/scoring_utils.py @@ -223,7 +223,7 @@ def compute_score(attr, mcmap, NN, feature, inst, nan_entries, headers, class_ty count_hit += 1 if ftype == 'continuous': - #diff_hit -= abs(xinstfeature - xNNifeature) / mmdiff #Hits differently add continuous value differences rather than subtract them + #diff_hit -= abs(xinstfeature - xNNifeature) / mmdiff #Hits differently add continuous value differences rather than subtract them diff_hit -= (1-ramp_function(data_type, attr, fname, xinstfeature, xNNifeature)) #Sameness should yield most negative score else: #discrete feature if xinstfeature == xNNifeature: # The same feature value is observed (Used for more efficient 'far' scoring, since there should be fewer same values for 'far' instances) @@ -231,7 +231,7 @@ def compute_score(attr, mcmap, NN, feature, inst, nan_entries, headers, class_ty else: # MISS count_miss += 1 if ftype == 'continuous': - #diff_miss += abs(xinstfeature - xNNifeature) / mmdiff #Misses differntly subtract continuous value differences rather than add them + #diff_miss += abs(xinstfeature - xNNifeature) / mmdiff #Misses differntly subtract continuous value differences rather than add them diff_miss += (1-ramp_function(data_type, attr, fname, xinstfeature, xNNifeature)) #Sameness should yield most negative score else: #discrete feature if xinstfeature == xNNifeature: # The same feature value is observed (Used for more efficient 'far' scoring, since there should be fewer same values for 'far' instances) @@ -271,7 +271,7 @@ def compute_score(attr, mcmap, NN, feature, inst, nan_entries, headers, class_ty count_hit += 1 if ftype == 'continuous': #diff_hit -= abs(xinstfeature - xNNifeature) / mmdiff - diff_hit -= ramp_function(data_type, attr, fname, xinstfeature, xNNifeature) + diff_hit -= ramp_function(data_type, attr, fname, xinstfeature, xNNifeature) else: #discrete feature if xinstfeature != xNNifeature: # Feature score is reduced when we observe feature difference between 'near' instances with the same class. @@ -293,7 +293,7 @@ def compute_score(attr, mcmap, NN, feature, inst, nan_entries, headers, class_ty if(y[inst] == y[NN[i]]): # HIT count_hit += 1 if ftype == 'continuous': - #diff_hit -= abs(xinstfeature - xNNifeature) / mmdiff #Hits differently add continuous value differences rather than subtract them + #diff_hit -= abs(xinstfeature - xNNifeature) / mmdiff #Hits differently add continuous value differences rather than subtract them diff_hit -= (1-ramp_function(data_type, attr, fname, xinstfeature, xNNifeature)) #Sameness should yield most negative score else: #discrete features if xinstfeature == xNNifeature: @@ -329,7 +329,7 @@ def compute_score(attr, mcmap, NN, feature, inst, nan_entries, headers, class_ty for each in class_store: #multiclass normalization diff += class_store[each][1] * (class_store[each][0] / count_miss) * len(class_store)# Contribution of given miss class weighted by it's observed frequency within NN set. diff = diff / count_miss #'m' normalization - + #Hit component: with 'h' normalization if count_hit == 0: pass diff --git a/skrebate/surf.py b/skrebate/surf.py index e40c6e0..35fe637 100644 --- a/skrebate/surf.py +++ b/skrebate/surf.py @@ -78,6 +78,7 @@ def _find_neighbors(self, inst, avg_dist): if i > inst: locator.reverse() d = self._distance_array[locator[0]][locator[1]] + if d < avg_dist: # Defining the neighborhood with an average distance radius. min_indicies.append(i) for i in range(len(min_indicies)): @@ -104,6 +105,7 @@ def _run_algorithm(self): scores = np.sum(Parallel(n_jobs=self.n_jobs)(delayed( SURF_compute_scores)(instance_num, self.attr, nan_entries, self._num_attributes, self.mcmap, NN, self._headers, self._class_type, self._X, self._y, self._labels_std,self.data_type) - for instance_num, NN in zip(range(self._datalen), NNlist)),axis=0) + for instance_num, NN in zip(range(self._datalen), NNlist)),axis=0) + return np.array(scores) diff --git a/tests.py b/tests.py index 3ba82a4..588798b 100644 --- a/tests.py +++ b/tests.py @@ -26,7 +26,7 @@ from skrebate import ReliefF, SURF, SURFstar, MultiSURF, MultiSURFstar from sklearn.pipeline import make_pipeline from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor -from sklearn.preprocessing import Imputer +from sklearn.impute import SimpleImputer from sklearn.model_selection import cross_val_score import pandas as pd import numpy as np @@ -290,7 +290,7 @@ def test_relieff_pipeline_multiclass(): np.random.seed(49082) clf = make_pipeline(ReliefF(n_features_to_select=2, n_neighbors=10), - Imputer(), + SimpleImputer(), RandomForestClassifier(n_estimators=100, n_jobs=-1)) assert np.mean(cross_val_score(clf, features_multiclass, @@ -302,7 +302,7 @@ def test_surf_pipeline_multiclass(): np.random.seed(240932) clf = make_pipeline(SURF(n_features_to_select=2), - Imputer(), + SimpleImputer(), RandomForestClassifier(n_estimators=100, n_jobs=-1)) assert np.mean(cross_val_score(clf, features_multiclass, @@ -314,7 +314,7 @@ def test_surfstar_pipeline_multiclass(): np.random.seed(9238745) clf = make_pipeline(SURFstar(n_features_to_select=2), - Imputer(), + SimpleImputer(), RandomForestClassifier(n_estimators=100, n_jobs=-1)) assert np.mean(cross_val_score(clf, features_multiclass, @@ -326,7 +326,7 @@ def test_multisurfstar_pipeline_multiclass(): np.random.seed(320931) clf = make_pipeline(MultiSURFstar(n_features_to_select=2), - Imputer(), + SimpleImputer(), RandomForestClassifier(n_estimators=100, n_jobs=-1)) assert np.mean(cross_val_score(clf, features_multiclass, @@ -338,7 +338,7 @@ def test_multisurf_pipeline_multiclass(): np.random.seed(320931) clf = make_pipeline(MultiSURF(n_features_to_select=2), - Imputer(), + SimpleImputer(), RandomForestClassifier(n_estimators=100, n_jobs=-1)) assert np.mean(cross_val_score(clf, features_multiclass, @@ -466,7 +466,7 @@ def test_relieff_pipeline_missing_values(): np.random.seed(49082) clf = make_pipeline(ReliefF(n_features_to_select=2, n_neighbors=10), - Imputer(), + SimpleImputer(), RandomForestClassifier(n_estimators=100, n_jobs=-1)) assert np.mean(cross_val_score(clf, features_missing_values, @@ -478,7 +478,7 @@ def test_surf_pipeline_missing_values(): np.random.seed(240932) clf = make_pipeline(SURF(n_features_to_select=2), - Imputer(), + SimpleImputer(), RandomForestClassifier(n_estimators=100, n_jobs=-1)) assert np.mean(cross_val_score(clf, features_missing_values, @@ -490,7 +490,7 @@ def test_surfstar_pipeline_missing_values(): np.random.seed(9238745) clf = make_pipeline(SURFstar(n_features_to_select=2), - Imputer(), + SimpleImputer(), RandomForestClassifier(n_estimators=100, n_jobs=-1)) assert np.mean(cross_val_score(clf, features_missing_values, @@ -502,7 +502,7 @@ def test_multisurfstar_pipeline_missing_values(): np.random.seed(320931) clf = make_pipeline(MultiSURFstar(n_features_to_select=2), - Imputer(), + SimpleImputer(), RandomForestClassifier(n_estimators=100, n_jobs=-1)) assert np.mean(cross_val_score(clf, features_missing_values, @@ -514,7 +514,7 @@ def test_multisurf_pipeline_missing_values(): np.random.seed(320931) clf = make_pipeline(MultiSURF(n_features_to_select=2), - Imputer(), + SimpleImputer(), RandomForestClassifier(n_estimators=100, n_jobs=-1)) assert np.mean(cross_val_score(clf, features_missing_values,