From 53987d079244eb2604f9b4ecd76ae28ec9080dd2 Mon Sep 17 00:00:00 2001 From: Jeroen Overschie Date: Sat, 5 Jun 2021 19:45:50 +0200 Subject: [PATCH] =?UTF-8?q?Add=20multiprocessing=20support=20=F0=9F=9A=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fseval/pipelines/_experiment.py | 48 ++++++++++++++----- .../rank_and_validate/_components.py | 3 +- fseval/storage_providers/wandb.py | 10 +++- 3 files changed, 47 insertions(+), 14 deletions(-) diff --git a/fseval/pipelines/_experiment.py b/fseval/pipelines/_experiment.py index 9117b8b..b05a62c 100644 --- a/fseval/pipelines/_experiment.py +++ b/fseval/pipelines/_experiment.py @@ -1,19 +1,20 @@ +import multiprocessing from dataclasses import dataclass, field from logging import Logger, getLogger from time import perf_counter -from typing import List +from typing import List, Optional import pandas as pd -from humanfriendly import format_timespan - from fseval.pipeline.estimator import Estimator from fseval.types import AbstractEstimator, TerminalColor +from humanfriendly import format_timespan @dataclass class Experiment(AbstractEstimator): estimators: List[AbstractEstimator] = field(default_factory=lambda: []) logger: Logger = getLogger(__name__) + n_jobs: Optional[int] = None def __post_init__(self): self.estimators = list(self._get_estimator()) @@ -51,6 +52,19 @@ def _step_text(self, step_name, step_number, estimator): def _prepare_data(self, X, y): return X, y + def _fit_estimator(self, X, y, step_number, estimator): + logger = self._logger(estimator) + text = self._step_text("fit", step_number, estimator) + + start_time = perf_counter() + estimator.fit(X, y) + fit_time = perf_counter() - start_time + setattr(estimator, "fit_time_", fit_time) + + logger(text(fit_time)) + + return estimator + def fit(self, X, y) -> AbstractEstimator: """Sequentially fits all estimators in this experiment, and record timings; which will be stored in a `fit_time_` attribute in each estimator itself. @@ -61,16 +75,28 @@ def fit(self, X, y) -> AbstractEstimator: X, y = self._prepare_data(X, y) - for step_number, estimator in enumerate(self.estimators): - logger = self._logger(estimator) - text = self._step_text("fit", step_number, estimator) + if self.n_jobs is not None: + assert ( + self.n_jobs >= 1 or self.n_jobs == -1 + ), f"incorrect `n_jobs`: {self.n_jobs}" + + cpus = multiprocessing.cpu_count() if self.n_jobs == -1 else self.n_jobs + self.logger.info(f"Using {cpus} CPU's in parallel (n_jobs={self.n_jobs})") - start_time = perf_counter() - estimator.fit(X, y) - fit_time = perf_counter() - start_time - setattr(estimator, "fit_time_", fit_time) + star_input = [ + (X, y, step_number, estimator) + for step_number, estimator in enumerate(self.estimators) + ] - logger(text(fit_time)) + pool = multiprocessing.Pool(processes=cpus) + estimators = pool.starmap(self._fit_estimator, star_input) + pool.close() + pool.join() + + self.estimators = estimators + else: + for step_number, estimator in enumerate(self.estimators): + self._fit_estimator(X, y, step_number, estimator) return self diff --git a/fseval/pipelines/rank_and_validate/_components.py b/fseval/pipelines/rank_and_validate/_components.py index fb07f0c..9c3713b 100644 --- a/fseval/pipelines/rank_and_validate/_components.py +++ b/fseval/pipelines/rank_and_validate/_components.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from logging import Logger, getLogger -from typing import List, cast +from typing import List, Optional, cast import numpy as np import pandas as pd @@ -127,6 +127,7 @@ class BootstrappedRankAndValidate(Experiment, RankAndValidatePipeline): that various metrics can be better approximated.""" logger: Logger = getLogger(__name__) + n_jobs: Optional[int] = -1 # utilize all CPU's def _get_estimator(self): for bootstrap_state in np.arange(1, self.n_bootstraps + 1): diff --git a/fseval/storage_providers/wandb.py b/fseval/storage_providers/wandb.py index 2fc8314..5091379 100644 --- a/fseval/storage_providers/wandb.py +++ b/fseval/storage_providers/wandb.py @@ -3,10 +3,10 @@ from pickle import dump, load from typing import Any, Callable, Dict -import wandb - from fseval.types import AbstractStorageProvider +import wandb + class WandbStorageProvider(AbstractStorageProvider): logger: Logger = getLogger(__name__) @@ -18,6 +18,9 @@ def set_config(self, config: Dict): super(WandbStorageProvider, self).set_config(config) def save(self, filename: str, writer: Callable, mode: str = "w"): + if __name__ != "__main__": + return + filedir = wandb.run.dir # type: ignore filepath = os.path.join(filedir, filename) @@ -50,6 +53,9 @@ def _get_restore_file_handle(self, filename: str): return None def restore(self, filename: str, reader: Callable, mode: str = "r") -> Any: + if __name__ != "__main__": + return + file_handle = self._get_restore_file_handle(filename) if not file_handle: