Skip to content

Commit

Permalink
Add ability to load from local disk cache 💪🏻💪🏻
Browse files Browse the repository at this point in the history
  • Loading branch information
dunnkers committed Jun 7, 2021
1 parent 415150a commit ff1cb5b
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 21 deletions.
3 changes: 2 additions & 1 deletion fseval/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import Any, Dict
from typing import Any, Dict, Optional

from hydra.core.config_store import ConfigStore
from omegaconf import MISSING
Expand All @@ -12,6 +12,7 @@
@dataclass
class StorageProviderConfig:
_target_: str = MISSING
local_dir: Optional[str] = None


@dataclass
Expand Down
9 changes: 5 additions & 4 deletions fseval/pipelines/rank_and_validate/_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@

import numpy as np
import pandas as pd
from fseval.callbacks import WandbCallback
from fseval.types import TerminalColor
from omegaconf import MISSING
from sklearn.base import clone
from tqdm import tqdm

from fseval.callbacks import WandbCallback
from fseval.types import TerminalColor

from .._experiment import Experiment
from ._config import RankAndValidatePipeline
from ._dataset_validator import DatasetValidator
Expand Down Expand Up @@ -168,9 +167,10 @@ def score(self, X, y):
# summary
summary = dict(best=best)

##### Upload tables
wandb_callback = getattr(self.callbacks, "wandb", False)
if wandb_callback:
##### Upload tables
self.logger.info(f"Uploading tables to wandb...")
wandb_callback = cast(WandbCallback, wandb_callback)

### upload best scores
Expand Down Expand Up @@ -211,5 +211,6 @@ def score(self, X, y):
"feature_ranking_", "feature_ranking"
)
wandb_callback.upload_table(ranking_table, "feature_ranking")
self.logger.info(f"Tables uploaded {TerminalColor.green('✓')}")

return summary
79 changes: 63 additions & 16 deletions fseval/storage_providers/wandb.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,25 @@
import os
from dataclasses import dataclass
from logging import Logger, getLogger
from pickle import dump, load
from typing import Any, Callable, Dict

import wandb
from typing import Any, Callable, Dict, Optional

from fseval.types import AbstractStorageProvider, TerminalColor

import wandb


@dataclass
class WandbStorageProvider(AbstractStorageProvider):
"""Storage provider for Weights and Biases (wandb), allowing users to save- and
restore files to the service.
Arguments:
local_dir: Optional[str] - when set, an attempt is made to load from the
designated local directory first, before downloading the data off of wandb. Can
be used to perform faster loads or prevent being rate-limited on wandb."""

local_dir: Optional[str] = None
logger: Logger = getLogger(__name__)

def _assert_wandb_available(self):
Expand Down Expand Up @@ -61,24 +72,60 @@ def _get_restore_file_handle(self, filename: str):

return None

def restore(self, filename: str, reader: Callable, mode: str = "r") -> Any:
self._assert_wandb_available()
def _local_restoration(self, filename: str, reader: Callable, mode: str = "r"):
local_file = os.path.join(self.local_dir or "", filename)
if self.local_dir is not None and os.path.exists(local_file):
filepath = local_file

file_handle = self._get_restore_file_handle(filename)
with open(filepath, mode=mode) as file_handle:
file = reader(file_handle)

if not file_handle:
return file or None
else:
return None

filedir = wandb.run.dir # type: ignore
filepath = os.path.join(filedir, filename)
with open(filepath, mode=mode) as file_handle:
file = reader(file_handle)
def _wandb_restoration(self, filename: str, reader: Callable, mode: str = "r"):
if self._get_restore_file_handle(filename):
filedir = wandb.run.dir # type: ignore
filepath = os.path.join(filedir, filename)

self.logger.info(
f"successfully restored {TerminalColor.yellow(filename)} from wandb servers "
+ TerminalColor.green("✓")
)
return file
with open(filepath, mode=mode) as file_handle:
file = reader(file_handle)

return file or None
else:
return None

def restore(self, filename: str, reader: Callable, mode: str = "r") -> Any:
"""Given a filename, restores the file either from local disk or from wandb,
depending on the availability of the file. First, the local disk is searched
for the file, taking in regard the `local_dir` value in the
`WandbStorageProvider` constructor. If this file is not found, the file will
be downloaded fresh from wandb servers."""

self._assert_wandb_available()

# (1) attempt local restoration if available
file = self._local_restoration(filename, reader, mode)
if file:
self.logger.info(
f"successfully restored {TerminalColor.yellow(filename)} from "
+ TerminalColor.blue("disk ")
+ TerminalColor.green("✓")
)

return file
# (2) otherwise, restore by downloading from wandb
elif self._wandb_restoration(filename, reader, mode):
self.logger.info(
f"successfully restored {TerminalColor.yellow(filename)} from "
+ TerminalColor.blue("wandb servers ")
+ TerminalColor.green("✓")
)

return file
else:
return None

def restore_pickle(self, filename: str) -> Any:
return self.restore(filename, load, mode="rb")

0 comments on commit ff1cb5b

Please sign in to comment.