Skip to content

Commit

Permalink
[FIX] clone test of dummy estimator (#1129)
Browse files Browse the repository at this point in the history
* [FIX] clone test of dummy estimator

* [FIX] flake8
  • Loading branch information
franchuterivera authored Apr 21, 2021
1 parent 79627e1 commit 0982410
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 11 deletions.
11 changes: 0 additions & 11 deletions test/test_automl/test_automl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import pandas as pd
import pytest
import sklearn.datasets
from sklearn.base import clone
from smac.scenario.scenario import Scenario
from smac.facade.roar_facade import ROAR

Expand Down Expand Up @@ -427,16 +426,6 @@ def test_do_dummy_prediction(backend, dask_client, datasets):
'predictions_ensemble_1_1_0.0.npy')
)

model_path = os.path.join(backend.temporary_directory, '.auto-sklearn',
'runs', '1_1_0.0',
'1.1.0.0.model')

# Make sure the dummy model complies with scikit learn
# get/set params
assert os.path.exists(model_path)
with open(model_path, 'rb') as model_handler:
clone(pickle.load(model_handler))

auto._clean_logger()

del auto
Expand Down
31 changes: 31 additions & 0 deletions test/test_evaluation/test_dummy_pipelines.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import numpy as np

import pytest

from sklearn.base import clone
from sklearn.datasets import make_classification, make_regression
from sklearn.utils.validation import check_is_fitted

from autosklearn.evaluation.abstract_evaluator import MyDummyClassifier, MyDummyRegressor


@pytest.mark.parametrize("task_type", ['classification', 'regression'])
def test_dummy_pipeline(task_type):
if task_type == 'classification':
estimator_class = MyDummyClassifier
data_maker = make_classification
elif task_type == 'regression':
estimator_class = MyDummyRegressor
data_maker = make_regression
else:
pytest.fail(task_type)

estimator = estimator_class(config=1, random_state=np.random.RandomState(42))
X, y = data_maker()
estimator.fit(X, y)
check_is_fitted(estimator)

assert np.shape(X)[0] == np.shape(estimator.predict(X))[0]

# make sure we comply with scikit-learn estimator API
clone(estimator)

0 comments on commit 0982410

Please sign in to comment.