From e9c31c9c6c2bde6f05a11bf306d5af62fe002f6e Mon Sep 17 00:00:00 2001 From: Michael Date: Wed, 14 Aug 2019 13:33:26 +0100 Subject: [PATCH] Flag to allow multioutput. --- tpot/base.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tpot/base.py b/tpot/base.py index 626bb247..974e7f74 100644 --- a/tpot/base.py +++ b/tpot/base.py @@ -114,7 +114,7 @@ def __init__(self, generations=100, population_size=100, offspring_size=None, random_state=None, config_dict=None, template=None, warm_start=False, memory=None, use_dask=False, periodic_checkpoint_folder=None, early_stop=None, - verbosity=0, disable_update_check=False): + verbosity=0, disable_update_check=False, multi_output=True): """Set up the genetic programming algorithm for pipeline optimization. Parameters @@ -262,6 +262,7 @@ def __init__(self, generations=100, population_size=100, offspring_size=None, A setting of 2 or higher will add a progress bar during the optimization procedure. disable_update_check: bool, optional (default: False) Flag indicating whether the TPOT version checker should be disabled. + multi_output: bool, whether or not to allow multi-dimensional y data. Returns @@ -293,6 +294,7 @@ def __init__(self, generations=100, population_size=100, offspring_size=None, self.verbosity = verbosity self.disable_update_check = disable_update_check self.random_state = random_state + self.multi_output = multi_output def _setup_template(self, template): @@ -1190,7 +1192,7 @@ def _check_dataset(self, features, target, sample_weight=None): try: if target is not None: - X, y = check_X_y(features, target, accept_sparse=True, dtype=None) + X, y = check_X_y(features, target, accept_sparse=True, dtype=None, multi_output=self.multi_output) if self._imputed: return X, y else: @@ -1204,9 +1206,9 @@ def _check_dataset(self, features, target, sample_weight=None): except (AssertionError, ValueError): raise ValueError( 'Error: Input data is not in a valid format. Please confirm ' - 'that the input data is scikit-learn compatible. For example, ' + 'that the input data is scikit-learn compatible.' + (' For example, ' 'the features must be a 2-D array and target labels must be a ' - '1-D array.' + '1-D array.' if self.multi_output else '') )