Skip to content

Commit

Permalink
Add multiprocessing support 🚀
Browse files Browse the repository at this point in the history
  • Loading branch information
dunnkers committed Jun 5, 2021
1 parent 008f9ed commit 53987d0
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 14 deletions.
48 changes: 37 additions & 11 deletions fseval/pipelines/_experiment.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
import multiprocessing
from dataclasses import dataclass, field
from logging import Logger, getLogger
from time import perf_counter
from typing import List
from typing import List, Optional

import pandas as pd
from humanfriendly import format_timespan

from fseval.pipeline.estimator import Estimator
from fseval.types import AbstractEstimator, TerminalColor
from humanfriendly import format_timespan


@dataclass
class Experiment(AbstractEstimator):
estimators: List[AbstractEstimator] = field(default_factory=lambda: [])
logger: Logger = getLogger(__name__)
n_jobs: Optional[int] = None

def __post_init__(self):
self.estimators = list(self._get_estimator())
Expand Down Expand Up @@ -51,6 +52,19 @@ def _step_text(self, step_name, step_number, estimator):
def _prepare_data(self, X, y):
return X, y

def _fit_estimator(self, X, y, step_number, estimator):
logger = self._logger(estimator)
text = self._step_text("fit", step_number, estimator)

start_time = perf_counter()
estimator.fit(X, y)
fit_time = perf_counter() - start_time
setattr(estimator, "fit_time_", fit_time)

logger(text(fit_time))

return estimator

def fit(self, X, y) -> AbstractEstimator:
"""Sequentially fits all estimators in this experiment, and record timings;
which will be stored in a `fit_time_` attribute in each estimator itself.
Expand All @@ -61,16 +75,28 @@ def fit(self, X, y) -> AbstractEstimator:

X, y = self._prepare_data(X, y)

for step_number, estimator in enumerate(self.estimators):
logger = self._logger(estimator)
text = self._step_text("fit", step_number, estimator)
if self.n_jobs is not None:
assert (
self.n_jobs >= 1 or self.n_jobs == -1
), f"incorrect `n_jobs`: {self.n_jobs}"

cpus = multiprocessing.cpu_count() if self.n_jobs == -1 else self.n_jobs
self.logger.info(f"Using {cpus} CPU's in parallel (n_jobs={self.n_jobs})")

start_time = perf_counter()
estimator.fit(X, y)
fit_time = perf_counter() - start_time
setattr(estimator, "fit_time_", fit_time)
star_input = [
(X, y, step_number, estimator)
for step_number, estimator in enumerate(self.estimators)
]

logger(text(fit_time))
pool = multiprocessing.Pool(processes=cpus)
estimators = pool.starmap(self._fit_estimator, star_input)
pool.close()
pool.join()

self.estimators = estimators
else:
for step_number, estimator in enumerate(self.estimators):
self._fit_estimator(X, y, step_number, estimator)

return self

Expand Down
3 changes: 2 additions & 1 deletion fseval/pipelines/rank_and_validate/_components.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass
from logging import Logger, getLogger
from typing import List, cast
from typing import List, Optional, cast

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -127,6 +127,7 @@ class BootstrappedRankAndValidate(Experiment, RankAndValidatePipeline):
that various metrics can be better approximated."""

logger: Logger = getLogger(__name__)
n_jobs: Optional[int] = -1 # utilize all CPU's

def _get_estimator(self):
for bootstrap_state in np.arange(1, self.n_bootstraps + 1):
Expand Down
10 changes: 8 additions & 2 deletions fseval/storage_providers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from pickle import dump, load
from typing import Any, Callable, Dict

import wandb

from fseval.types import AbstractStorageProvider

import wandb


class WandbStorageProvider(AbstractStorageProvider):
logger: Logger = getLogger(__name__)
Expand All @@ -18,6 +18,9 @@ def set_config(self, config: Dict):
super(WandbStorageProvider, self).set_config(config)

def save(self, filename: str, writer: Callable, mode: str = "w"):
if __name__ != "__main__":
return

filedir = wandb.run.dir # type: ignore
filepath = os.path.join(filedir, filename)

Expand Down Expand Up @@ -50,6 +53,9 @@ def _get_restore_file_handle(self, filename: str):
return None

def restore(self, filename: str, reader: Callable, mode: str = "r") -> Any:
if __name__ != "__main__":
return

file_handle = self._get_restore_file_handle(filename)

if not file_handle:
Expand Down

0 comments on commit 53987d0

Please sign in to comment.