Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix a bug in the old scoring API when introducing the new API in TPOT 0.9.1 #626

Merged
merged 11 commits into from
Nov 28, 2017
Merged
1 change: 0 additions & 1 deletion tests/export_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ def test_export_random_ind():
exported_pipeline.fit(training_features, training_target)
results = exported_pipeline.predict(testing_features)
"""
print(export_pipeline(pipeline, tpot_obj.operators, tpot_obj._pset))
assert expected_code == export_pipeline(pipeline, tpot_obj.operators, tpot_obj._pset)


Expand Down
31 changes: 28 additions & 3 deletions tests/tpot_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
from sklearn.datasets import load_digits, load_boston
from sklearn.model_selection import train_test_split, cross_val_score, GroupKFold
from sklearn.externals.joblib import Memory
from sklearn.metrics import make_scorer
from sklearn.metrics import make_scorer, roc_auc_score
from deap import creator, gp
from deap.tools import ParetoFront
from nose.tools import assert_raises, assert_not_equal, assert_greater_equal, assert_equal, assert_in
Expand Down Expand Up @@ -171,6 +171,29 @@ def my_scorer(clf, X, y):
assert tpot_obj.scoring_function == 'my_scorer'


def test_init_default_scoring_5():
"""Assert that TPOT intitializes with a valid sklearn metric function roc_auc_score."""
with warnings.catch_warnings(record=True) as w:
tpot_obj = TPOTClassifier(scoring=roc_auc_score)
assert len(w) == 1
assert issubclass(w[-1].category, DeprecationWarning)
assert "This scoring type was deprecated" in str(w[-1].message)
assert tpot_obj.scoring_function == 'roc_auc_score'


def test_init_default_scoring_6():
"""Assert that TPOT intitializes with a valid customized metric function in __main__"""
def my_scorer(y_true, y_pred):
return roc_auc_score(y_true, y_pred)
with warnings.catch_warnings(record=True) as w:
tpot_obj = TPOTClassifier(scoring=my_scorer)
assert len(w) == 1
assert issubclass(w[-1].category, DeprecationWarning)
assert "This scoring type was deprecated" in str(w[-1].message)
print(tpot_obj.scoring_function)
assert tpot_obj.scoring_function == 'my_scorer'


def test_invalid_score_warning():
"""Assert that the TPOT intitializes raises a ValueError when the scoring metrics is not available in SCORERS."""
# Mis-spelled scorer
Expand Down Expand Up @@ -1757,7 +1780,8 @@ def test_mutNodeReplacement():
if new_prims_list == old_prims_list: # Terminal mutated
assert new_ret_type_list == old_ret_type_list
else: # Primitive mutated
diff_prims = list(set(new_prims_list).symmetric_difference(old_prims_list))
diff_prims = [x for x in new_prims_list if x not in old_prims_list]
diff_prims += [x for x in old_prims_list if x not in new_prims_list]
if len(diff_prims) > 1: # Sometimes mutation randomly replaces an operator that already in the pipelines
assert diff_prims[0].ret == diff_prims[1].ret
assert mut_ind[0][0].ret == Output_Array
Expand Down Expand Up @@ -1795,7 +1819,8 @@ def test_mutNodeReplacement_2():
if isinstance(node, gp.Primitive):
Primitive_Count += 1
assert Primitive_Count == 4
diff_prims = list(set(new_prims_list).symmetric_difference(old_prims_list))
diff_prims = [x for x in new_prims_list if x not in old_prims_list]
diff_prims += [x for x in old_prims_list if x not in new_prims_list]
if len(diff_prims) > 1: # Sometimes mutation randomly replaces an operator that already in the pipelines
assert diff_prims[0].ret == diff_prims[1].ret
assert mut_ind[0][0].ret == Output_Array
Expand Down
19 changes: 14 additions & 5 deletions tpot/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,10 +358,17 @@ def _setup_scoring_function(self, scoring):
elif callable(scoring):
# Heuristic to ensure user has not passed a metric
module = getattr(scoring, '__module__', None)
if hasattr(module, 'startswith') and \
if sys.version_info[0] < 3:
if inspect.isfunction(scoring):
args_list = inspect.getargspec(scoring)[0]
else:
args_list = inspect.getargspec(scoring.__call__)[0]
else:
args_list = inspect.getfullargspec(scoring)[0]
if args_list == ["y_true", "y_pred"] or (hasattr(module, 'startswith') and \
(module.startswith('sklearn.metrics.') or module.startswith('tpot.metrics')) and \
not module.startswith('sklearn.metrics.scorer') and \
not module.startswith('sklearn.metrics.tests.'):
not module.startswith('sklearn.metrics.tests.')):
scoring_name = scoring.__name__
greater_is_better = 'loss' not in scoring_name and 'error' not in scoring_name
SCORERS[scoring_name] = make_scorer(scoring, greater_is_better=greater_is_better)
Expand Down Expand Up @@ -481,8 +488,10 @@ def _add_terminals(self):
self._pset.addTerminal(val, _type, name=terminal_name)

def _setup_toolbox(self):
creator.create('FitnessMulti', base.Fitness, weights=(-1.0, 1.0))
creator.create('Individual', gp.PrimitiveTree, fitness=creator.FitnessMulti, statistics=dict)
with warnings.catch_warnings():
warnings.simplefilter('ignore')
creator.create('FitnessMulti', base.Fitness, weights=(-1.0, 1.0))
creator.create('Individual', gp.PrimitiveTree, fitness=creator.FitnessMulti, statistics=dict)

self._toolbox = base.Toolbox()
self._toolbox.register('expr', self._gen_grow_safe, pset=self._pset, min_=1, max_=3)
Expand Down Expand Up @@ -946,7 +955,7 @@ def _create_periodic_checkpoint_folder(self):
if e.errno == errno.EEXIST and os.path.isdir(self.periodic_checkpoint_folder):
pass # Folder already exists. User probably created it.
else:
raise ValueError('Failed creating the periodic_checkpoint_folder:\n{}'.format(e))
raise ValueError('Failed creating the periodic_checkpoint_folder:\n{}'.format(e))

def export(self, output_file_name, skip_if_repeated=False):
"""Export the optimized pipeline as Python code.
Expand Down