From 07e631d56db3bfd666f519f2816a1db6930f8ec8 Mon Sep 17 00:00:00 2001 From: Jeroen Overschie Date: Sat, 12 Jun 2021 22:36:19 +0200 Subject: [PATCH] Allow ignoring cache per-estimator, closes #16 --- fseval/pipeline/estimator.py | 26 ++++++++++++++++--- fseval/storage_providers/wandb.py | 2 +- fseval/types.py | 19 ++++++++++++++ .../pipelines/test_rank_and_validate.py | 3 +++ 4 files changed, 45 insertions(+), 5 deletions(-) diff --git a/fseval/pipeline/estimator.py b/fseval/pipeline/estimator.py index ce9d02b..85a2634 100644 --- a/fseval/pipeline/estimator.py +++ b/fseval/pipeline/estimator.py @@ -6,6 +6,7 @@ from fseval.types import ( AbstractEstimator, AbstractStorageProvider, + CacheUsage, IncompatibilityError, Task, ) @@ -17,7 +18,8 @@ @dataclass class EstimatorConfig: estimator: Any = None # must have _target_ of type BaseEstimator. - use_cache_if_available: bool = True + load_cache: CacheUsage = CacheUsage.allow + save_cache: CacheUsage = CacheUsage.allow # tags multioutput: Optional[bool] = None multioutput_only: Optional[bool] = None @@ -36,7 +38,8 @@ class TaskedEstimatorConfig(EstimatorConfig): name: str = MISSING classifier: Optional[EstimatorConfig] = None regressor: Optional[EstimatorConfig] = None - use_cache_if_available: bool = True + load_cache: CacheUsage = CacheUsage.allow + save_cache: CacheUsage = CacheUsage.allow # tags multioutput: Optional[bool] = False multioutput_only: Optional[bool] = False @@ -75,16 +78,31 @@ def _get_class_repr(cls, estimator): return f"{module_name}.{class_name}" def _load_cache(self, filename: str, storage_provider: AbstractStorageProvider): + if self.load_cache == CacheUsage.never: + self.logger.info("ignoring cache load completely.") + return + restored = storage_provider.restore_pickle(filename) self.estimator = restored or self.estimator self._is_fitted = bool(restored) + if self.load_cache == CacheUsage.must: + assert self._is_fitted, ( + "Cache usage was set to 'must' but loading cached estimator failed." + + " Pickle file might be corrupt or could not be found." + ) + def _save_cache(self, filename: str, storage_provider: AbstractStorageProvider): - storage_provider.save_pickle(filename, self.estimator) + if self.save_cache == CacheUsage.never: + self.logger.info("ignoring cache save completely.") + return + else: + storage_provider.save_pickle(filename, self.estimator) + # TODO check whether file was successfully saved. def fit(self, X, y): # don't refit if cache available and `use_cache_if_available` is enabled - if self._is_fitted and self.use_cache_if_available: + if self._is_fitted: self.logger.debug("using estimator from cache.") return self diff --git a/fseval/storage_providers/wandb.py b/fseval/storage_providers/wandb.py index 4601b1a..6be2968 100644 --- a/fseval/storage_providers/wandb.py +++ b/fseval/storage_providers/wandb.py @@ -29,8 +29,8 @@ class WandbStorageProvider(AbstractStorageProvider): run_id: Optional[str] - recover from a specific run id.""" - resume: Optional[str] = None local_dir: Optional[str] = None + resume: Optional[str] = None entity: Optional[str] = None project: Optional[str] = None run_id: Optional[str] = None diff --git a/fseval/types.py b/fseval/types.py index 2e8026c..cb432f4 100644 --- a/fseval/types.py +++ b/fseval/types.py @@ -11,6 +11,25 @@ class Task(Enum): classification = 2 +class CacheUsage(Enum): + """ + Determines how cache usage is handled. In the case of **loading** caches: + + - `allow`: program might use cache; if found and could be restored + - `must`: program should fail if no cache found + - `never`: program should not load cache even if found + + When **saving** caches: + - `allow`: program might save cache; no fatal error thrown when fails + - `must`: program must save cache; throws error if fails (e.g. due to out of memory) + - `never`: program does not try to save a cached version + """ + + allow = 1 + must = 2 + never = 3 + + class IncompatibilityError(Exception): ... diff --git a/tests/integration/pipelines/test_rank_and_validate.py b/tests/integration/pipelines/test_rank_and_validate.py index 7db0cce..860dc80 100644 --- a/tests/integration/pipelines/test_rank_and_validate.py +++ b/tests/integration/pipelines/test_rank_and_validate.py @@ -144,6 +144,9 @@ def cfg(dataset, cv, resample, classifier, ranker, validator): n_bootstraps=2, n_jobs=None, all_features_to_select="range(1, min(50, p) + 1)", + upload_ranking_scores=True, + upload_validation_scores=True, + upload_best_scores=True, ) cfg = OmegaConf.create(config.__dict__)