From 29225e2484448842c7fd735ccdcb8ff21b74ddd6 Mon Sep 17 00:00:00 2001 From: rvandewater Date: Fri, 23 Sep 2022 16:12:11 +0200 Subject: [PATCH 1/2] Redid prep and bake and extracted common functionality --- icu_benchmarks/recipes/recipe.py | 36 +++++++++++++++++++++----------- 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/icu_benchmarks/recipes/recipe.py b/icu_benchmarks/recipes/recipe.py index 8532dc49..81528624 100644 --- a/icu_benchmarks/recipes/recipe.py +++ b/icu_benchmarks/recipes/recipe.py @@ -49,30 +49,42 @@ def _apply_group(self, data, step): data = data.groupby(group_vars) return data - def prep(self, data=None, fresh=False): + def prep(self, data=None, refit=False): + """ + Fits and transforms or preps the data + @param data: + @param refit: Refit all columns + @return: + """ data = self._check_data(data) data = copy(data) - - for step in self.steps: - data = self._apply_group(data, step) - if fresh or not step.trained: - data = step.fit_transform(data) - else: - data = step.transform(data) - + self._apply_fit_transform(self, data, refit) return self def bake(self, data=None): + """ + Transforms or bakes the data if it has been prepped. + @param data: + @return: + """ data = self._check_data(data) data = copy(data) - for step in self.steps: - data = self._apply_group(data, step) if not step.trained: raise RuntimeError(f'Step {step} not trained. Run prep first.') + + self._apply_fit_transform(self, data) + return data + + def _apply_fit_transform(self, data=None, refit=False): + # applies transform or fit and transform (when refit or not trained yet) + for step in self.steps: + data = self._apply_group(data, step) + + if refit or not step.trained: + data = step.fit_transform(data) else: data = step.transform(data) - return data def __repr__(self): From a86a3814f9f29cf6cd7de458e68a22fc5cdd2556 Mon Sep 17 00:00:00 2001 From: rvandewater Date: Fri, 23 Sep 2022 16:18:55 +0200 Subject: [PATCH 2/2] Small changes --- icu_benchmarks/recipes/recipe.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/icu_benchmarks/recipes/recipe.py b/icu_benchmarks/recipes/recipe.py index 81528624..abca7f63 100644 --- a/icu_benchmarks/recipes/recipe.py +++ b/icu_benchmarks/recipes/recipe.py @@ -51,7 +51,7 @@ def _apply_group(self, data, step): def prep(self, data=None, refit=False): """ - Fits and transforms or preps the data + Fits and transforms, in other words preps, the data. @param data: @param refit: Refit all columns @return: @@ -63,7 +63,7 @@ def prep(self, data=None, refit=False): def bake(self, data=None): """ - Transforms or bakes the data if it has been prepped. + Transforms, or bakes, the data if it has been prepped. @param data: @return: """ @@ -80,7 +80,6 @@ def _apply_fit_transform(self, data=None, refit=False): # applies transform or fit and transform (when refit or not trained yet) for step in self.steps: data = self._apply_group(data, step) - if refit or not step.trained: data = step.fit_transform(data) else: