From 3dc40701a4b4b0fbb23b191ad996412072251be5 Mon Sep 17 00:00:00 2001 From: weixuanfu2016 Date: Tue, 14 Nov 2017 14:16:36 -0500 Subject: [PATCH 01/10] temp fix --- tests/tpot_tests.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/tests/tpot_tests.py b/tests/tpot_tests.py index 6bcaadae..e964a567 100644 --- a/tests/tpot_tests.py +++ b/tests/tpot_tests.py @@ -144,11 +144,11 @@ def test_init_default_scoring(): def test_init_default_scoring_2(): """Assert that TPOT intitializes with a valid customized metric function.""" - with warnings.catch_warnings(record=True) as w: - tpot_obj = TPOTClassifier(scoring=balanced_accuracy) - assert len(w) == 1 - assert issubclass(w[-1].category, DeprecationWarning) - assert "This scoring type was deprecated" in str(w[-1].message) + #with warnings.catch_warnings(record=True) as w: + tpot_obj = TPOTClassifier(scoring=balanced_accuracy) + #assert len(w) == 1 # deap 1.2.2 warning message made this unit test failed + #assert issubclass(w[-1].category, DeprecationWarning) # deap 1.2.2 warning message made this unit test failed + #assert "This scoring type was deprecated" in str(w[-1].message) # deap 1.2.2 warning message made this unit test failed assert tpot_obj.scoring_function == 'balanced_accuracy' @@ -156,7 +156,7 @@ def test_init_default_scoring_3(): """Assert that TPOT intitializes with a valid _BaseScorer.""" with warnings.catch_warnings(record=True) as w: tpot_obj = TPOTClassifier(scoring=make_scorer(balanced_accuracy)) - assert len(w) == 0 + #assert len(w) == 0 # deap 1.2.2 warning message made this unit test failed assert tpot_obj.scoring_function == 'balanced_accuracy' @@ -167,7 +167,7 @@ def my_scorer(clf, X, y): with warnings.catch_warnings(record=True) as w: tpot_obj = TPOTClassifier(scoring=my_scorer) - assert len(w) == 0 + #assert len(w) == 0 # deap 1.2.2 warning message made this unit test failed assert tpot_obj.scoring_function == 'my_scorer' @@ -1757,7 +1757,8 @@ def test_mutNodeReplacement(): if new_prims_list == old_prims_list: # Terminal mutated assert new_ret_type_list == old_ret_type_list else: # Primitive mutated - diff_prims = list(set(new_prims_list).symmetric_difference(old_prims_list)) + diff_prims = [x for x in new_prims_list if x not in old_prims_list] + diff_prims += [x for x in old_prims_list if x not in new_prims_list] if len(diff_prims) > 1: # Sometimes mutation randomly replaces an operator that already in the pipelines assert diff_prims[0].ret == diff_prims[1].ret assert mut_ind[0][0].ret == Output_Array @@ -1795,7 +1796,8 @@ def test_mutNodeReplacement_2(): if isinstance(node, gp.Primitive): Primitive_Count += 1 assert Primitive_Count == 4 - diff_prims = list(set(new_prims_list).symmetric_difference(old_prims_list)) + diff_prims = [x for x in new_prims_list if x not in old_prims_list] + diff_prims += [x for x in old_prims_list if x not in new_prims_list] if len(diff_prims) > 1: # Sometimes mutation randomly replaces an operator that already in the pipelines assert diff_prims[0].ret == diff_prims[1].ret assert mut_ind[0][0].ret == Output_Array From fdfbd64f6a72c28ee5389879365e624f6674e741 Mon Sep 17 00:00:00 2001 From: weixuanfu2016 Date: Wed, 15 Nov 2017 11:39:21 -0500 Subject: [PATCH 02/10] warning catch --- tests/tpot_tests.py | 10 +++++----- tpot/base.py | 8 +++++--- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/tests/tpot_tests.py b/tests/tpot_tests.py index e964a567..6e191ffd 100644 --- a/tests/tpot_tests.py +++ b/tests/tpot_tests.py @@ -146,9 +146,9 @@ def test_init_default_scoring_2(): """Assert that TPOT intitializes with a valid customized metric function.""" #with warnings.catch_warnings(record=True) as w: tpot_obj = TPOTClassifier(scoring=balanced_accuracy) - #assert len(w) == 1 # deap 1.2.2 warning message made this unit test failed - #assert issubclass(w[-1].category, DeprecationWarning) # deap 1.2.2 warning message made this unit test failed - #assert "This scoring type was deprecated" in str(w[-1].message) # deap 1.2.2 warning message made this unit test failed + assert len(w) == 1 # deap 1.2.2 warning message made this unit test failed + assert issubclass(w[-1].category, DeprecationWarning) # deap 1.2.2 warning message made this unit test failed + assert "This scoring type was deprecated" in str(w[-1].message) # deap 1.2.2 warning message made this unit test failed assert tpot_obj.scoring_function == 'balanced_accuracy' @@ -156,7 +156,7 @@ def test_init_default_scoring_3(): """Assert that TPOT intitializes with a valid _BaseScorer.""" with warnings.catch_warnings(record=True) as w: tpot_obj = TPOTClassifier(scoring=make_scorer(balanced_accuracy)) - #assert len(w) == 0 # deap 1.2.2 warning message made this unit test failed + assert len(w) == 0 # deap 1.2.2 warning message made this unit test failed assert tpot_obj.scoring_function == 'balanced_accuracy' @@ -167,7 +167,7 @@ def my_scorer(clf, X, y): with warnings.catch_warnings(record=True) as w: tpot_obj = TPOTClassifier(scoring=my_scorer) - #assert len(w) == 0 # deap 1.2.2 warning message made this unit test failed + assert len(w) == 0 # deap 1.2.2 warning message made this unit test failed assert tpot_obj.scoring_function == 'my_scorer' diff --git a/tpot/base.py b/tpot/base.py index 86f6a9ec..42d81fe3 100644 --- a/tpot/base.py +++ b/tpot/base.py @@ -481,8 +481,10 @@ def _add_terminals(self): self._pset.addTerminal(val, _type, name=terminal_name) def _setup_toolbox(self): - creator.create('FitnessMulti', base.Fitness, weights=(-1.0, 1.0)) - creator.create('Individual', gp.PrimitiveTree, fitness=creator.FitnessMulti, statistics=dict) + with warnings.catch_warnings(): + warnings.simplefilter('ignore') + creator.create('FitnessMulti', base.Fitness, weights=(-1.0, 1.0)) + creator.create('Individual', gp.PrimitiveTree, fitness=creator.FitnessMulti, statistics=dict) self._toolbox = base.Toolbox() self._toolbox.register('expr', self._gen_grow_safe, pset=self._pset, min_=1, max_=3) @@ -946,7 +948,7 @@ def _create_periodic_checkpoint_folder(self): if e.errno == errno.EEXIST and os.path.isdir(self.periodic_checkpoint_folder): pass # Folder already exists. User probably created it. else: - raise ValueError('Failed creating the periodic_checkpoint_folder:\n{}'.format(e)) + raise ValueError('Failed creating the periodic_checkpoint_folder:\n{}'.format(e)) def export(self, output_file_name, skip_if_repeated=False): """Export the optimized pipeline as Python code. From e69bbacdfc296a419e0a5431fc9ced1221c5ae95 Mon Sep 17 00:00:00 2001 From: weixuanfu2016 Date: Wed, 15 Nov 2017 11:40:54 -0500 Subject: [PATCH 03/10] bug fix --- tests/tpot_tests.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/tpot_tests.py b/tests/tpot_tests.py index 6e191ffd..79eb6af0 100644 --- a/tests/tpot_tests.py +++ b/tests/tpot_tests.py @@ -144,8 +144,8 @@ def test_init_default_scoring(): def test_init_default_scoring_2(): """Assert that TPOT intitializes with a valid customized metric function.""" - #with warnings.catch_warnings(record=True) as w: - tpot_obj = TPOTClassifier(scoring=balanced_accuracy) + with warnings.catch_warnings(record=True) as w: + tpot_obj = TPOTClassifier(scoring=balanced_accuracy) assert len(w) == 1 # deap 1.2.2 warning message made this unit test failed assert issubclass(w[-1].category, DeprecationWarning) # deap 1.2.2 warning message made this unit test failed assert "This scoring type was deprecated" in str(w[-1].message) # deap 1.2.2 warning message made this unit test failed From 4e2c2dfc6f8f8bd326d986e434a56063cfb170d3 Mon Sep 17 00:00:00 2001 From: weixuanfu2016 Date: Wed, 15 Nov 2017 11:47:37 -0500 Subject: [PATCH 04/10] clean up unit tests --- tests/export_tests.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/export_tests.py b/tests/export_tests.py index b4e2c73f..8ce60ec3 100644 --- a/tests/export_tests.py +++ b/tests/export_tests.py @@ -76,7 +76,6 @@ def test_export_random_ind(): exported_pipeline.fit(training_features, training_target) results = exported_pipeline.predict(testing_features) """ - print(export_pipeline(pipeline, tpot_obj.operators, tpot_obj._pset)) assert expected_code == export_pipeline(pipeline, tpot_obj.operators, tpot_obj._pset) From ee898b73176a04049408dfa30e0bc09013197410 Mon Sep 17 00:00:00 2001 From: weixuanfu2016 Date: Wed, 15 Nov 2017 13:35:40 -0500 Subject: [PATCH 05/10] clean codes --- tests/tpot_tests.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/tpot_tests.py b/tests/tpot_tests.py index 79eb6af0..9f21c17c 100644 --- a/tests/tpot_tests.py +++ b/tests/tpot_tests.py @@ -146,9 +146,9 @@ def test_init_default_scoring_2(): """Assert that TPOT intitializes with a valid customized metric function.""" with warnings.catch_warnings(record=True) as w: tpot_obj = TPOTClassifier(scoring=balanced_accuracy) - assert len(w) == 1 # deap 1.2.2 warning message made this unit test failed - assert issubclass(w[-1].category, DeprecationWarning) # deap 1.2.2 warning message made this unit test failed - assert "This scoring type was deprecated" in str(w[-1].message) # deap 1.2.2 warning message made this unit test failed + assert len(w) == 1 + assert issubclass(w[-1].category, DeprecationWarning) + assert "This scoring type was deprecated" in str(w[-1].message) assert tpot_obj.scoring_function == 'balanced_accuracy' @@ -156,7 +156,7 @@ def test_init_default_scoring_3(): """Assert that TPOT intitializes with a valid _BaseScorer.""" with warnings.catch_warnings(record=True) as w: tpot_obj = TPOTClassifier(scoring=make_scorer(balanced_accuracy)) - assert len(w) == 0 # deap 1.2.2 warning message made this unit test failed + assert len(w) == 0 assert tpot_obj.scoring_function == 'balanced_accuracy' @@ -167,7 +167,7 @@ def my_scorer(clf, X, y): with warnings.catch_warnings(record=True) as w: tpot_obj = TPOTClassifier(scoring=my_scorer) - assert len(w) == 0 # deap 1.2.2 warning message made this unit test failed + assert len(w) == 0 assert tpot_obj.scoring_function == 'my_scorer' From 1f0c13c808957fa5f0831a8f0dedc90a2dd857b3 Mon Sep 17 00:00:00 2001 From: weixuanfu2016 Date: Fri, 17 Nov 2017 11:40:41 -0500 Subject: [PATCH 06/10] fix a bug in scoring api --- tests/tpot_tests.py | 11 ++++++++++- tpot/base.py | 10 +++++++--- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/tests/tpot_tests.py b/tests/tpot_tests.py index 6bcaadae..7cc5d71d 100644 --- a/tests/tpot_tests.py +++ b/tests/tpot_tests.py @@ -56,7 +56,7 @@ from sklearn.datasets import load_digits, load_boston from sklearn.model_selection import train_test_split, cross_val_score, GroupKFold from sklearn.externals.joblib import Memory -from sklearn.metrics import make_scorer +from sklearn.metrics import make_scorer, roc_auc_score from deap import creator, gp from deap.tools import ParetoFront from nose.tools import assert_raises, assert_not_equal, assert_greater_equal, assert_equal, assert_in @@ -170,6 +170,15 @@ def my_scorer(clf, X, y): assert len(w) == 0 assert tpot_obj.scoring_function == 'my_scorer' +def test_init_default_scoring_5(): + """Assert that TPOT intitializes with a valid sklearn metric function roc_auc_score.""" + with warnings.catch_warnings(record=True) as w: + tpot_obj = TPOTClassifier(scoring=roc_auc_score) + assert len(w) == 1 + assert issubclass(w[-1].category, DeprecationWarning) + assert "This scoring type was deprecated" in str(w[-1].message) + assert tpot_obj.scoring_function == 'roc_auc_score' + def test_invalid_score_warning(): """Assert that the TPOT intitializes raises a ValueError when the scoring metrics is not available in SCORERS.""" diff --git a/tpot/base.py b/tpot/base.py index 86f6a9ec..f9a6bf05 100644 --- a/tpot/base.py +++ b/tpot/base.py @@ -358,10 +358,14 @@ def _setup_scoring_function(self, scoring): elif callable(scoring): # Heuristic to ensure user has not passed a metric module = getattr(scoring, '__module__', None) - if hasattr(module, 'startswith') and \ + if sys.version_info[0] < 3: + args_list = inspect.getargspec(scoring)[0] + else: + args_list = inspect.getfullargspec(scoring)[0] + if args_list == ["y_true", "y_pred"] or (hasattr(module, 'startswith') and \ (module.startswith('sklearn.metrics.') or module.startswith('tpot.metrics')) and \ not module.startswith('sklearn.metrics.scorer') and \ - not module.startswith('sklearn.metrics.tests.'): + not module.startswith('sklearn.metrics.tests.')): scoring_name = scoring.__name__ greater_is_better = 'loss' not in scoring_name and 'error' not in scoring_name SCORERS[scoring_name] = make_scorer(scoring, greater_is_better=greater_is_better) @@ -946,7 +950,7 @@ def _create_periodic_checkpoint_folder(self): if e.errno == errno.EEXIST and os.path.isdir(self.periodic_checkpoint_folder): pass # Folder already exists. User probably created it. else: - raise ValueError('Failed creating the periodic_checkpoint_folder:\n{}'.format(e)) + raise ValueError('Failed creating the periodic_checkpoint_folder:\n{}'.format(e)) def export(self, output_file_name, skip_if_repeated=False): """Export the optimized pipeline as Python code. From ade169e1a588c1a33116990c3fff5aaedaeb885a Mon Sep 17 00:00:00 2001 From: weixuanfu2016 Date: Fri, 17 Nov 2017 11:44:32 -0500 Subject: [PATCH 07/10] add one more unit test --- tests/tpot_tests.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/tpot_tests.py b/tests/tpot_tests.py index 2c3f4a05..40542375 100644 --- a/tests/tpot_tests.py +++ b/tests/tpot_tests.py @@ -170,6 +170,7 @@ def my_scorer(clf, X, y): assert len(w) == 0 assert tpot_obj.scoring_function == 'my_scorer' + def test_init_default_scoring_5(): """Assert that TPOT intitializes with a valid sklearn metric function roc_auc_score.""" with warnings.catch_warnings(record=True) as w: @@ -180,6 +181,19 @@ def test_init_default_scoring_5(): assert tpot_obj.scoring_function == 'roc_auc_score' +def test_init_default_scoring_6(): + """Assert that TPOT intitializes with a valid customized metric function in __main__""" + def my_scorer(y_true, y_pred): + return roc_auc_score(y_true, y_pred) + with warnings.catch_warnings(record=True) as w: + tpot_obj = TPOTClassifier(scoring=my_scorer) + assert len(w) == 1 + assert issubclass(w[-1].category, DeprecationWarning) + assert "This scoring type was deprecated" in str(w[-1].message) + print(tpot_obj.scoring_function) + assert tpot_obj.scoring_function == 'my_scorer' + + def test_invalid_score_warning(): """Assert that the TPOT intitializes raises a ValueError when the scoring metrics is not available in SCORERS.""" # Mis-spelled scorer From 7e3dfffcb0e82d5343443a1582fecbd69f43bdf1 Mon Sep 17 00:00:00 2001 From: weixuanfu2016 Date: Fri, 17 Nov 2017 12:07:55 -0500 Subject: [PATCH 08/10] support for py27 --- tpot/base.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tpot/base.py b/tpot/base.py index 8d909ae0..cfe711be 100644 --- a/tpot/base.py +++ b/tpot/base.py @@ -359,7 +359,10 @@ def _setup_scoring_function(self, scoring): # Heuristic to ensure user has not passed a metric module = getattr(scoring, '__module__', None) if sys.version_info[0] < 3: - args_list = inspect.getargspec(scoring)[0] + if not inspect.isclass(scoring) + args_list = inspect.getargspec(scoring)[0] + else: + args_list = ['NA', 'NA'] else: args_list = inspect.getfullargspec(scoring)[0] if args_list == ["y_true", "y_pred"] or (hasattr(module, 'startswith') and \ From f32aaa3a6e00a3c4cfc240e0dec5f19152dffe38 Mon Sep 17 00:00:00 2001 From: weixuanfu2016 Date: Fri, 17 Nov 2017 12:08:28 -0500 Subject: [PATCH 09/10] clean codes --- tpot/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tpot/base.py b/tpot/base.py index cfe711be..a19bcb0f 100644 --- a/tpot/base.py +++ b/tpot/base.py @@ -359,7 +359,7 @@ def _setup_scoring_function(self, scoring): # Heuristic to ensure user has not passed a metric module = getattr(scoring, '__module__', None) if sys.version_info[0] < 3: - if not inspect.isclass(scoring) + if not inspect.isclass(scoring): args_list = inspect.getargspec(scoring)[0] else: args_list = ['NA', 'NA'] From eea87496fcd2e88b8e940b8285d8606910b04c4a Mon Sep 17 00:00:00 2001 From: weixuanfu2016 Date: Fri, 17 Nov 2017 12:19:39 -0500 Subject: [PATCH 10/10] fix py27 --- tpot/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tpot/base.py b/tpot/base.py index a19bcb0f..2a3bacd1 100644 --- a/tpot/base.py +++ b/tpot/base.py @@ -359,10 +359,10 @@ def _setup_scoring_function(self, scoring): # Heuristic to ensure user has not passed a metric module = getattr(scoring, '__module__', None) if sys.version_info[0] < 3: - if not inspect.isclass(scoring): + if inspect.isfunction(scoring): args_list = inspect.getargspec(scoring)[0] else: - args_list = ['NA', 'NA'] + args_list = inspect.getargspec(scoring.__call__)[0] else: args_list = inspect.getfullargspec(scoring)[0] if args_list == ["y_true", "y_pred"] or (hasattr(module, 'startswith') and \