Skip to content

Commit

Permalink
Flag to allow multioutput.
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelith committed Aug 14, 2019
1 parent 815b0e2 commit e9c31c9
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions tpot/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def __init__(self, generations=100, population_size=100, offspring_size=None,
random_state=None, config_dict=None, template=None,
warm_start=False, memory=None, use_dask=False,
periodic_checkpoint_folder=None, early_stop=None,
verbosity=0, disable_update_check=False):
verbosity=0, disable_update_check=False, multi_output=True):
"""Set up the genetic programming algorithm for pipeline optimization.
Parameters
Expand Down Expand Up @@ -262,6 +262,7 @@ def __init__(self, generations=100, population_size=100, offspring_size=None,
A setting of 2 or higher will add a progress bar during the optimization procedure.
disable_update_check: bool, optional (default: False)
Flag indicating whether the TPOT version checker should be disabled.
multi_output: bool, whether or not to allow multi-dimensional y data.
Returns
Expand Down Expand Up @@ -293,6 +294,7 @@ def __init__(self, generations=100, population_size=100, offspring_size=None,
self.verbosity = verbosity
self.disable_update_check = disable_update_check
self.random_state = random_state
self.multi_output = multi_output


def _setup_template(self, template):
Expand Down Expand Up @@ -1190,7 +1192,7 @@ def _check_dataset(self, features, target, sample_weight=None):

try:
if target is not None:
X, y = check_X_y(features, target, accept_sparse=True, dtype=None)
X, y = check_X_y(features, target, accept_sparse=True, dtype=None, multi_output=self.multi_output)
if self._imputed:
return X, y
else:
Expand All @@ -1204,9 +1206,9 @@ def _check_dataset(self, features, target, sample_weight=None):
except (AssertionError, ValueError):
raise ValueError(
'Error: Input data is not in a valid format. Please confirm '
'that the input data is scikit-learn compatible. For example, '
'that the input data is scikit-learn compatible.' + (' For example, '
'the features must be a 2-D array and target labels must be a '
'1-D array.'
'1-D array.' if self.multi_output else '')
)


Expand Down

0 comments on commit e9c31c9

Please sign in to comment.