Skip to content

Commit

Permalink
Finish multiprocessing support 🎉
Browse files Browse the repository at this point in the history
-> also, move caching to Estimator 🙌🏻
  • Loading branch information
dunnkers committed Jun 7, 2021
1 parent 53987d0 commit cacc902
Show file tree
Hide file tree
Showing 9 changed files with 120 additions and 43 deletions.
1 change: 1 addition & 0 deletions fseval/conf/pipeline/rank_and_validate.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ defaults:
- /estimator@ranker: relieff
- /estimator@validator: decision_tree
n_bootstraps: 2
n_jobs: -1
all_features_to_select: range(1, min(50, p) + 1)
10 changes: 8 additions & 2 deletions fseval/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,18 +80,24 @@ def main(cfg: BaseConfig) -> None:
X_train, X_test, y_train, y_test = cv.train_test_split(X, y)

try:
logger.info(f"pipeline {TerminalColor.cyan('prefit')}...")
pipeline.prefit()
logger.info(f"pipeline {TerminalColor.cyan('fit')}...")
pipeline.fit(X_train, y_train)
logger.info(f"pipeline {TerminalColor.cyan('postfit')}...")
pipeline.postfit()
logger.info(f"pipeline {TerminalColor.cyan('score')}...")
scores = pipeline.score(X_test, y_test)
except Exception as e:
print_exc()
logger.error(e)
logger.info(
"error occured during pipeline fitting step... "
"error occured during pipeline `prefit`, `fit` or `score` step... "
+ "exiting with a status code 1."
)
callbacks.on_end(exit_code=1)
raise e

scores = pipeline.score(X_test, y_test)
logger.info(f"{pipeline_name} pipeline finished {TerminalColor.green('✓')}")
callbacks.on_summary(scores)
callbacks.on_end()
Expand Down
26 changes: 24 additions & 2 deletions fseval/pipeline/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,22 @@
from logging import Logger, getLogger
from typing import Any, Optional

from fseval.types import (
AbstractEstimator,
AbstractStorageProvider,
IncompatibilityError,
Task,
)
from hydra.utils import instantiate
from omegaconf import II, MISSING, OmegaConf
from sklearn.preprocessing import minmax_scale

from fseval.types import AbstractEstimator, IncompatibilityError, Task


@dataclass
class EstimatorConfig:
estimator: Any = None # must have _target_ of type BaseEstimator.
use_cache_if_available: bool = True
# tags
multioutput: Optional[bool] = None
multioutput_only: Optional[bool] = None
requires_positive_X: Optional[bool] = None
Expand All @@ -30,6 +36,7 @@ class TaskedEstimatorConfig(EstimatorConfig):
name: str = MISSING
classifier: Optional[EstimatorConfig] = None
regressor: Optional[EstimatorConfig] = None
use_cache_if_available: bool = True
# tags
multioutput: Optional[bool] = False
multioutput_only: Optional[bool] = False
Expand All @@ -51,6 +58,7 @@ class Estimator(AbstractEstimator, EstimatorConfig):
task: Task = MISSING

logger: Logger = getLogger(__name__)
_is_fitted: bool = False

@classmethod
def _get_estimator_repr(cls, estimator):
Expand All @@ -66,13 +74,27 @@ def _get_class_repr(cls, estimator):
class_name = type(estimator).__name__
return f"{module_name}.{class_name}"

def _load_cache(self, filename: str, storage_provider: AbstractStorageProvider):
restored = storage_provider.restore_pickle(filename)
self.estimator = restored or self.estimator
self._is_fitted = restored

def _save_cache(self, filename: str, storage_provider: AbstractStorageProvider):
storage_provider.save_pickle(filename, self.estimator)

def fit(self, X, y):
# don't refit if cache available and `use_cache_if_available` is enabled
if self._is_fitted and self.use_cache_if_available:
return self

# rescale if necessary
if self.requires_positive_X:
X = minmax_scale(X)
self.logger.info(
"rescaled X: this estimator strictly requires positive features."
)

# fit
self.logger.debug(f"Fitting {Estimator._get_class_repr(self)}...")
self.estimator.fit(X, y)
return self
Expand Down
35 changes: 28 additions & 7 deletions fseval/pipelines/_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
class Experiment(AbstractEstimator):
estimators: List[AbstractEstimator] = field(default_factory=lambda: [])
logger: Logger = getLogger(__name__)
n_jobs: Optional[int] = None

def __post_init__(self):
self.estimators = list(self._get_estimator())

def _get_n_jobs(self):
return None

def _get_estimator(self):
return []

Expand Down Expand Up @@ -50,8 +52,17 @@ def _step_text(self, step_name, step_number, estimator):
)

def _prepare_data(self, X, y):
"""Callback. Can be used to implement any data preparation schemes."""
return X, y

def prefit(self):
"""Pre-fit hook. Is executed right before calling `fit()`. Can be used to load
estimators from cache or do any other preparatory work."""

for estimator in self.estimators:
if hasattr(estimator, "prefit") and callable(getattr(estimator, "prefit")):
estimator.prefit()

def _fit_estimator(self, X, y, step_number, estimator):
logger = self._logger(estimator)
text = self._step_text("fit", step_number, estimator)
Expand All @@ -75,13 +86,13 @@ def fit(self, X, y) -> AbstractEstimator:

X, y = self._prepare_data(X, y)

if self.n_jobs is not None:
assert (
self.n_jobs >= 1 or self.n_jobs == -1
), f"incorrect `n_jobs`: {self.n_jobs}"
## Run `fit`
n_jobs = self._get_n_jobs()
if n_jobs is not None:
assert n_jobs >= 1 or n_jobs == -1, f"incorrect `n_jobs`: {n_jobs}"

cpus = multiprocessing.cpu_count() if self.n_jobs == -1 else self.n_jobs
self.logger.info(f"Using {cpus} CPU's in parallel (n_jobs={self.n_jobs})")
cpus = multiprocessing.cpu_count() if n_jobs == -1 else n_jobs
self.logger.info(f"Using {cpus} CPU's in parallel (n_jobs={n_jobs})")

star_input = [
(X, y, step_number, estimator)
Expand All @@ -100,6 +111,16 @@ def fit(self, X, y) -> AbstractEstimator:

return self

def postfit(self):
"""Post-fit hook. Is executed right after calling `fit()`. Can be used to save
estimators to cache, for example."""

for estimator in self.estimators:
if hasattr(estimator, "postfit") and callable(
getattr(estimator, "postfit")
):
estimator.postfit()

def transform(self, X, y):
...

Expand Down
11 changes: 9 additions & 2 deletions fseval/pipelines/rank_and_validate/_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pandas as pd
from fseval.callbacks import WandbCallback
from fseval.pipeline.estimator import Estimator
from fseval.types import TerminalColor
from omegaconf import MISSING
from sklearn.base import clone
from tqdm import tqdm
Expand Down Expand Up @@ -116,7 +117,9 @@ def score(self, X, y):
scores = scores.append(validation_score)
scores["bootstrap_state"] = self.bootstrap_state

self.logger.info(f"scored bootstrap_state={self.bootstrap_state} ✓")
self.logger.info(
f"scored bootstrap_state={self.bootstrap_state} " + TerminalColor.green("✓")
)
return scores


Expand All @@ -127,7 +130,11 @@ class BootstrappedRankAndValidate(Experiment, RankAndValidatePipeline):
that various metrics can be better approximated."""

logger: Logger = getLogger(__name__)
n_jobs: Optional[int] = -1 # utilize all CPU's

def _get_n_jobs(self):
"""Allow each bootstrap experiment to run on a separate CPU."""

return self.n_jobs

def _get_estimator(self):
for bootstrap_state in np.arange(1, self.n_bootstraps + 1):
Expand Down
3 changes: 3 additions & 0 deletions fseval/pipelines/rank_and_validate/_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
from dataclasses import dataclass
from typing import Optional

from fseval.pipeline.estimator import Estimator, TaskedEstimatorConfig
from fseval.pipeline.resample import Resample, ResampleConfig
Expand All @@ -20,6 +21,7 @@ class RankAndValidateConfig:
ranker: TaskedEstimatorConfig = MISSING
validator: TaskedEstimatorConfig = MISSING
n_bootstraps: int = MISSING
n_jobs: Optional[int] = MISSING
all_features_to_select: str = MISSING


Expand All @@ -33,6 +35,7 @@ class RankAndValidatePipeline(Pipeline):
ranker: Estimator = MISSING
validator: Estimator = MISSING
n_bootstraps: int = MISSING
n_jobs: Optional[int] = MISSING
all_features_to_select: str = MISSING

def _get_config(self):
Expand Down
27 changes: 15 additions & 12 deletions fseval/pipelines/rank_and_validate/_ranking_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@

import numpy as np
import pandas as pd
from fseval.types import IncompatibilityError
from omegaconf import MISSING
from sklearn.metrics import log_loss, r2_score

from fseval.types import IncompatibilityError

from .._experiment import Experiment
from ._config import RankAndValidatePipeline

Expand All @@ -31,20 +30,24 @@ def __post_init__(self):

super(RankingValidator, self).__post_init__()

@property
def _cache_filename(self):
override = f"bootstrap_state={self.bootstrap_state}"
filename = f"ranking[{override}].pickle"

return filename

def _get_estimator(self):
yield self.ranker

def prefit(self):
self.ranker._load_cache(self._cache_filename, self.storage_provider)

def fit(self, X, y):
override = f"bootstrap_state={self.bootstrap_state}"
filename = f"ranking[{override}].pickle"
restored = self.storage_provider.restore_pickle(filename)

if restored:
self.ranker.estimator = restored
self.logger.info("restored ranking from storage provider ✓")
else:
super(RankingValidator, self).fit(X, y)
self.storage_provider.save_pickle(filename, self.ranker.estimator)
super(RankingValidator, self).fit(X, y)

def postfit(self):
self.ranker._save_cache(self._cache_filename, self.storage_provider)

def score(self, X, y):
"""Scores a feature ranker, if a ground-truth on the desired dataset
Expand Down
25 changes: 14 additions & 11 deletions fseval/pipelines/rank_and_validate/_subset_validator.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from dataclasses import dataclass

import numpy as np
from omegaconf import MISSING
from sklearn.feature_selection import SelectFromModel

from fseval.pipeline.estimator import Estimator
from fseval.types import IncompatibilityError
from omegaconf import MISSING
from sklearn.feature_selection import SelectFromModel

from .._experiment import Experiment
from ._config import RankAndValidatePipeline
Expand Down Expand Up @@ -57,18 +56,22 @@ def _prepare_data(self, X, y):
X = selector.transform(X)
return X, y

def fit(self, X, y):
@property
def _cache_filename(self):
override = f"bootstrap_state={self.bootstrap_state}"
override += f",n_features_to_select={self.n_features_to_select}"
filename = f"validation[{override}].pickle"
restored = self.storage_provider.restore_pickle(filename)

if restored:
self.validator.estimator = restored
self.logger.info("restored validator from storage provider ✓")
else:
super(SubsetValidator, self).fit(X, y)
self.storage_provider.save_pickle(filename, self.validator.estimator)
return filename

def prefit(self):
self.validator._load_cache(self._cache_filename, self.storage_provider)

def fit(self, X, y):
super(SubsetValidator, self).fit(X, y)

def postfit(self):
self.validator._save_cache(self._cache_filename, self.storage_provider)

def score(self, X, y):
score = super(SubsetValidator, self).score(X, y)
Expand Down
25 changes: 18 additions & 7 deletions fseval/storage_providers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,29 @@
from pickle import dump, load
from typing import Any, Callable, Dict

from fseval.types import AbstractStorageProvider
from fseval.types import AbstractStorageProvider, TerminalColor

import wandb


class WandbStorageProvider(AbstractStorageProvider):
logger: Logger = getLogger(__name__)

def _assert_wandb_available(self):
assert wandb.run is not None, (
"`wandb.run` is not available in this process. you are perhaps using multi-"
+ "processing: make sure to only use the wandb storage provider from the main"
+ "thread. see https://docs.wandb.ai/guides/track/advanced/distributed-training."
)

def set_config(self, config: Dict):
assert config["callbacks"].get(
"wandb"
), "wandb callback must be enabled to use wandb storage provider."
super(WandbStorageProvider, self).set_config(config)

def save(self, filename: str, writer: Callable, mode: str = "w"):
if __name__ != "__main__":
return
self._assert_wandb_available()

filedir = wandb.run.dir # type: ignore
filepath = os.path.join(filedir, filename)
Expand All @@ -28,7 +34,10 @@ def save(self, filename: str, writer: Callable, mode: str = "w"):
writer(file_handle)

wandb.save(filename, base_path="/") # type: ignore
self.logger.info(f"successfully saved `{filename}` to wandb servers ✓")
self.logger.info(
f"successfully saved {TerminalColor.yellow(filename)} to wandb servers "
+ TerminalColor.green("✓")
)

def save_pickle(self, filename: str, obj: Any):
self.save(filename, lambda file: dump(obj, file), mode="wb")
Expand All @@ -53,8 +62,7 @@ def _get_restore_file_handle(self, filename: str):
return None

def restore(self, filename: str, reader: Callable, mode: str = "r") -> Any:
if __name__ != "__main__":
return
self._assert_wandb_available()

file_handle = self._get_restore_file_handle(filename)

Expand All @@ -66,7 +74,10 @@ def restore(self, filename: str, reader: Callable, mode: str = "r") -> Any:
with open(filepath, mode=mode) as file_handle:
file = reader(file_handle)

self.logger.info(f"successfully restored `{filename}` from wandb servers ✓")
self.logger.info(
f"successfully restored {TerminalColor.yellow(filename)} from wandb servers "
+ TerminalColor.green("✓")
)
return file

def restore_pickle(self, filename: str) -> Any:
Expand Down

0 comments on commit cacc902

Please sign in to comment.