diff --git a/fseval/pipelines/_experiment.py b/fseval/pipelines/_experiment.py index 615bffe..34cc9cf 100644 --- a/fseval/pipelines/_experiment.py +++ b/fseval/pipelines/_experiment.py @@ -83,10 +83,48 @@ def _fit_estimator(self, X, y, step_number, estimator): return estimator + def _remove_and_get_callbacks(self) -> Tuple[Dict[str, Callback], List[str]]: + """ + Removes callbacks from this Experiment object. Returns them as a + tuple. This is necessary because the callbacks might contain state that + either **cannot** be pickled, or **should** not be taken over to a + process fork. For example, SQLAlchemy's engine object can intentionally + not be forked - and with good reason. + + @see https://docs.sqlalchemy.org/en/14/core/pooling.html + """ + callback_objects = dict() + callback_names = self.callbacks.callback_names + + for callback_name in self.callbacks.callback_names: + callback_objects[callback_name] = getattr(self.callbacks, callback_name) + delattr(self.callbacks, callback_name) + + self.callbacks.callback_names = [] + + return callback_objects, callback_names + + def _set_callbacks( + self, callback_objects: Dict[str, Callback], callback_names: List[str] + ): + """ + Restores the callbacks, by setting them back onto this class object. + + Attributes: + callback_objects (Dict[str, Callback]): The callback objects, like stored in + `_remove_and_get_callbacks()`. + callback_names (List[str]): The corresponding callback names, like stored in + `_remove_and_get_callbacks()`. + """ + self.callbacks.callback_names = callback_names + + for callback_name in self.callbacks.callback_names: + setattr(self.callbacks, callback_name, callback_objects[callback_name]) + def fit(self, X, y) -> AbstractEstimator: """Sequentially fits all estimators in this experiment. - Args: + Attributes: X (np.ndarray): design matrix X y (np.ndarray): target labels y""" @@ -97,46 +135,19 @@ def fit(self, X, y) -> AbstractEstimator: if n_jobs is not None and (n_jobs > 1 or n_jobs == -1): assert n_jobs >= 1 or n_jobs == -1, f"incorrect `n_jobs`: {n_jobs}" + # remove callbacks this object and store locally in main thread. + callback_objects, callback_names = self.remove_and_get_callbacks() + + # determine amount of CPU's to use. ALL if n_jobs is -1, else n_jobs. cpus = multiprocessing.cpu_count() if n_jobs == -1 else n_jobs self.logger.info(f"Using {cpus} CPU's in parallel (n_jobs={n_jobs})") + # input to `self._fit_esitmator` star_input = [ (X, y, step_number, estimator) for step_number, estimator in enumerate(self.estimators) ] - # # dispose SQLAlchemy engine if present - # # @see https://docs.sqlalchemy.org/en/14/core/pooling.html - # if "to_sql" in self.callbacks.callback_names: - # engine: Engine = self.callbacks.to_sql.engine - # engine.dispose() - - def remove_and_get_callbacks() -> Tuple[Dict[str, Callback], List[str]]: - callback_objects = dict() - callback_names = self.callbacks.callback_names - - for callback_name in self.callbacks.callback_names: - callback_objects[callback_name] = getattr( - self.callbacks, callback_name - ) - delattr(self.callbacks, callback_name) - - self.callbacks.callback_names = [] - - return callback_objects, callback_names - - def set_callbacks( - callback_objects: Dict[str, Callback], callback_names: List[str] - ): - self.callbacks.callback_names = callback_names - - for callback_name in self.callbacks.callback_names: - setattr( - self.callbacks, callback_name, callback_objects[callback_name] - ) - - callback_objects, callback_names = remove_and_get_callbacks() - # open pool and fit estimators. pool = multiprocessing.Pool(processes=cpus) estimators = pool.starmap(self._fit_estimator, star_input) @@ -144,7 +155,7 @@ def set_callbacks( pool.join() # restore callbacks in main thread - set_callbacks(callback_objects, callback_names) + self.set_callbacks(callback_objects, callback_names) # set collected estimators to this local object self.estimators = estimators diff --git a/tests/integration/conf/to_sql_multiprocessing.yaml b/tests/integration/conf/callbacks_multiprocessing.yaml similarity index 100% rename from tests/integration/conf/to_sql_multiprocessing.yaml rename to tests/integration/conf/callbacks_multiprocessing.yaml diff --git a/tests/integration/test_main.py b/tests/integration/test_main.py index 491aed9..d389580 100644 --- a/tests/integration/test_main.py +++ b/tests/integration/test_main.py @@ -119,19 +119,19 @@ def test_pipeline_failure(failing_cfg: PipelineConfig): @pytest.fixture -def to_sql_multiprocessing_cfg() -> PipelineConfig: +def callbacks_multiprocessing_cfg() -> PipelineConfig: config = get_config( config_module="tests.integration.conf", - config_name="to_sql_multiprocessing", + config_name="callbacks_multiprocessing", ) return config -def test_to_sql_multiprocessing(to_sql_multiprocessing_cfg: PipelineConfig): +def test_callbacks_multiprocessing(callbacks_multiprocessing_cfg: PipelineConfig): # execute from temporary dir tmpdir = tempfile.mkdtemp() os.chdir(tmpdir) # run pipeline - run_pipeline(to_sql_multiprocessing_cfg) + run_pipeline(callbacks_multiprocessing_cfg)