Skip to content

Commit

Permalink
add test_fit_predict
Browse files Browse the repository at this point in the history
  • Loading branch information
Weixuan Fu committed Jun 28, 2017
1 parent 22af31f commit 55af5e5
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 5 deletions.
23 changes: 20 additions & 3 deletions tests/tpot_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,8 +603,25 @@ def test_fit3():
assert not (tpot_obj._start_datetime is None)


def test_fit_predict():
"""Assert that the TPOT fit_predict function provides an optimized pipeline and correct output."""
tpot_obj = TPOTClassifier(
random_state=42,
population_size=1,
offspring_size=2,
generations=1,
verbosity=0,
config_dict='TPOT light'
)
result = tpot_obj.fit_predict(training_features, training_target)

assert isinstance(tpot_obj._optimized_pipeline, creator.Individual)
assert not (tpot_obj._start_datetime is None)
assert result.shape == (training_features.shape[0],)


def test_update_top_pipeline():
"""Assert that the TPOT _update_top_pipeline updated an optimized pipeline"""
"""Assert that the TPOT _update_top_pipeline updated an optimized pipeline."""
tpot_obj = TPOTClassifier(
random_state=42,
population_size=1,
Expand All @@ -623,7 +640,7 @@ def test_update_top_pipeline():


def test_update_top_pipeline_2():
"""Assert that the TPOT _update_top_pipeline raises RuntimeError when self._pareto_front is empty"""
"""Assert that the TPOT _update_top_pipeline raises RuntimeError when self._pareto_front is empty."""
tpot_obj = TPOTClassifier(
random_state=42,
population_size=1,
Expand All @@ -641,7 +658,7 @@ def pareto_eq(ind1, ind2):


def test_update_top_pipeline_3():
"""Assert that the TPOT _update_top_pipeline raises RuntimeError when self._optimized_pipeline is updated"""
"""Assert that the TPOT _update_top_pipeline raises RuntimeError when self._optimized_pipeline is not updated."""
tpot_obj = TPOTClassifier(
random_state=42,
population_size=1,
Expand Down
10 changes: 8 additions & 2 deletions tpot/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,7 +664,7 @@ def predict(self, features):

return self.fitted_pipeline_.predict(features)

def fit_predict(self, features, target):
def fit_predict(self, features, target, sample_weight=None, groups=None):
"""Call fit and predict in sequence.
Parameters
Expand All @@ -673,14 +673,20 @@ def fit_predict(self, features, target):
Feature matrix
target: array-like {n_samples}
List of class labels for prediction
sample_weight: array-like {n_samples}, optional
Per-sample weights. Higher weights force TPOT to put more emphasis on those points
groups: array-like, with shape {n_samples, }, optional
Group labels for the samples used when performing cross-validation.
This parameter should only be used in conjunction with sklearn's Group cross-validation
functions, such as sklearn.model_selection.GroupKFold
Returns
----------
array-like: {n_samples}
Predicted target for the provided features
"""
self.fit(features, target)
self.fit(features, target, sample_weight=sample_weight, groups=groups)
return self.predict(features)

def score(self, testing_features, testing_target):
Expand Down

0 comments on commit 55af5e5

Please sign in to comment.