Skip to content
This repository has been archived by the owner on Nov 16, 2023. It is now read-only.

Commit

Permalink
Delete the cached summaries when refitting a pipeline or a predictor. (
Browse files Browse the repository at this point in the history
…#109)

* Fix build issue on Windows when VS2019 is installed.

Note: The -version option could not be added directly
to the FOR command due to a command script parsing issue.

* Add missing arguments to fix build issue with latest version of autoflake.

* Delete the cached summaries when refitting a pipeline or a predictor.
Fixes #106

* Simplify the code that deletes cached summaries when calling fit.
  • Loading branch information
pieths authored and ganik committed Jun 1, 2019
1 parent 8da35e1 commit b4ec723
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 3 deletions.
5 changes: 5 additions & 0 deletions src/python/nimbusml/base_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ def fit(self, X, y=None, **params):
self.X_ = X
self.y_ = y

# Clear cached summary since it should not
# retain its value after a new call to fit
if hasattr(self, 'model_summary_'):
delattr(self, 'model_summary_')

pipeline = Pipeline([self])
try:
pipeline.fit(X, y, **params)
Expand Down
8 changes: 5 additions & 3 deletions src/python/nimbusml/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1089,13 +1089,15 @@ def fit(self, X, y=None, verbose=1, **params):
clone = self.clone()
self.steps = clone.steps

# Clear cached values
for attr in ["_run_time_error", "model_summary"]:
if hasattr(self, attr):
delattr(self, attr)

# Caches the predictor to restore it as it was
# in case of exception. It is deleted after the training.
self._cache_predictor = deepcopy(self.steps[-1])

if hasattr(self, "_run_time_error"):
delattr(self, "_run_time_error")

# Checks that no node was ever trained.
for i, n in enumerate(self.nodes):
if hasattr(n, "model_") and n.model_ is not None:
Expand Down
23 changes: 23 additions & 0 deletions src/python/nimbusml/tests/model_summary/test_model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,29 @@ def test_summary_called_back_to_back_on_predictor(self):
ols.summary()
ols.summary()

def test_pipeline_summary_is_refreshed_after_refitting(self):
predictor = OrdinaryLeastSquaresRegressor(normalize='No', l2_regularization=0)
pipeline = Pipeline([predictor])

pipeline.fit([0,1,2,3], [1,2,3,4])
summary1 = pipeline.summary()

pipeline.fit([0,1,2,3], [2,5,8,11])
summary2 = pipeline.summary()

self.assertFalse(summary1.equals(summary2))

def test_predictor_summary_is_refreshed_after_refitting(self):
predictor = OrdinaryLeastSquaresRegressor(normalize='No', l2_regularization=0)

predictor.fit([0,1,2,3], [1,2,3,4])
summary1 = predictor.summary()

predictor.fit([0,1,2,3], [2,5,8,11])
summary2 = predictor.summary()

self.assertFalse(summary1.equals(summary2))


if __name__ == '__main__':
unittest.main()

0 comments on commit b4ec723

Please sign in to comment.