diff --git a/icu_benchmarks/recipes/recipe.py b/icu_benchmarks/recipes/recipe.py index fb722806..b10e8522 100644 --- a/icu_benchmarks/recipes/recipe.py +++ b/icu_benchmarks/recipes/recipe.py @@ -51,30 +51,38 @@ def _apply_group(self, data, step): data = data.groupby(group_vars) return data - def prep(self, data=None, fresh=False): + def prep(self, data=None, refit=False): + """ + Fits and transforms, in other words preps, the data. + @param data: + @param refit: Refit all columns + @return: + """ data = self._check_data(data) data = copy(data) - - for step in self.steps: - data = self._apply_group(data, step) - if fresh or not step.trained: - data = step.fit_transform(data) - else: - data = step.transform(data) - + self._apply_fit_transform(self, data, refit) + return self return pd.DataFrame(data) - + def bake(self, data=None): + """ + Transforms, or bakes, the data if it has been prepped. + @param data: + @return: + """ data = self._check_data(data) data = copy(data) - + self._apply_fit_transform(self, data) + return data + + def _apply_fit_transform(self, data=None, refit=False): + # applies transform or fit and transform (when refit or not trained yet) for step in self.steps: data = self._apply_group(data, step) - if not step.trained: - raise RuntimeError(f'Step {step} not trained. Run prep first.') + if refit or not step.trained: + data = step.fit_transform(data) else: data = step.transform(data) - return pd.DataFrame(data) def __repr__(self):