Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add parameter "progress file" #1038

Merged
merged 13 commits into from
Apr 10, 2020
11 changes: 10 additions & 1 deletion docs_sources/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
<strong>periodic_checkpoint_folder</strong>=None,
<strong>early_stop</strong>=None,
<strong>verbosity</strong>=0,
<strong>disable_update_check</strong>=False</em>)</pre>
<strong>disable_update_check</strong>=False,
<strong>log_file</strong>=None
</em>)</pre>
<div align="right"><a href="https://github.com/EpistasisLab/tpot/blob/master/tpot/base.py">source</a></div>

Automated machine learning for supervised classification tasks.
Expand Down Expand Up @@ -222,6 +224,13 @@ Flag indicating whether the TPOT version checker should be disabled.
<br /><br />
The update checker will tell you when a new version of TPOT has been released.
</blockquote>

<strong>log_file</strong>: io.TextIOWrapper or io.StringIO, optional (defaul: sys.stdout)
<br /><br />
<blockquote>
Save progress content to a file.
</blockquote>

</td>
</tr>

Expand Down
Empty file added progress.txt
Empty file.
72 changes: 72 additions & 0 deletions tests/test_log_file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# -*- coding: utf-8 -*-

"""This file is part of the TPOT library.

TPOT was primarily developed at the University of Pennsylvania by:
- Randal S. Olson (rso@randalolson.com)
- Weixuan Fu (weixuanf@upenn.edu)
- Daniel Angell (dpa34@drexel.edu)
- and many more generous open source contributors

TPOT is free software: you can redistribute it and/or modify
it under the terms of the GNU Lesser General Public License as
published by the Free Software Foundation, either version 3 of
the License, or (at your option) any later version.

TPOT is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Lesser General Public License for more details.

You should have received a copy of the GNU Lesser General Public
License along with TPOT. If not, see <http://www.gnu.org/licenses/>.

"""

from tpot import TPOTClassifier
from sklearn.datasets import load_iris
from nose.tools import assert_equal
import os

data = load_iris()
X = data['data']
y = data['target']

def test_log_file_verbosity_1():
""" Set verbosity as 1. Assert log_file parameter to generate log file. """
file_name = "progress_verbose_1.log"
tracking_progress_file = open(file_name, "w")
tpot_obj = TPOTClassifier(
population_size=10,
generations=10,
verbosity=1,
log_file=tracking_progress_file
)
tpot_obj.fit(X, y)
assert_equal(os.path.getsize(file_name), 0)

def test_log_file_verbosity_2():
""" Set verbosity as 2. Assert log_file parameter to generate log file. """
file_name = "progress_verbose_2.log"
tracking_progress_file = open(file_name, "w")
tpot_obj = TPOTClassifier(
population_size=10,
generations=10,
verbosity=2,
log_file=tracking_progress_file
)
tpot_obj.fit(X, y)
assert_equal(os.path.getsize(file_name) > 0, True)

def test_log_file_verbose_3():
""" Set verbosity as 3. Assert log_file parameter to generate log file. """
file_name = "progress_verbosity_3.log"
tracking_progress_file = open(file_name, "w")
tpot_obj = TPOTClassifier(
population_size=10,
generations=10,
verbosity=3,
log_file=tracking_progress_file
)
tpot_obj.fit(X, y)
assert_equal(os.path.getsize(file_name) > 0, True)
30 changes: 20 additions & 10 deletions tests/tpot_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import warnings
from multiprocessing import cpu_count
import os
import sys
from re import search
from datetime import datetime
from time import sleep
Expand Down Expand Up @@ -155,7 +156,8 @@ def test_init_custom_parameters():
verbosity=1,
random_state=42,
disable_update_check=True,
warm_start=True
warm_start=True,
log_file=None
)

assert tpot_obj.population_size == 500
Expand All @@ -168,6 +170,7 @@ def test_init_custom_parameters():
assert tpot_obj.max_time_mins is None
assert tpot_obj.warm_start is True
assert tpot_obj.verbosity == 1
assert tpot_obj.log_file == None

tpot_obj._fit_init()

Expand All @@ -179,7 +182,14 @@ def test_init_custom_parameters():
assert tpot_obj._optimized_pipeline_score == None
assert tpot_obj.fitted_pipeline_ == None
assert tpot_obj._exported_pipeline_text == []
assert tpot_obj.log_file == sys.stdout

