From e8076613803038989f8c9c96dcb02c7712ee33aa Mon Sep 17 00:00:00 2001 From: Jan Ittner Date: Thu, 9 Sep 2021 11:31:19 +0200 Subject: [PATCH 1/2] API: expect Iterable not positional args in run_jobs() and run_queues() --- src/facet/inspection/_shap.py | 32 ++++++++++++++--------------- src/facet/selection/_selection.py | 2 +- src/facet/simulation/_simulation.py | 24 +++++++++------------- 3 files changed, 26 insertions(+), 32 deletions(-) diff --git a/src/facet/inspection/_shap.py b/src/facet/inspection/_shap.py index 1e488cf0a..39ba1ff58 100644 --- a/src/facet/inspection/_shap.py +++ b/src/facet/inspection/_shap.py @@ -275,23 +275,21 @@ def _make_explainer(_model: T_LearnerPipelineDF) -> BaseExplainer: else: shap_df_per_split = JobRunner.from_parallelizable(self).run_jobs( - *( - Job.delayed(self._get_shap_for_split)( - model, - sample, - _make_explainer(model), - self.feature_index_, - self._convert_raw_shap_to_df, - self.get_multi_output_type(), - self._get_multi_output_names(model=model, sample=sample), - ) - for model, sample in zip( - crossfit.models(), - ( - sample.subsample(iloc=oob_split) - for _, oob_split in crossfit.splits() - ), - ) + Job.delayed(self._get_shap_for_split)( + model, + sample, + _make_explainer(model), + self.feature_index_, + self._convert_raw_shap_to_df, + self.get_multi_output_type(), + self._get_multi_output_names(model=model, sample=sample), + ) + for model, sample in zip( + crossfit.models(), + ( + sample.subsample(iloc=oob_split) + for _, oob_split in crossfit.splits() + ), ) ) diff --git a/src/facet/selection/_selection.py b/src/facet/selection/_selection.py index c8dfe6abb..e27a39497 100644 --- a/src/facet/selection/_selection.py +++ b/src/facet/selection/_selection.py @@ -518,7 +518,7 @@ def _rank_learners( ) pipeline_scorings: List[np.ndarray] = list( - JobRunner.from_parallelizable(self).run_queues(*queues) + JobRunner.from_parallelizable(self).run_queues(queues) ) for crossfit, pipeline_parameters, pipeline_scoring in zip( diff --git a/src/facet/simulation/_simulation.py b/src/facet/simulation/_simulation.py index 33fff80ac..1782e7a90 100644 --- a/src/facet/simulation/_simulation.py +++ b/src/facet/simulation/_simulation.py @@ -346,12 +346,10 @@ def simulate_actuals(self) -> pd.Series: y_mean = self.expected_output() result: List[float] = JobRunner.from_parallelizable(self).run_jobs( - *( - Job.delayed(self._simulate_actuals)( - model, subsample.features, y_mean, self._simulate - ) - for model, subsample in self._get_simulations() + Job.delayed(self._simulate_actuals)( + model, subsample.features, y_mean, self._simulate ) + for model, subsample in self._get_simulations() ) return pd.Series( @@ -417,16 +415,14 @@ def _simulate_feature_with_values( simulation_results_per_split: List[np.ndarray] = JobRunner.from_parallelizable( self ).run_jobs( - *( - Job.delayed(UnivariateUpliftSimulator._simulate_values_for_split)( - model=model, - subsample=subsample, - feature_name=feature_name, - simulated_values=simulation_values, - simulate_fn=self._simulate, - ) - for (model, subsample) in self._get_simulations() + Job.delayed(UnivariateUpliftSimulator._simulate_values_for_split)( + model=model, + subsample=subsample, + feature_name=feature_name, + simulated_values=simulation_values, + simulate_fn=self._simulate, ) + for (model, subsample) in self._get_simulations() ) return pd.DataFrame( From b238bd139c69ed9a6052e3042da522341662c037 Mon Sep 17 00:00:00 2001 From: Jan Ittner Date: Fri, 10 Sep 2021 11:27:39 +0200 Subject: [PATCH 2/2] API: rename JobQueue.collate() to .aggregate() --- src/facet/crossfit/_crossfit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/facet/crossfit/_crossfit.py b/src/facet/crossfit/_crossfit.py index 8a39b8aaa..3b4789d0f 100644 --- a/src/facet/crossfit/_crossfit.py +++ b/src/facet/crossfit/_crossfit.py @@ -480,7 +480,7 @@ def on_run(self) -> None: if do_fit: crossfit._reset_fit() - def collate(self, job_results: List[FitResult]) -> Optional[np.ndarray]: + def aggregate(self, job_results: List[FitResult]) -> Optional[np.ndarray]: models, scores = zip(*job_results) if do_fit: