Skip to content
This repository has been archived by the owner on Nov 16, 2023. It is now read-only.

Save predictor_model when pickling a pipeline. #295

Merged
merged 7 commits into from
Oct 3, 2019
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 29 additions & 5 deletions src/python/nimbusml/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2482,7 +2482,7 @@ def load_model(self, src):
self.steps = []

def __getstate__(self):
odict = {'export_version': 1}
odict = {'export_version': 2}

if hasattr(self, 'steps'):
odict['steps'] = self.steps
Expand All @@ -2494,25 +2494,49 @@ def __getstate__(self):
with open(self.model, "rb") as f:
odict['modelbytes'] = f.read()

if (hasattr(self, 'predictor_model') and
self.predictor_model is not None and
os.path.isfile(self.predictor_model)):

with open(self.predictor_model, "rb") as f:
odict['predictor_model_bytes'] = f.read()

return odict

def __setstate__(self, state):
self.steps = []
self.model = None
self.random_state = None

for k, v in state.items():
if k not in {'modelbytes', 'export_version'}:
setattr(self, k, v)
if state.get('export_version', 0) == 0:
# Pickled pipelines which were created
# before export_version was added used
# the default implementation which uses
# the instance’s __dict__.
if 'steps' in state:
self.steps = state['steps']

elif state.get('export_version', 0) in {1, 2}:
if 'steps' in state:
self.steps = state['steps']

if state.get('export_version', 0) == 1:
if 'modelbytes' in state:
(fd, modelfile) = tempfile.mkstemp()
fl = os.fdopen(fd, "wb")
fl.write(state['modelbytes'])
fl.close()
self.model = modelfile

if 'predictor_model_bytes' in state:
(fd, modelfile) = tempfile.mkstemp()
fl = os.fdopen(fd, "wb")
fl.write(state['predictor_model_bytes'])
fl.close()
self.predictor_model = modelfile

else:
raise ValueError('Pipeline version not supported.')

@trace
def score(
self,
Expand Down
43 changes: 42 additions & 1 deletion src/python/nimbusml/tests/pipeline/test_load_save.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@
import pickle
import unittest

import numpy as np
import pandas as pd

from nimbusml import Pipeline
from nimbusml.datasets import get_dataset
from nimbusml.feature_extraction.categorical import OneHotVectorizer
from nimbusml.linear_model import FastLinearBinaryClassifier
from nimbusml.linear_model import FastLinearBinaryClassifier, OnlineGradientDescentRegressor
from nimbusml.utils import get_X_y
from numpy.testing import assert_almost_equal

Expand Down Expand Up @@ -326,5 +329,43 @@ def test_predictor_loaded_from_zip_has_feature_contributions(self):

os.remove(model_filename)

def test_pickled_pipeline_with_predictor_model(self):
train_data = {'c1': [1, 2, 3, 4], 'c2': [2, 3, 4, 5]}
train_df = pd.DataFrame(train_data).astype({'c1': np.float64,
'c2': np.float64})

test_data = {'c1': [1.5, 2.3, 3.7], 'c2': [2.2, 4.9, 2.7]}
test_df = pd.DataFrame(test_data).astype({'c1': np.float64,
'c2': np.float64})

# Create predictor model and use it to predict
pipeline = Pipeline([OnlineGradientDescentRegressor(label='c2')], random_state=0)
pipeline.fit(train_df, output_predictor_model=True)
result_1 = pipeline.predict(test_df)

self.assertTrue(pipeline.model)
self.assertTrue(pipeline.predictor_model)
self.assertNotEqual(pipeline.model, pipeline.predictor_model)

pickle_filename = 'nimbusml_model.p'
with open(pickle_filename, 'wb') as f:
pickle.dump(pipeline, f)

os.remove(pipeline.model)
os.remove(pipeline.predictor_model)

with open(pickle_filename, "rb") as f:
pipeline_pickle = pickle.load(f)

os.remove(pickle_filename)

# Load predictor pipeline and score data
predictor_pipeline = Pipeline()
predictor_pipeline.load_model(pipeline_pickle.predictor_model)
result_2 = predictor_pipeline.predict(test_df)

self.assertTrue(result_1.equals(result_2))


if __name__ == '__main__':
unittest.main()