def test_init_custom_progress_file():
""" Assert that TPOT has right file handler to save progress. """
file_name = "progress.txt"
file_handle = open(file_name, "w")
tpot_obj = TPOTClassifier(log_file=file_handle)
assert tpot_obj.log_file == file_handle

def test_init_default_scoring():
"""Assert that TPOT intitializes with the correct default scoring function."""
Expand Down Expand Up @@ -1196,7 +1206,7 @@ def test_check_periodic_pipeline():
)
tpot_obj.fit(training_features, training_target)
with closing(StringIO()) as our_file:
tpot_obj._file = our_file
tpot_obj.log_file = our_file
tpot_obj.verbosity = 3
tpot_obj._last_pipeline_write = datetime.now()
sleep(0.11)
Expand Down Expand Up @@ -1240,7 +1250,7 @@ def test_save_periodic_pipeline():
)
tpot_obj.fit(training_features, training_target)
with closing(StringIO()) as our_file:
tpot_obj._file = our_file
tpot_obj.log_file = our_file
tpot_obj.verbosity = 3
tpot_obj._last_pipeline_write = datetime.now()
sleep(0.11)
Expand Down Expand Up @@ -1270,7 +1280,7 @@ def test_save_periodic_pipeline_2():
)
tpot_obj.fit(training_features, training_target)
with closing(StringIO()) as our_file:
tpot_obj._file = our_file
tpot_obj.log_file = our_file
tpot_obj.verbosity = 3
tpot_obj._last_pipeline_write = datetime.now()
sleep(0.11)
Expand Down Expand Up @@ -1301,7 +1311,7 @@ def test_check_periodic_pipeline_3():
)
tpot_obj.fit(training_features, training_target)
with closing(StringIO()) as our_file:
tpot_obj._file = our_file
tpot_obj.log_file = our_file
tpot_obj.verbosity = 3
tpot_obj._exported_pipeline_text = []
tpot_obj._last_pipeline_write = datetime.now()
Expand Down Expand Up @@ -1544,7 +1554,7 @@ def test_update_pbar():
# reset verbosity = 3 for checking pbar message
tpot_obj.verbosity = 3
with closing(StringIO()) as our_file:
tpot_obj._file=our_file
tpot_obj.log_file=our_file
tpot_obj._pbar = tqdm(total=10, disable=False, file=our_file)
tpot_obj._update_pbar(pbar_num=2, pbar_msg="Test Warning Message")
our_file.seek(0)
Expand All @@ -1563,7 +1573,7 @@ def test_update_val():
# reset verbosity = 3 for checking pbar message
tpot_obj.verbosity = 3
with closing(StringIO()) as our_file:
tpot_obj._file=our_file
tpot_obj.log_file=our_file
tpot_obj._pbar = tqdm(total=10, disable=False, file=our_file)
result_score_list = []
result_score_list = tpot_obj._update_val(0.9999, result_score_list)
Expand Down Expand Up @@ -1610,7 +1620,7 @@ def test_preprocess_individuals():
# reset verbosity = 3 for checking pbar message
tpot_obj.verbosity = 3
with closing(StringIO()) as our_file:
tpot_obj._file=our_file
tpot_obj.log_file=our_file
tpot_obj._pbar = tqdm(total=2, disable=False, file=our_file)
operator_counts, eval_individuals_str, sklearn_pipeline_list, stats_dicts = \
tpot_obj._preprocess_individuals(individuals)
Expand Down Expand Up @@ -1656,7 +1666,7 @@ def test_preprocess_individuals_2():
# reset verbosity = 3 for checking pbar message
tpot_obj.verbosity = 3
with closing(StringIO()) as our_file:
tpot_obj._file=our_file
tpot_obj.log_file=our_file
tpot_obj._pbar = tqdm(total=3, disable=False, file=our_file)
operator_counts, eval_individuals_str, sklearn_pipeline_list, stats_dicts = \
tpot_obj._preprocess_individuals(individuals)
Expand Down Expand Up @@ -1703,7 +1713,7 @@ def test_preprocess_individuals_3():
# reset verbosity = 3 for checking pbar message

with closing(StringIO()) as our_file:
tpot_obj._file=our_file
tpot_obj.log_file=our_file
tpot_obj._lambda=4
tpot_obj._pbar = tqdm(total=2, disable=False, file=our_file)
tpot_obj._pbar.n = 2
Expand Down
28 changes: 16 additions & 12 deletions tpot/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ def __init__(self, generations=100, population_size=100, offspring_size=None,
random_state=None, config_dict=None, template=None,
warm_start=False, memory=None, use_dask=False,
periodic_checkpoint_folder=None, early_stop=None,
verbosity=0, disable_update_check=False):
verbosity=0, disable_update_check=False,
log_file=None):
"""Set up the genetic programming algorithm for pipeline optimization.

Parameters
Expand Down Expand Up @@ -235,7 +236,8 @@ def __init__(self, generations=100, population_size=100, offspring_size=None,
A setting of 2 or higher will add a progress bar during the optimization procedure.
disable_update_check: bool, optional (default: False)
Flag indicating whether the TPOT version checker should be disabled.

log_file: io.TextIOWrapper or io.StringIO, optional (defaul: sys.stdout)
Save progress content to a file.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please rename the progress_file to log_file? I think there are some log infos when verbosity=3.


Returns
-------
Expand Down Expand Up @@ -266,6 +268,7 @@ def __init__(self, generations=100, population_size=100, offspring_size=None,
self.verbosity = verbosity
self.disable_update_check = disable_update_check
self.random_state = random_state
self.log_file = log_file


def _setup_template(self, template):
Expand Down Expand Up @@ -566,9 +569,9 @@ def _fit_init(self):
)

self._pbar = None
# Specifies where to output the progress messages (default: sys.stdout).
# Maybe open this API in future version of TPOT.(io.TextIOWrapper or io.StringIO)
self._file = sys.stdout

if not self.log_file:
self.log_file = sys.stdout

self._setup_scoring_function(self.scoring)

Expand Down Expand Up @@ -693,7 +696,7 @@ def pareto_eq(ind1, ind2):
else:
total_evals = self._lambda * self.generations + self.population_size

self._pbar = tqdm(total=total_evals, unit='pipeline', leave=False,
self._pbar = tqdm(total=total_evals, unit='pipeline', leave=False, file=self.log_file,
disable=not (self.verbosity >= 2), desc='Optimization Progress')

try:
Expand All @@ -717,9 +720,9 @@ def pareto_eq(ind1, ind2):
# Allow for certain exceptions to signal a premature fit() cancellation
except (KeyboardInterrupt, SystemExit, StopIteration) as e:
if self.verbosity > 0:
self._pbar.write('', file=self._file)
self._pbar.write('', file=self.log_file)
self._pbar.write('{}\nTPOT closed prematurely. Will use the current best pipeline.'.format(e),
file=self._file)
file=self.log_file)
finally:
# clean population for the next call if warm_start=False
if not self.warm_start:
Expand Down Expand Up @@ -1331,10 +1334,11 @@ def _evaluate_individuals(self, population, features, target, sample_weight=None

except (KeyboardInterrupt, SystemExit, StopIteration) as e:
if self.verbosity > 0:
self._pbar.write('', file=self._file)
self._pbar.write('', file=self.log_file)
self._pbar.write('{}\nTPOT closed during evaluation in one generation.\n'
'WARNING: TPOT may not provide a good pipeline if TPOT is stopped/interrupted in an early generation.'.format(e),
file=self._file)
'WARNING: TPOT may not provide a good pipeline if TPOT is stopped/interrupted in a early generation.'.format(e),
file=self.log_file)

# number of individuals already evaluated in this generation
num_eval_ind = len(result_score_list)
self._update_evaluated_individuals_(result_score_list,
Expand Down Expand Up @@ -1482,7 +1486,7 @@ def _update_pbar(self, pbar_num=1, pbar_msg=None):
"""
if not isinstance(self._pbar, type(None)):
if self.verbosity > 2 and pbar_msg is not None:
self._pbar.write(pbar_msg, file=self._file)
self._pbar.write(pbar_msg, file=self.log_file)
if not self._pbar.disable:
self._pbar.update(pbar_num)

Expand Down