Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flag to allow multioutput. #903

Open
wants to merge 1 commit into
base: development
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions tpot/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def __init__(self, generations=100, population_size=100, offspring_size=None,
random_state=None, config_dict=None, template=None,
warm_start=False, memory=None, use_dask=False,
periodic_checkpoint_folder=None, early_stop=None,
verbosity=0, disable_update_check=False):
verbosity=0, disable_update_check=False, multi_output=True):
"""Set up the genetic programming algorithm for pipeline optimization.

Parameters
Expand Down Expand Up @@ -262,6 +262,7 @@ def __init__(self, generations=100, population_size=100, offspring_size=None,
A setting of 2 or higher will add a progress bar during the optimization procedure.
disable_update_check: bool, optional (default: False)
Flag indicating whether the TPOT version checker should be disabled.
multi_output: bool, whether or not to allow multi-dimensional y data.


Returns
Expand Down Expand Up @@ -293,6 +294,7 @@ def __init__(self, generations=100, population_size=100, offspring_size=None,
self.verbosity = verbosity
self.disable_update_check = disable_update_check
self.random_state = random_state
self.multi_output = multi_output


def _setup_template(self, template):
Expand Down Expand Up @@ -1190,7 +1192,7 @@ def _check_dataset(self, features, target, sample_weight=None):

try:
if target is not None:
X, y = check_X_y(features, target, accept_sparse=True, dtype=None)
X, y = check_X_y(features, target, accept_sparse=True, dtype=None, multi_output=self.multi_output)
if self._imputed:
return X, y
else:
Expand All @@ -1204,9 +1206,9 @@ def _check_dataset(self, features, target, sample_weight=None):
except (AssertionError, ValueError):
raise ValueError(
'Error: Input data is not in a valid format. Please confirm '
'that the input data is scikit-learn compatible. For example, '
'that the input data is scikit-learn compatible.' + (' For example, '
'the features must be a 2-D array and target labels must be a '
'1-D array.'
'1-D array.' if self.multi_output else '')
)


Expand Down