diff --git a/fseval/config.py b/fseval/config.py index acb7bfa..641c4e4 100644 --- a/fseval/config.py +++ b/fseval/config.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Any, Dict +from typing import Any, Dict, Optional from hydra.core.config_store import ConfigStore from omegaconf import MISSING @@ -12,6 +12,7 @@ @dataclass class StorageProviderConfig: _target_: str = MISSING + local_dir: Optional[str] = None @dataclass diff --git a/fseval/pipelines/rank_and_validate/_components.py b/fseval/pipelines/rank_and_validate/_components.py index d920521..bdcb98d 100644 --- a/fseval/pipelines/rank_and_validate/_components.py +++ b/fseval/pipelines/rank_and_validate/_components.py @@ -4,13 +4,12 @@ import numpy as np import pandas as pd +from fseval.callbacks import WandbCallback +from fseval.types import TerminalColor from omegaconf import MISSING from sklearn.base import clone from tqdm import tqdm -from fseval.callbacks import WandbCallback -from fseval.types import TerminalColor - from .._experiment import Experiment from ._config import RankAndValidatePipeline from ._dataset_validator import DatasetValidator @@ -168,9 +167,10 @@ def score(self, X, y): # summary summary = dict(best=best) - ##### Upload tables wandb_callback = getattr(self.callbacks, "wandb", False) if wandb_callback: + ##### Upload tables + self.logger.info(f"Uploading tables to wandb...") wandb_callback = cast(WandbCallback, wandb_callback) ### upload best scores @@ -211,5 +211,6 @@ def score(self, X, y): "feature_ranking_", "feature_ranking" ) wandb_callback.upload_table(ranking_table, "feature_ranking") + self.logger.info(f"Tables uploaded {TerminalColor.green('✓')}") return summary diff --git a/fseval/storage_providers/wandb.py b/fseval/storage_providers/wandb.py index f05423e..8ec4a32 100644 --- a/fseval/storage_providers/wandb.py +++ b/fseval/storage_providers/wandb.py @@ -1,14 +1,25 @@ import os +from dataclasses import dataclass from logging import Logger, getLogger from pickle import dump, load -from typing import Any, Callable, Dict - -import wandb +from typing import Any, Callable, Dict, Optional from fseval.types import AbstractStorageProvider, TerminalColor +import wandb + +@dataclass class WandbStorageProvider(AbstractStorageProvider): + """Storage provider for Weights and Biases (wandb), allowing users to save- and + restore files to the service. + + Arguments: + local_dir: Optional[str] - when set, an attempt is made to load from the + designated local directory first, before downloading the data off of wandb. Can + be used to perform faster loads or prevent being rate-limited on wandb.""" + + local_dir: Optional[str] = None logger: Logger = getLogger(__name__) def _assert_wandb_available(self): @@ -61,24 +72,60 @@ def _get_restore_file_handle(self, filename: str): return None - def restore(self, filename: str, reader: Callable, mode: str = "r") -> Any: - self._assert_wandb_available() + def _local_restoration(self, filename: str, reader: Callable, mode: str = "r"): + local_file = os.path.join(self.local_dir or "", filename) + if self.local_dir is not None and os.path.exists(local_file): + filepath = local_file - file_handle = self._get_restore_file_handle(filename) + with open(filepath, mode=mode) as file_handle: + file = reader(file_handle) - if not file_handle: + return file or None + else: return None - filedir = wandb.run.dir # type: ignore - filepath = os.path.join(filedir, filename) - with open(filepath, mode=mode) as file_handle: - file = reader(file_handle) + def _wandb_restoration(self, filename: str, reader: Callable, mode: str = "r"): + if self._get_restore_file_handle(filename): + filedir = wandb.run.dir # type: ignore + filepath = os.path.join(filedir, filename) - self.logger.info( - f"successfully restored {TerminalColor.yellow(filename)} from wandb servers " - + TerminalColor.green("✓") - ) - return file + with open(filepath, mode=mode) as file_handle: + file = reader(file_handle) + + return file or None + else: + return None + + def restore(self, filename: str, reader: Callable, mode: str = "r") -> Any: + """Given a filename, restores the file either from local disk or from wandb, + depending on the availability of the file. First, the local disk is searched + for the file, taking in regard the `local_dir` value in the + `WandbStorageProvider` constructor. If this file is not found, the file will + be downloaded fresh from wandb servers.""" + + self._assert_wandb_available() + + # (1) attempt local restoration if available + file = self._local_restoration(filename, reader, mode) + if file: + self.logger.info( + f"successfully restored {TerminalColor.yellow(filename)} from " + + TerminalColor.blue("disk ") + + TerminalColor.green("✓") + ) + + return file + # (2) otherwise, restore by downloading from wandb + elif self._wandb_restoration(filename, reader, mode): + self.logger.info( + f"successfully restored {TerminalColor.yellow(filename)} from " + + TerminalColor.blue("wandb servers ") + + TerminalColor.green("✓") + ) + + return file + else: + return None def restore_pickle(self, filename: str) -> Any: return self.restore(filename, load, mode="rb")