Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More unit tests #515

Merged
merged 20 commits into from
Jul 17, 2017
130 changes: 127 additions & 3 deletions tests/driver_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

import numpy as np

from tpot.driver import positive_integer, float_range, _get_arg_parser, _print_args, _read_data_file, load_scoring_function
from tpot.driver import positive_integer, float_range, _get_arg_parser, _print_args, _read_data_file, load_scoring_function, main
from nose.tools import assert_raises, assert_equal, assert_in
from unittest import TestCase

Expand All @@ -44,11 +44,12 @@ def captured_output():
finally:
sys.stdout, sys.stderr = old_out, old_err


def test_scoring_function_argument():
with captured_output() as (out, err):
# regular argument returns regular string
assert_equal(load_scoring_function("roc_auc"), "roc_auc")

# bad function returns exception
assert_raises(Exception, load_scoring_function, scoring_func="tests.__fake_BAD_FUNC_NAME")

Expand Down Expand Up @@ -80,6 +81,89 @@ def test_driver():
assert ret_val > 0.0


def test_driver_2():
"""Assert that the main() in TPOT driver outputs normal result with verbosity = 1"""
args_list = [
'tests/tests.csv',
'-is', ',',
'-target', 'class',
'-g', '1',
'-p', '2',
'-os', '4',
'-cv', '5',
'-s',' 45',
'-config', 'TPOT light',
'-v', '1'
]
args = _get_arg_parser().parse_args(args_list)
with captured_output() as (out, err):
main(args)
ret_stdout = out.getvalue()

assert "TPOT settings" not in ret_stdout
assert "Final Pareto front testing scores" not in ret_stdout
try:
ret_val = float(ret_stdout.split('\n')[-2].split(': ')[-1])
except Exception:
ret_val = -float('inf')
assert ret_val > 0.0


def test_driver_3():
"""Assert that the main() in TPOT driver outputs normal result with verbosity = 2"""
args_list = [
'tests/tests.csv',
'-is', ',',
'-target', 'class',
'-g', '1',
'-p', '2',
'-os', '4',
'-cv', '5',
'-s',' 45',
'-config', 'TPOT light',
'-v', '2'
]
args = _get_arg_parser().parse_args(args_list)
with captured_output() as (out, err):
main(args)
ret_stdout = out.getvalue()
assert "TPOT settings" in ret_stdout
assert "Final Pareto front testing scores" not in ret_stdout
try:
ret_val = float(ret_stdout.split('\n')[-2].split(': ')[-1])
except Exception:
ret_val = -float('inf')
assert ret_val > 0.0


def test_driver_4():
"""Assert that the main() in TPOT driver outputs normal result with verbosity = 3"""
args_list = [
'tests/tests.csv',
'-is', ',',
'-target', 'class',
'-g', '1',
'-p', '2',
'-os', '4',
'-cv', '5',
'-s', '42',
'-config', 'TPOT light',
'-v', '3'
]
args = _get_arg_parser().parse_args(args_list)
with captured_output() as (out, err):
main(args)
ret_stdout = out.getvalue()

assert "TPOT settings" in ret_stdout
assert "Final Pareto front testing scores" in ret_stdout
try:
ret_val = float(ret_stdout.split('\n')[-2].split('\t')[1])
except Exception:
ret_val = -float('inf')
assert ret_val > 0.0


def test_read_data_file():
"""Assert that _read_data_file raises ValueError when the targe column is missing."""
# Mis-spelled target
Expand Down Expand Up @@ -107,6 +191,7 @@ class ParserTest(TestCase):
def setUp(self):
self.parser = _get_arg_parser()


def test_default_param(self):
"""Assert that the TPOT driver stores correct default values for all parameters."""
args = self.parser.parse_args(['tests/tests.csv'])
Expand All @@ -129,8 +214,9 @@ def test_default_param(self):
self.assertEqual(args.TPOT_MODE, 'classification')
self.assertEqual(args.VERBOSITY, 1)


def test_print_args(self):
"""Assert that _print_args prints correct values for all parameters."""
"""Assert that _print_args prints correct values for all parameters in default settings"""
args = self.parser.parse_args(['tests/tests.csv'])
with captured_output() as (out, err):
_print_args(args)
Expand Down Expand Up @@ -158,6 +244,44 @@ def test_print_args(self):
TPOT_MODE\t=\tclassification
VERBOSITY\t=\t1

"""

self.assertEqual(_sort_lines(expected_output), _sort_lines(output))


def test_print_args_2(self):
"""Assert that _print_args prints correct values for all parameters in regression mode."""
args_list = [
'tests/tests.csv',
'-mode', 'regression',
]
args = self.parser.parse_args(args_list)
with captured_output() as (out, err):
_print_args(args)
output = out.getvalue()
expected_output = """
TPOT settings:
CONFIG_FILE\t=\tNone
CROSSOVER_RATE\t=\t0.1
GENERATIONS\t=\t100
INPUT_FILE\t=\ttests/tests.csv
INPUT_SEPARATOR\t=\t\t
MAX_EVAL_MINS\t=\t5
MAX_TIME_MINS\t=\tNone
MUTATION_RATE\t=\t0.9
NUM_CV_FOLDS\t=\t5
NUM_JOBS\t=\t1
OFFSPRING_SIZE\t=\t100
CHECKPOINT_FOLDER\t=\tNone
OUTPUT_FILE\t=\t
POPULATION_SIZE\t=\t100
RANDOM_STATE\t=\tNone
SCORING_FN\t=\tneg_mean_squared_error
SUBSAMPLE\t=\t1.0
TARGET_NAME\t=\tclass
TPOT_MODE\t=\tregression
VERBOSITY\t=\t1

"""

self.assertEqual(_sort_lines(expected_output), _sort_lines(output))
Expand Down
25 changes: 24 additions & 1 deletion tests/tpot_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def test_init_default_scoring_2():
assert tpot_obj.scoring_function == 'balanced_accuracy'


def test_invaild_score_warning():
def test_invalid_score_warning():
"""Assert that the TPOT intitializes raises a ValueError when the scoring metrics is not available in SCORERS."""
# Mis-spelled scorer
assert_raises(ValueError, TPOTClassifier, scoring='balanced_accuray')
Expand Down Expand Up @@ -199,6 +199,29 @@ def test_timeout():
assert return_value == "Timeout"


def test_invaild_pipeline():
"""Assert that _wrapped_cross_val_score return -float(\'inf\') with a invalid_pipeline"""
tpot_obj = TPOTClassifier()
# a invalid pipeline
# Dual or primal formulation. Dual formulation is only implemented for l2 penalty.
pipeline_string = (
'LogisticRegression(input_matrix, LogisticRegression__C=10.0, '
'LogisticRegression__dual=True, LogisticRegression__penalty=l1)'
)
tpot_obj._optimized_pipeline = creator.Individual.from_string(pipeline_string, tpot_obj._pset)
tpot_obj.fitted_pipeline_ = tpot_obj._toolbox.compile(expr=tpot_obj._optimized_pipeline)
# test _wrapped_cross_val_score with cv=20 so that it is impossible to finish in 1 second
return_value = _wrapped_cross_val_score(tpot_obj.fitted_pipeline_,
training_features,
training_target,
cv=5,
scoring_function='accuracy',
sample_weight=None,
groups=None,
timeout=300)
assert return_value == -float('inf')


def test_balanced_accuracy():
"""Assert that the balanced_accuracy in TPOT returns correct accuracy."""
y_true = np.array([1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4])
Expand Down
12 changes: 6 additions & 6 deletions tpot/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,18 +442,17 @@ def load_scoring_function(scoring_func):
sys.path.insert(0, module_path)
scoring_func = getattr(import_module(module_name), func_name)
sys.path.pop(0)

print('manual scoring function: {}'.format(scoring_func))
print('taken from module: {}'.format(module_name))
except Exception as e:
print('failed importing custom scoring function, error: {}'.format(str(e)))
raise ValueError(e)

return scoring_func

def main():
def main(args):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change will break the CLI.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CLI issue is fixed in the commits below

"""Perform a TPOT run."""
args = _get_arg_parser().parse_args()
if args.VERBOSITY >= 2:
_print_args(args)

Expand All @@ -464,7 +463,7 @@ def main():
axis=1
)


scoring_func = load_scoring_function(args.SCORING_FN)

training_features, testing_features, training_target, testing_target = \
Expand Down Expand Up @@ -513,4 +512,5 @@ def main():


if __name__ == '__main__':
main()
args = _get_arg_parser().parse_args()
main(args)