Skip to content

Commit

Permalink
Fix caching 🔨🔨🔨
Browse files Browse the repository at this point in the history
  • Loading branch information
dunnkers committed Jun 8, 2021
1 parent ff1cb5b commit 86f4ca3
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 21 deletions.
12 changes: 4 additions & 8 deletions fseval/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,16 @@
from fseval.pipelines.rank_and_validate import RankAndValidateConfig


@dataclass
class StorageProviderConfig:
_target_: str = MISSING
local_dir: Optional[str] = None


@dataclass
class BaseConfig:
dataset: DatasetConfig = MISSING
cv: CrossValidatorConfig = MISSING
pipeline: Any = MISSING
callbacks: Dict = field(default_factory=lambda: dict())
storage_provider: StorageProviderConfig = StorageProviderConfig(
_target_="fseval.storage_providers.mock.MockStorageProvider"
storage_provider: Any = field(
default_factory=lambda: dict(
_target_="fseval.storage_providers.mock.MockStorageProvider"
)
)


Expand Down
10 changes: 5 additions & 5 deletions fseval/pipeline/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,15 @@
from logging import Logger, getLogger
from typing import Any, Optional

from hydra.utils import instantiate
from omegaconf import II, MISSING, OmegaConf
from sklearn.preprocessing import minmax_scale

from fseval.types import (
AbstractEstimator,
AbstractStorageProvider,
IncompatibilityError,
Task,
)
from hydra.utils import instantiate
from omegaconf import II, MISSING, OmegaConf
from sklearn.preprocessing import minmax_scale


@dataclass
Expand Down Expand Up @@ -78,14 +77,15 @@ def _get_class_repr(cls, estimator):
def _load_cache(self, filename: str, storage_provider: AbstractStorageProvider):
restored = storage_provider.restore_pickle(filename)
self.estimator = restored or self.estimator
self._is_fitted = restored
self._is_fitted = bool(restored)

def _save_cache(self, filename: str, storage_provider: AbstractStorageProvider):
storage_provider.save_pickle(filename, self.estimator)

def fit(self, X, y):
# don't refit if cache available and `use_cache_if_available` is enabled
if self._is_fitted and self.use_cache_if_available:
self.logger.debug("using estimator from cache.")
return self

# rescale if necessary
Expand Down
5 changes: 2 additions & 3 deletions fseval/pipelines/_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@
from typing import List

import pandas as pd
from humanfriendly import format_timespan

from fseval.pipeline.estimator import Estimator
from fseval.types import AbstractEstimator, TerminalColor
from humanfriendly import format_timespan


@dataclass
Expand Down Expand Up @@ -92,7 +91,7 @@ def fit(self, X, y) -> AbstractEstimator:

## Run `fit`
n_jobs = self._get_n_jobs()
if n_jobs is not None:
if n_jobs is not None and (n_jobs > 1 or n_jobs == -1):
assert n_jobs >= 1 or n_jobs == -1, f"incorrect `n_jobs`: {n_jobs}"

cpus = multiprocessing.cpu_count() if n_jobs == -1 else n_jobs
Expand Down
27 changes: 22 additions & 5 deletions fseval/storage_providers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,17 @@ class WandbStorageProvider(AbstractStorageProvider):
Arguments:
local_dir: Optional[str] - when set, an attempt is made to load from the
designated local directory first, before downloading the data off of wandb. Can
be used to perform faster loads or prevent being rate-limited on wandb."""
be used to perform faster loads or prevent being rate-limited on wandb.
wandb_entity: Optional[str] - allows you to recover from a specific entity,
instead of using the entity that is set for the 'current' run.
wandb_project: Optional[str] - idem
wandb_run_id: Optional[str] - idem"""

local_dir: Optional[str] = None
wandb_entity: Optional[str] = None
wandb_project: Optional[str] = None
wandb_run_id: Optional[str] = None
logger: Logger = getLogger(__name__)

def _assert_wandb_available(self):
Expand Down Expand Up @@ -55,7 +63,13 @@ def save_pickle(self, filename: str, obj: Any):

def _get_restore_file_handle(self, filename: str):
try:
file_handle = wandb.restore(filename)
entity = self.wandb_entity or wandb.run.entity # type: ignore
project = self.wandb_project or wandb.run.project # type: ignore
run_id = self.wandb_run_id or wandb.run.id # type: ignore

file_handle = wandb.restore(
filename, run_path=f"{entity}/{project}/{run_id}"
)
return file_handle
except ValueError as err:
config = self.config
Expand Down Expand Up @@ -115,17 +129,20 @@ def restore(self, filename: str, reader: Callable, mode: str = "r") -> Any:
)

return file

# (2) otherwise, restore by downloading from wandb
elif self._wandb_restoration(filename, reader, mode):
file = self._wandb_restoration(filename, reader, mode)
if file:
self.logger.info(
f"successfully restored {TerminalColor.yellow(filename)} from "
+ TerminalColor.blue("wandb servers ")
+ TerminalColor.green("✓")
)

return file
else:
return None

# (3) if no cache is available anywhere, return None.
return None

def restore_pickle(self, filename: str) -> Any:
return self.restore(filename, load, mode="rb")

0 comments on commit 86f4ca3

Please sign in to comment.