Skip to content

Commit

Permalink
Fix callbacks multiprocessing issue
Browse files Browse the repository at this point in the history
  • Loading branch information
dunnkers committed Jan 30, 2022
1 parent 7203d9c commit 333b36a
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 38 deletions.
79 changes: 45 additions & 34 deletions fseval/pipelines/_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,48 @@ def _fit_estimator(self, X, y, step_number, estimator):

return estimator

def _remove_and_get_callbacks(self) -> Tuple[Dict[str, Callback], List[str]]:
"""
Removes callbacks from this Experiment object. Returns them as a
tuple. This is necessary because the callbacks might contain state that
either **cannot** be pickled, or **should** not be taken over to a
process fork. For example, SQLAlchemy's engine object can intentionally
not be forked - and with good reason.
@see https://docs.sqlalchemy.org/en/14/core/pooling.html
"""
callback_objects = dict()
callback_names = self.callbacks.callback_names

for callback_name in self.callbacks.callback_names:
callback_objects[callback_name] = getattr(self.callbacks, callback_name)
delattr(self.callbacks, callback_name)

self.callbacks.callback_names = []

return callback_objects, callback_names

def _set_callbacks(
self, callback_objects: Dict[str, Callback], callback_names: List[str]
):
"""
Restores the callbacks, by setting them back onto this class object.
Attributes:
callback_objects (Dict[str, Callback]): The callback objects, like stored in
`_remove_and_get_callbacks()`.
callback_names (List[str]): The corresponding callback names, like stored in
`_remove_and_get_callbacks()`.
"""
self.callbacks.callback_names = callback_names

for callback_name in self.callbacks.callback_names:
setattr(self.callbacks, callback_name, callback_objects[callback_name])

def fit(self, X, y) -> AbstractEstimator:
"""Sequentially fits all estimators in this experiment.
Args:
Attributes:
X (np.ndarray): design matrix X
y (np.ndarray): target labels y"""

Expand All @@ -97,54 +135,27 @@ def fit(self, X, y) -> AbstractEstimator:
if n_jobs is not None and (n_jobs > 1 or n_jobs == -1):
assert n_jobs >= 1 or n_jobs == -1, f"incorrect `n_jobs`: {n_jobs}"

# remove callbacks this object and store locally in main thread.
callback_objects, callback_names = self.remove_and_get_callbacks()

# determine amount of CPU's to use. ALL if n_jobs is -1, else n_jobs.
cpus = multiprocessing.cpu_count() if n_jobs == -1 else n_jobs
self.logger.info(f"Using {cpus} CPU's in parallel (n_jobs={n_jobs})")

# input to `self._fit_esitmator`
star_input = [
(X, y, step_number, estimator)
for step_number, estimator in enumerate(self.estimators)
]

# # dispose SQLAlchemy engine if present
# # @see https://docs.sqlalchemy.org/en/14/core/pooling.html
# if "to_sql" in self.callbacks.callback_names:
# engine: Engine = self.callbacks.to_sql.engine
# engine.dispose()

def remove_and_get_callbacks() -> Tuple[Dict[str, Callback], List[str]]:
callback_objects = dict()
callback_names = self.callbacks.callback_names

for callback_name in self.callbacks.callback_names:
callback_objects[callback_name] = getattr(
self.callbacks, callback_name
)
delattr(self.callbacks, callback_name)

self.callbacks.callback_names = []

return callback_objects, callback_names

def set_callbacks(
callback_objects: Dict[str, Callback], callback_names: List[str]
):
self.callbacks.callback_names = callback_names

for callback_name in self.callbacks.callback_names:
setattr(
self.callbacks, callback_name, callback_objects[callback_name]
)

callback_objects, callback_names = remove_and_get_callbacks()

# open pool and fit estimators.
pool = multiprocessing.Pool(processes=cpus)
estimators = pool.starmap(self._fit_estimator, star_input)
pool.close()
pool.join()

# restore callbacks in main thread
set_callbacks(callback_objects, callback_names)
self.set_callbacks(callback_objects, callback_names)

# set collected estimators to this local object
self.estimators = estimators
Expand Down
8 changes: 4 additions & 4 deletions tests/integration/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,19 +119,19 @@ def test_pipeline_failure(failing_cfg: PipelineConfig):


@pytest.fixture
def to_sql_multiprocessing_cfg() -> PipelineConfig:
def callbacks_multiprocessing_cfg() -> PipelineConfig:
config = get_config(
config_module="tests.integration.conf",
config_name="to_sql_multiprocessing",
config_name="callbacks_multiprocessing",
)

return config


def test_to_sql_multiprocessing(to_sql_multiprocessing_cfg: PipelineConfig):
def test_callbacks_multiprocessing(callbacks_multiprocessing_cfg: PipelineConfig):
# execute from temporary dir
tmpdir = tempfile.mkdtemp()
os.chdir(tmpdir)

# run pipeline
run_pipeline(to_sql_multiprocessing_cfg)
run_pipeline(callbacks_multiprocessing_cfg)

0 comments on commit 333b36a

Please sign in to comment.