Skip to content

Commit

Permalink
Add custom inference fn suport to the sklearn model handlers (#23642)
Browse files Browse the repository at this point in the history
* Add custom inference fn suport to the sklearn model handlers

* Clean up import order

* Update typing, add numpy unit test

* Add Pandas unit test

* Formatting, linting

* yapf run

* Remove trailing whitespace

* import order

* Change inference_fn to keyword-only arg
  • Loading branch information
jrmccluskey authored Nov 4, 2022
1 parent 8617b86 commit 1cfdb12
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 11 deletions.
55 changes: 44 additions & 11 deletions sdks/python/apache_beam/ml/inference/sklearn_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import pickle
import sys
from typing import Any
from typing import Callable
from typing import Dict
from typing import Iterable
from typing import Optional
Expand All @@ -45,6 +46,9 @@
'SklearnModelHandlerPandas',
]

NumpyInferenceFn = Callable[
[BaseEstimator, Sequence[numpy.ndarray], Optional[Dict[str, Any]]], Any]


class ModelFileType(enum.Enum):
"""Defines how a model file is serialized. Options are pickle or joblib."""
Expand Down Expand Up @@ -84,13 +88,24 @@ def _convert_to_result(
return [PredictionResult(x, y) for x, y in zip(batch, predictions)]


def _default_numpy_inference_fn(
model: BaseEstimator,
batch: Sequence[numpy.ndarray],
inference_args: Optional[Dict[str, Any]] = None) -> Any:
# vectorize data for better performance
vectorized_batch = numpy.stack(batch, axis=0)
return model.predict(vectorized_batch)


class SklearnModelHandlerNumpy(ModelHandler[numpy.ndarray,
PredictionResult,
BaseEstimator]):
def __init__(
self,
model_uri: str,
model_file_type: ModelFileType = ModelFileType.PICKLE):
model_file_type: ModelFileType = ModelFileType.PICKLE,
*,
inference_fn: NumpyInferenceFn = _default_numpy_inference_fn):
""" Implementation of the ModelHandler interface for scikit-learn
using numpy arrays as input.
Expand All @@ -102,9 +117,12 @@ def __init__(
model_uri: The URI to where the model is saved.
model_file_type: The method of serialization of the argument.
default=pickle
inference_fn: The inference function to use.
default=_default_numpy_inference_fn
"""
self._model_uri = model_uri
self._model_file_type = model_file_type
self._model_inference_fn = inference_fn

def load_model(self) -> BaseEstimator:
"""Loads and initializes a model for processing."""
Expand All @@ -128,9 +146,7 @@ def run_inference(
Returns:
An Iterable of type PredictionResult.
"""
# vectorize data for better performance
vectorized_batch = numpy.stack(batch, axis=0)
predictions = model.predict(vectorized_batch)
predictions = self._model_inference_fn(model, batch, inference_args)

return _convert_to_result(batch, predictions)

Expand All @@ -149,14 +165,33 @@ def get_metrics_namespace(self) -> str:
return 'BeamML_Sklearn'


PandasInferenceFn = Callable[
[BaseEstimator, Sequence[pandas.DataFrame], Optional[Dict[str, Any]]], Any]


def _default_pandas_inference_fn(
model: BaseEstimator,
batch: Sequence[pandas.DataFrame],
inference_args: Optional[Dict[str, Any]] = None) -> Any:
# vectorize data for better performance
vectorized_batch = pandas.concat(batch, axis=0)
predictions = model.predict(vectorized_batch)
splits = [
vectorized_batch.iloc[[i]] for i in range(vectorized_batch.shape[0])
]
return predictions, splits


@experimental(extra_message="No backwards-compatibility guarantees.")
class SklearnModelHandlerPandas(ModelHandler[pandas.DataFrame,
PredictionResult,
BaseEstimator]):
def __init__(
self,
model_uri: str,
model_file_type: ModelFileType = ModelFileType.PICKLE):
model_file_type: ModelFileType = ModelFileType.PICKLE,
*,
inference_fn: PandasInferenceFn = _default_pandas_inference_fn):
"""Implementation of the ModelHandler interface for scikit-learn that
supports pandas dataframes.
Expand All @@ -171,9 +206,12 @@ def __init__(
model_uri: The URI to where the model is saved.
model_file_type: The method of serialization of the argument.
default=pickle
inference_fn: The inference function to use.
default=_default_pandas_inference_fn
"""
self._model_uri = model_uri
self._model_file_type = model_file_type
self._model_inference_fn = inference_fn

def load_model(self) -> BaseEstimator:
"""Loads and initializes a model for processing."""
Expand Down Expand Up @@ -203,12 +241,7 @@ def run_inference(
if dataframe.shape[0] != 1:
raise ValueError('Only dataframes with single rows are supported.')

# vectorize data for better performance
vectorized_batch = pandas.concat(batch, axis=0)
predictions = model.predict(vectorized_batch)
splits = [
vectorized_batch.iloc[[i]] for i in range(vectorized_batch.shape[0])
]
predictions, splits = self._model_inference_fn(model, batch, inference_args)

return _convert_to_result(splits, predictions)

Expand Down
65 changes: 65 additions & 0 deletions sdks/python/apache_beam/ml/inference/sklearn_inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,17 @@
import sys
import tempfile
import unittest
from typing import Any
from typing import Dict
from typing import Optional
from typing import Sequence

import joblib
import numpy
import pandas
from sklearn import linear_model
from sklearn import svm
from sklearn.base import BaseEstimator
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder
Expand Down Expand Up @@ -150,6 +155,27 @@ def convert_inference_to_floor(prediction_result):
return math.floor(prediction_result.inference)


def alternate_numpy_inference_fn(
model: BaseEstimator,
batch: Sequence[numpy.ndarray],
inference_args: Optional[Dict[str, Any]] = None) -> Any:
return [0]


def alternate_pandas_inference_fn(
model: BaseEstimator,
batch: Sequence[pandas.DataFrame],
inference_args: Optional[Dict[str, Any]] = None) -> Any:
# vectorize data for better performance
vectorized_batch = pandas.concat(batch, axis=0)
predictions = model.predict(vectorized_batch)
splits = [
vectorized_batch.iloc[[i]] for i in range(vectorized_batch.shape[0])
]
predictions = predictions - 1
return predictions, splits


class SkLearnRunInferenceTest(unittest.TestCase):
def setUp(self):
self.tmpdir = tempfile.mkdtemp()
Expand All @@ -172,6 +198,22 @@ def test_predict_output(self):
for actual, expected in zip(inferences, expected_predictions):
self.assertTrue(_compare_prediction_result(actual, expected))

def test_custom_inference_fn(self):
fake_model = FakeModel()
inference_runner = SklearnModelHandlerNumpy(
model_uri='unused', inference_fn=alternate_numpy_inference_fn)
batched_examples = [
numpy.array([1, 2, 3]), numpy.array([4, 5, 6]), numpy.array([7, 8, 9])
]
expected_predictions = [
PredictionResult(numpy.array([1, 2, 3]), 0),
PredictionResult(numpy.array([4, 5, 6]), 0),
PredictionResult(numpy.array([7, 8, 9]), 0)
]
inferences = inference_runner.run_inference(batched_examples, fake_model)
for actual, expected in zip(inferences, expected_predictions):
self.assertTrue(_compare_prediction_result(actual, expected))

def test_predict_output_dict(self):
fake_model = FakeNumpyModelDictOut()
inference_runner = SklearnModelHandlerNumpy(model_uri='unused')
Expand Down Expand Up @@ -296,6 +338,29 @@ def test_pipeline_pandas(self):
assert_that(
actual, equal_to(expected, equals_fn=_compare_dataframe_predictions))

def test_pipeline_pandas_custom_inference(self):
temp_file_name = self.tmpdir + os.sep + 'pickled_file'
with open(temp_file_name, 'wb') as file:
pickle.dump(build_pandas_pipeline(), file)
with TestPipeline() as pipeline:
dataframe = pandas_dataframe()
splits = [dataframe.loc[[i]] for i in dataframe.index]
pcoll = pipeline | 'start' >> beam.Create(splits)
actual = pcoll | RunInference(
SklearnModelHandlerPandas(
model_uri=temp_file_name,
inference_fn=alternate_pandas_inference_fn))

expected = [
PredictionResult(splits[0], 4),
PredictionResult(splits[1], 7),
PredictionResult(splits[2], 0),
PredictionResult(splits[3], 0),
PredictionResult(splits[4], 1),
]
assert_that(
actual, equal_to(expected, equals_fn=_compare_dataframe_predictions))

def test_pipeline_pandas_dict_out(self):
temp_file_name = self.tmpdir + os.sep + 'pickled_file'
with open(temp_file_name, 'wb') as file:
Expand Down

0 comments on commit 1cfdb12

Please sign in to comment.