Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run history traditional #121

Merged
21 changes: 18 additions & 3 deletions autoPyTorch/api/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

import pandas as pd

from smac.runhistory.runhistory import RunHistory
from smac.runhistory.runhistory import DataOrigin, RunHistory
from smac.stats.stats import Stats
from smac.tae import StatusType

Expand Down Expand Up @@ -173,7 +173,7 @@ def __init__(
self._dataset_requirements: Optional[List[FitRequirement]] = None
self._metric: Optional[autoPyTorchMetric] = None
self._logger: Optional[PicklableClientLogger] = None
self.run_history: Optional[RunHistory] = None
self.run_history: RunHistory = RunHistory()
self.trajectory: Optional[List] = None
self.dataset_name: Optional[str] = None
self.cv_models_: Dict = {}
Expand Down Expand Up @@ -582,6 +582,10 @@ def _do_traditional_prediction(self, num_run: int, time_left: int, func_eval_tim
assert self._logger is not None
assert self._dask_client is not None

self._logger.info("Starting to create traditional classifier predictions.")

# Initialise run history for the traditional classifiers
run_history = RunHistory()
memory_limit = self._memory_limit
if memory_limit is not None:
memory_limit = int(math.ceil(memory_limit))
Expand Down Expand Up @@ -651,6 +655,11 @@ def _do_traditional_prediction(self, num_run: int, time_left: int, func_eval_tim
if status == StatusType.SUCCESS:
self._logger.info(
f"Fitting {cls} took {runtime}s, performance:{cost}/{additional_info}")
configuration = additional_info['pipeline_configuration']
origin = additional_info['configuration_origin']
run_history.add(config=configuration, cost=cost,
time=runtime, status=status, seed=self.seed,
origin=origin)
else:
if additional_info.get('exitcode') == -6:
self._logger.error(
Expand All @@ -677,6 +686,11 @@ def _do_traditional_prediction(self, num_run: int, time_left: int, func_eval_tim
"Please consider increasing the run time to further improve performance.")
break

self._logger.debug("Run history traditional: {}".format(run_history))
ravinkohli marked this conversation as resolved.
Show resolved Hide resolved
# add run history of traditional to api run history
self.run_history.update(run_history, DataOrigin.EXTERNAL_SAME_INSTANCES)
run_history.save_json(os.path.join(self._backend.internals_directory, 'traditional_run_history.json'),
save_external=True)
return num_run

def _search(
Expand Down Expand Up @@ -947,8 +961,9 @@ def _search(
search_space_updates=self.search_space_updates
)
try:
self.run_history, self.trajectory, budget_type = \
run_history, self.trajectory, budget_type = \
_proc_smac.run_smbo()
self.run_history.update(run_history, DataOrigin.INTERNAL)
trajectory_filename = os.path.join(
self._backend.get_smac_output_directory_for_run(self.seed),
'trajectory.json')
Expand Down
26 changes: 18 additions & 8 deletions autoPyTorch/evaluation/abstract_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ def __init__(self, config: str,
configuration_space = self.pipeline.get_hyperparameter_search_space()
default_configuration = configuration_space.get_default_configuration().get_dictionary()
default_configuration['model_trainer:tabular_classifier:classifier'] = config
configuration = Configuration(configuration_space, default_configuration)
self.pipeline.set_hyperparameters(configuration)
self.configuration = Configuration(configuration_space, default_configuration)
self.pipeline.set_hyperparameters(self.configuration)

def fit(self, X: Dict[str, Any], y: Any,
sample_weight: Optional[np.ndarray] = None) -> object:
Expand All @@ -85,8 +85,18 @@ def predict(self, X: Union[np.ndarray, pd.DataFrame],
def estimator_supports_iterative_fit(self) -> bool: # pylint: disable=R0201
return False

def get_additional_run_info(self) -> None: # pylint: disable=R0201
return None
def get_additional_run_info(self) -> Dict[str, Any]: # pylint: disable=R0201
"""
Can be used to return additional info for the run.
Returns:
Dict[str, Any]:
Currently contains
1. pipeline_configuration: the configuration of the pipeline, i.e, the traditional model used
2. trainer_configuration: the parameters for the traditional model used.
Can be found in autoPyTorch/pipeline/components/setup/traditional_ml/classifier_configs
"""
return {'pipeline_configuration': self.configuration,
'trainer_configuration': self.pipeline.named_steps['model_trainer'].choice.model.get_config()}
ravinkohli marked this conversation as resolved.
Show resolved Hide resolved

def get_pipeline_representation(self) -> Dict[str, str]:
return self.pipeline.get_pipeline_representation()
Expand Down Expand Up @@ -131,8 +141,8 @@ def predict(self, X: Union[np.ndarray, pd.DataFrame],
def estimator_supports_iterative_fit(self) -> bool: # pylint: disable=R0201
return False

def get_additional_run_info(self) -> None: # pylint: disable=R0201
return None
def get_additional_run_info(self) -> Dict: # pylint: disable=R0201
return {}

def get_pipeline_representation(self) -> Dict[str, str]:
return {
Expand Down Expand Up @@ -172,8 +182,8 @@ def predict(self, X: Union[np.ndarray, pd.DataFrame],
def estimator_supports_iterative_fit(self) -> bool: # pylint: disable=R0201
return False

def get_additional_run_info(self) -> None: # pylint: disable=R0201
return None
def get_additional_run_info(self) -> Dict: # pylint: disable=R0201
return {}

@staticmethod
def get_default_pipeline_options() -> Dict[str, Any]:
Expand Down
2 changes: 1 addition & 1 deletion autoPyTorch/evaluation/tae.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def fit_predict_try_except_decorator(
def get_cost_of_crash(metric: autoPyTorchMetric) -> float:
# The metric must always be defined to extract optimum/worst
if not isinstance(metric, autoPyTorchMetric):
raise ValueError("The metric must be stricly be an instance of autoPyTorchMetric")
raise ValueError("The metric must be strictly be an instance of autoPyTorchMetric")

# Autopytorch optimizes the err. This function translates
# worst_possible_result to be a minimization problem.
Expand Down
5 changes: 4 additions & 1 deletion autoPyTorch/evaluation/train_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ def fit_predict_and_loss(self) -> None:
# weights for opt_losses.
opt_fold_weights = [np.NaN] * self.num_folds

additional_run_info = {}

for i, (train_split, test_split) in enumerate(self.splits):

pipeline = self.pipelines[i]
Expand Down Expand Up @@ -178,7 +180,8 @@ def fit_predict_and_loss(self) -> None:
# number of optimization data points for this fold.
# Used for weighting the average.
opt_fold_weights[i] = len(train_split)

additional_run_info.update(pipeline.get_additional_run_info() if hasattr(
pipeline, 'get_additional_run_info') and pipeline.get_additional_run_info() is not None else {})
# Compute weights of each fold based on the number of samples in each
# fold.
train_fold_weights = [w / sum(train_fold_weights)
Expand Down
46 changes: 27 additions & 19 deletions test/test_api/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import sklearn.datasets
from sklearn.ensemble import VotingClassifier, VotingRegressor

from smac.runhistory.runhistory import RunHistory

import torch

from autoPyTorch.api.tabular_classification import TabularClassificationTask
Expand Down Expand Up @@ -104,17 +106,20 @@ def test_tabular_classification(openml_id, resampling_strategy, backend):

# Search for an existing run key in disc. A individual model might have
# a timeout and hence was not written to disc
successful_num_run = None
SUCCESS = False
for i, (run_key, value) in enumerate(estimator.run_history.data.items()):
if 'SUCCESS' not in str(value.status):
continue

run_key_model_run_dir = estimator._backend.get_numrun_directory(
estimator.seed, run_key.config_id + 1, run_key.budget)
if os.path.exists(run_key_model_run_dir):
# Runkey config id is different from the num_run
# more specifically num_run = config_id + 1(dummy)
if 'SUCCESS' in str(value.status):
run_key_model_run_dir = estimator._backend.get_numrun_directory(
estimator.seed, run_key.config_id + 1, run_key.budget)
successful_num_run = run_key.config_id + 1
break
if os.path.exists(run_key_model_run_dir):
# Runkey config id is different from the num_run
# more specifically num_run = config_id + 1(dummy)
SUCCESS = True
break

assert SUCCESS, f"Successful run was not properly saved for num_run: {successful_num_run}"

if resampling_strategy == HoldoutValTypes.holdout_validation:
model_file = os.path.join(run_key_model_run_dir,
Expand Down Expand Up @@ -272,17 +277,20 @@ def test_tabular_regression(openml_name, resampling_strategy, backend):

# Search for an existing run key in disc. A individual model might have
# a timeout and hence was not written to disc
successful_num_run = None
SUCCESS = False
for i, (run_key, value) in enumerate(estimator.run_history.data.items()):
if 'SUCCESS' not in str(value.status):
ravinkohli marked this conversation as resolved.
Show resolved Hide resolved
continue

run_key_model_run_dir = estimator._backend.get_numrun_directory(
estimator.seed, run_key.config_id + 1, run_key.budget)
if os.path.exists(run_key_model_run_dir):
# Runkey config id is different from the num_run
# more specifically num_run = config_id + 1(dummy)
if 'SUCCESS' in str(value.status):
run_key_model_run_dir = estimator._backend.get_numrun_directory(
estimator.seed, run_key.config_id + 1, run_key.budget)
successful_num_run = run_key.config_id + 1
break
if os.path.exists(run_key_model_run_dir):
# Runkey config id is different from the num_run
# more specifically num_run = config_id + 1(dummy)
SUCCESS = True
break

assert SUCCESS, f"Successful run was not properly saved for num_run: {successful_num_run}"

if resampling_strategy == HoldoutValTypes.holdout_validation:
model_file = os.path.join(run_key_model_run_dir,
Expand Down Expand Up @@ -384,7 +392,7 @@ def test_tabular_input_support(openml_id, backend):
estimator._do_dummy_prediction = unittest.mock.MagicMock()

with unittest.mock.patch.object(AutoMLSMBO, 'run_smbo') as AutoMLSMBOMock:
AutoMLSMBOMock.return_value = ({}, {}, 'epochs')
AutoMLSMBOMock.return_value = (RunHistory(), {}, 'epochs')
estimator.search(
X_train=X_train, y_train=y_train,
X_test=X_test, y_test=y_test,
Expand Down
4 changes: 2 additions & 2 deletions test/test_evaluation/test_train_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def __init__(self):
def predict_proba(self, X, batch_size=None):
return np.tile([0.6, 0.4], (len(X), 1))

def get_additional_run_info(self) -> None:
return None
def get_additional_run_info(self):
return {}


class TestTrainEvaluator(BaseEvaluatorTest, unittest.TestCase):
Expand Down