Skip to content

Commit

Permalink
Merge branch 'development' into behaviour_prep_bake
Browse files Browse the repository at this point in the history
  • Loading branch information
rvandewater authored Sep 23, 2022
2 parents a86a381 + b969a2a commit 4a2e62b
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 11 deletions.
6 changes: 3 additions & 3 deletions icu_benchmarks/recipes/ingredients.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
class Ingredients(pd.DataFrame):
_metadata = ["roles"]

def __init__(self, data=None, index=None, columns=None, dtype=None, copy=None,):
super().__init__(data, index, columns, dtype, )
self.roles = {}
def __init__(self, data=None, index=None, columns=None, dtype=None, copy=None, roles = {}):
super().__init__(data, index, columns, dtype, copy, )
self.roles = roles

@property
def _constructor(self):
Expand Down
4 changes: 2 additions & 2 deletions icu_benchmarks/recipes/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@
rec.add_step(StepImputeFill(method='ffill'))
rec.add_step(StepImputeFill(value=0))

rec.prep()
rec.bake()
rec.prep(df.iloc[:-10000, :])
rec.bake(df.iloc[10000:, :])

11 changes: 5 additions & 6 deletions icu_benchmarks/recipes/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def add_step(self, step):
def _check_data(self, data):
if data is None:
data = self.data
elif data.__class__ == pd.DataFrame:
data = Ingredients(data, roles=self.data.roles)
if not data.columns.equals(self.data.columns):
raise ValueError('Columns of data argument differs from recipe data.')
return data
Expand All @@ -60,7 +62,8 @@ def prep(self, data=None, refit=False):
data = copy(data)
self._apply_fit_transform(self, data, refit)
return self

return pd.DataFrame(data)

def bake(self, data=None):
"""
Transforms, or bakes, the data if it has been prepped.
Expand All @@ -69,10 +72,6 @@ def bake(self, data=None):
"""
data = self._check_data(data)
data = copy(data)
for step in self.steps:
if not step.trained:
raise RuntimeError(f'Step {step} not trained. Run prep first.')

self._apply_fit_transform(self, data)
return data

Expand All @@ -84,7 +83,7 @@ def _apply_fit_transform(self, data=None, refit=False):
data = step.fit_transform(data)
else:
data = step.transform(data)
return data
return pd.DataFrame(data)

def __repr__(self):
repr = 'Recipe\n\n'
Expand Down

0 comments on commit 4a2e62b

Please sign in to comment.