Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More unit tests #515

Merged
merged 20 commits into from
Jul 17, 2017
Merged
131 changes: 128 additions & 3 deletions tests/driver_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

import numpy as np

from tpot.driver import positive_integer, float_range, _get_arg_parser, _print_args, _read_data_file, load_scoring_function
from tpot.driver import positive_integer, float_range, _get_arg_parser, _print_args, _read_data_file, load_scoring_function, tpot_driver
from nose.tools import assert_raises, assert_equal, assert_in
from unittest import TestCase

Expand All @@ -44,11 +44,12 @@ def captured_output():
finally:
sys.stdout, sys.stderr = old_out, old_err


def test_scoring_function_argument():
with captured_output() as (out, err):
# regular argument returns regular string
assert_equal(load_scoring_function("roc_auc"), "roc_auc")

# bad function returns exception
assert_raises(Exception, load_scoring_function, scoring_func="tests.__fake_BAD_FUNC_NAME")

Expand All @@ -75,6 +76,90 @@ def test_driver():
ret_stdout = subprocess.check_output(batcmd, shell=True)
try:
ret_val = float(ret_stdout.decode('UTF-8').split('\n')[-2].split(': ')[-1])

except Exception as e:
ret_val = -float('inf')
assert ret_val > 0.0


def test_driver_2():
"""Assert that the tpot_driver() in TPOT driver outputs normal result with verbosity = 1."""
args_list = [
'tests/tests.csv',
'-is', ',',
'-target', 'class',
'-g', '1',
'-p', '2',
'-os', '4',
'-cv', '5',
'-s',' 45',
'-config', 'TPOT light',
'-v', '1'
]
args = _get_arg_parser().parse_args(args_list)
with captured_output() as (out, err):
tpot_driver(args)
ret_stdout = out.getvalue()

assert "TPOT settings" not in ret_stdout
assert "Final Pareto front testing scores" not in ret_stdout
try:
ret_val = float(ret_stdout.split('\n')[-2].split(': ')[-1])
except Exception:
ret_val = -float('inf')
assert ret_val > 0.0


def test_driver_3():
"""Assert that the tpot_driver() in TPOT driver outputs normal result with verbosity = 2."""
args_list = [
'tests/tests.csv',
'-is', ',',
'-target', 'class',
'-g', '1',
'-p', '2',
'-os', '4',
'-cv', '5',
'-s',' 45',
'-config', 'TPOT light',
'-v', '2'
]
args = _get_arg_parser().parse_args(args_list)
with captured_output() as (out, err):
tpot_driver(args)
ret_stdout = out.getvalue()
assert "TPOT settings" in ret_stdout
assert "Final Pareto front testing scores" not in ret_stdout
try:
ret_val = float(ret_stdout.split('\n')[-2].split(': ')[-1])
except Exception:
ret_val = -float('inf')
assert ret_val > 0.0


def test_driver_4():
"""Assert that the tpot_driver() in TPOT driver outputs normal result with verbosity = 3."""
args_list = [
'tests/tests.csv',
'-is', ',',
'-target', 'class',
'-g', '1',
'-p', '2',
'-os', '4',
'-cv', '5',
'-s', '42',
'-config', 'TPOT light',
'-v', '3'
]
args = _get_arg_parser().parse_args(args_list)
with captured_output() as (out, err):
tpot_driver(args)
ret_stdout = out.getvalue()

assert "TPOT settings" in ret_stdout
assert "Final Pareto front testing scores" in ret_stdout
try:
ret_val = float(ret_stdout.split('\n')[-2].split('\t')[1])
except Exception:
ret_val = -float('inf')
assert ret_val > 0.0
Expand Down Expand Up @@ -107,6 +192,7 @@ class ParserTest(TestCase):
def setUp(self):
self.parser = _get_arg_parser()


def test_default_param(self):
"""Assert that the TPOT driver stores correct default values for all parameters."""
args = self.parser.parse_args(['tests/tests.csv'])
Expand All @@ -129,8 +215,9 @@ def test_default_param(self):
self.assertEqual(args.TPOT_MODE, 'classification')
self.assertEqual(args.VERBOSITY, 1)


def test_print_args(self):
"""Assert that _print_args prints correct values for all parameters."""
"""Assert that _print_args prints correct values for all parameters in default settings."""
args = self.parser.parse_args(['tests/tests.csv'])
with captured_output() as (out, err):
_print_args(args)
Expand Down Expand Up @@ -158,6 +245,44 @@ def test_print_args(self):
TPOT_MODE\t=\tclassification
VERBOSITY\t=\t1

"""

self.assertEqual(_sort_lines(expected_output), _sort_lines(output))


def test_print_args_2(self):
"""Assert that _print_args prints correct values for all parameters in regression mode."""
args_list = [
'tests/tests.csv',
'-mode', 'regression',
]
args = self.parser.parse_args(args_list)
with captured_output() as (out, err):
_print_args(args)
output = out.getvalue()
expected_output = """
TPOT settings:
CONFIG_FILE\t=\tNone
CROSSOVER_RATE\t=\t0.1
GENERATIONS\t=\t100
INPUT_FILE\t=\ttests/tests.csv
INPUT_SEPARATOR\t=\t\t
MAX_EVAL_MINS\t=\t5
MAX_TIME_MINS\t=\tNone
MUTATION_RATE\t=\t0.9
NUM_CV_FOLDS\t=\t5
NUM_JOBS\t=\t1
OFFSPRING_SIZE\t=\t100
CHECKPOINT_FOLDER\t=\tNone
OUTPUT_FILE\t=\t
POPULATION_SIZE\t=\t100
RANDOM_STATE\t=\tNone
SCORING_FN\t=\tneg_mean_squared_error
SUBSAMPLE\t=\t1.0
TARGET_NAME\t=\tclass
TPOT_MODE\t=\tregression
VERBOSITY\t=\t1

"""

self.assertEqual(_sort_lines(expected_output), _sort_lines(output))
Expand Down
35 changes: 35 additions & 0 deletions tests/test_config.py.bad
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# -*- coding: utf-8 -*-
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please rename this file to test_config_bad.py.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would cause a error during unit tests since it is not in right format as python codes. Check this link

Copy link
Contributor

@rhiever rhiever Jun 28, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's confusing. The CI systems are running that file? They shouldn't be?

Copy link
Contributor Author

@weixuanfu weixuanfu Jun 28, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nosetests -s -v checks all the .py files in tests folder.


"""Copyright 2015-Present Randal S. Olson.

This file is part of the TPOT library.

TPOT is free software: you can redistribute it and/or modify
it under the terms of the GNU Lesser General Public License as
published by the Free Software Foundation, either version 3 of
the License, or (at your option) any later version.

TPOT is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Lesser General Public License for more details.

You should have received a copy of the GNU Lesser General Public
License along with TPOT. If not, see <http://www.gnu.org/licenses/>.

"""

tpot_config = {
'sklearn.naive_bayes.GaussianNB': {
},

'sklearn.naive_bayes.BernoulliNB': {
'alpha': [1e-3, 1e-2, 1e-1, 1., 10., 100.],
'fit_prior': [True, False]
},

'sklearn.naive_bayes.MultinomialNB': {
'alpha': [1e-3, 1e-2, 1e-1, 1., 10., 100.]# miss a "," here
'fit_prior': [True, False]
}
}
55 changes: 51 additions & 4 deletions tests/tpot_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,15 +115,15 @@ def test_init_default_scoring_2():
assert tpot_obj.scoring_function == 'balanced_accuracy'


def test_invaild_score_warning():
def test_invalid_score_warning():
"""Assert that the TPOT intitializes raises a ValueError when the scoring metrics is not available in SCORERS."""
# Mis-spelled scorer
assert_raises(ValueError, TPOTClassifier, scoring='balanced_accuray')
# Correctly spelled
TPOTClassifier(scoring='balanced_accuracy')


def test_invaild_dataset_warning():
def test_invalid_dataset_warning():
"""Assert that the TPOT fit function raises a ValueError when dataset is not in right format."""
tpot_obj = TPOTClassifier(
random_state=42,
Expand All @@ -137,15 +137,15 @@ def test_invaild_dataset_warning():
assert_raises(ValueError, tpot_obj.fit, training_features, bad_training_target)


def test_invaild_subsample_ratio_warning():
def test_invalid_subsample_ratio_warning():
"""Assert that the TPOT intitializes raises a ValueError when subsample ratio is not in the range (0.0, 1.0]."""
# Invalid ratio
assert_raises(ValueError, TPOTClassifier, subsample=0.0)
# Valid ratio
TPOTClassifier(subsample=0.1)


def test_invaild_mut_rate_plus_xo_rate():
def test_invalid_mut_rate_plus_xo_rate():
"""Assert that the TPOT intitializes raises a ValueError when the sum of crossover and mutation probabilities is large than 1."""
# Invalid ratio
assert_raises(ValueError, TPOTClassifier, mutation_rate=0.8, crossover_rate=0.8)
Expand Down Expand Up @@ -199,6 +199,29 @@ def test_timeout():
assert return_value == "Timeout"


def test_invalid_pipeline():
"""Assert that _wrapped_cross_val_score return -float(\'inf\') with a invalid_pipeline"""
tpot_obj = TPOTClassifier()
# a invalid pipeline
# Dual or primal formulation. Dual formulation is only implemented for l2 penalty.
pipeline_string = (
'LogisticRegression(input_matrix, LogisticRegression__C=10.0, '
'LogisticRegression__dual=True, LogisticRegression__penalty=l1)'
)
tpot_obj._optimized_pipeline = creator.Individual.from_string(pipeline_string, tpot_obj._pset)
tpot_obj.fitted_pipeline_ = tpot_obj._toolbox.compile(expr=tpot_obj._optimized_pipeline)
# test _wrapped_cross_val_score with cv=20 so that it is impossible to finish in 1 second
return_value = _wrapped_cross_val_score(tpot_obj.fitted_pipeline_,
training_features,
training_target,
cv=5,
scoring_function='accuracy',
sample_weight=None,
groups=None,
timeout=300)
assert return_value == -float('inf')


def test_balanced_accuracy():
"""Assert that the balanced_accuracy in TPOT returns correct accuracy."""
y_true = np.array([1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4])
Expand Down Expand Up @@ -298,6 +321,30 @@ def test_conf_dict_3():
assert tpot_obj.config_dict == tested_config_dict


def test_read_config_file():
"""Assert that _read_config_file rasie FileNotFoundError with a wrong path."""
tpot_obj = TPOTRegressor()
# typo for "tests/test_config.py"
try:
FileNotFoundError
except NameError: # python 2 has no FileNotFoundError
FileNotFoundError = IOError

assert_raises(FileNotFoundError, tpot_obj._read_config_file, "tests/test_confg.py")


def test_read_config_file_2():
"""Assert that _read_config_file rasie AttributeError with a wrong dictionary name."""
tpot_obj = TPOTRegressor()
assert_raises(AttributeError, tpot_obj._read_config_file, "tpot/config/classifier_light.py")


def test_read_config_file_3():
"""Assert that _read_config_file rasie ValueError with wrong dictionary format"""
tpot_obj = TPOTRegressor()
assert_raises(ValueError, tpot_obj._read_config_file, "tests/test_config.py.bad")


def test_random_ind():
"""Assert that the TPOTClassifier can generate the same pipeline with same random seed."""
tpot_obj = TPOTClassifier(random_state=43)
Expand Down
7 changes: 6 additions & 1 deletion tpot/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,11 @@ def _setup_config(self, config_dict):
self.config_dict = self.default_config_dict

def _read_config_file(self, config_path):
try:
FileNotFoundError
except NameError: # python 2 has no FileNotFoundError
FileNotFoundError = IOError

try:
custom_config = imp.new_module('custom_config')

Expand All @@ -357,7 +362,7 @@ def _read_config_file(self, config_path):
'a dictionary named "tpot_config".'
)
except Exception as e:
raise type(e)(
raise ValueError(
'An error occured while attempting to read the specified '
'custom TPOT operator configuration file.'
)
Expand Down
12 changes: 7 additions & 5 deletions tpot/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,18 +442,17 @@ def load_scoring_function(scoring_func):
sys.path.insert(0, module_path)
scoring_func = getattr(import_module(module_name), func_name)
sys.path.pop(0)

print('manual scoring function: {}'.format(scoring_func))
print('taken from module: {}'.format(module_name))
except Exception as e:
print('failed importing custom scoring function, error: {}'.format(str(e)))
raise ValueError(e)

return scoring_func

def main():
def tpot_driver(args):
"""Perform a TPOT run."""
args = _get_arg_parser().parse_args()
if args.VERBOSITY >= 2:
_print_args(args)

Expand All @@ -464,7 +463,7 @@ def main():
axis=1
)


scoring_func = load_scoring_function(args.SCORING_FN)

training_features, testing_features, training_target, testing_target = \
Expand Down Expand Up @@ -511,6 +510,9 @@ def main():
if args.OUTPUT_FILE != '':
tpot_obj.export(args.OUTPUT_FILE)

def main():
args = _get_arg_parser().parse_args()
tpot_driver(args)

if __name__ == '__main__':
main()