Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add custom inference fn suport to the sklearn model handlers #23642

Merged
merged 9 commits into from
Nov 4, 2022
39 changes: 28 additions & 11 deletions sdks/python/apache_beam/ml/inference/sklearn_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import pickle
import sys
from typing import Any
from typing import Callable
from typing import Dict
from typing import Iterable
from typing import Optional
Expand Down Expand Up @@ -83,14 +84,21 @@ def _convert_to_result(
]
return [PredictionResult(x, y) for x, y in zip(batch, predictions)]

def _default_numpy_inference_fn(
model: BaseEstimator, batch: Sequence[numpy.ndarray], inference_args: Optional[Dict[str, Any]] = None
) -> Any:
# vectorize data for better performance
vectorized_batch = numpy.stack(batch, axis=0)
return model.predict(vectorized_batch)

class SklearnModelHandlerNumpy(ModelHandler[numpy.ndarray,
PredictionResult,
BaseEstimator]):
def __init__(
self,
model_uri: str,
model_file_type: ModelFileType = ModelFileType.PICKLE):
model_file_type: ModelFileType = ModelFileType.PICKLE,
inference_fn: Optional[Callable[[BaseEstimator, Sequence[numpy.ndarray], Optional[Dict[str, Any]]], Any]] = _default_numpy_inference_fn):
jrmccluskey marked this conversation as resolved.
Show resolved Hide resolved
""" Implementation of the ModelHandler interface for scikit-learn
using numpy arrays as input.

Expand All @@ -102,9 +110,12 @@ def __init__(
model_uri: The URI to where the model is saved.
model_file_type: The method of serialization of the argument.
default=pickle
inference_fn: The inference function to use.
default=_default_numpy_inference_fn
"""
self._model_uri = model_uri
self._model_file_type = model_file_type
self._model_inference_fn = inference_fn

def load_model(self) -> BaseEstimator:
"""Loads and initializes a model for processing."""
Expand All @@ -128,9 +139,7 @@ def run_inference(
Returns:
An Iterable of type PredictionResult.
"""
# vectorize data for better performance
vectorized_batch = numpy.stack(batch, axis=0)
predictions = model.predict(vectorized_batch)
predictions = self._model_inference_fn(model, batch, inference_args)

return _convert_to_result(batch, predictions)

Expand All @@ -148,6 +157,15 @@ def get_metrics_namespace(self) -> str:
"""
return 'BeamML_Sklearn'

def _default_pandas_inference_fn(model: BaseEstimator, batch: Sequence[pandas.DataFrame], inference_args: Optional[Dict[str, Any]] = None
) -> Any:
# vectorize data for better performance
vectorized_batch = pandas.concat(batch, axis=0)
predictions = model.predict(vectorized_batch)
splits = [
vectorized_batch.iloc[[i]] for i in range(vectorized_batch.shape[0])
]
return predictions, splits

@experimental(extra_message="No backwards-compatibility guarantees.")
class SklearnModelHandlerPandas(ModelHandler[pandas.DataFrame,
Expand All @@ -156,7 +174,8 @@ class SklearnModelHandlerPandas(ModelHandler[pandas.DataFrame,
def __init__(
self,
model_uri: str,
model_file_type: ModelFileType = ModelFileType.PICKLE):
model_file_type: ModelFileType = ModelFileType.PICKLE,
inference_fn: Optional[Callable[[BaseEstimator, Sequence[pandas.DataFrame], Optional[Dict[str, Any]]], Any]] = _default_pandas_inference_fn):
jrmccluskey marked this conversation as resolved.
Show resolved Hide resolved
"""Implementation of the ModelHandler interface for scikit-learn that
supports pandas dataframes.

Expand All @@ -171,9 +190,12 @@ def __init__(
model_uri: The URI to where the model is saved.
model_file_type: The method of serialization of the argument.
default=pickle
inference_fn: The inference function to use.
default=_default_pandas_inference_fn
"""
self._model_uri = model_uri
self._model_file_type = model_file_type
self._model_inference_fn = inference_fn

def load_model(self) -> BaseEstimator:
"""Loads and initializes a model for processing."""
Expand Down Expand Up @@ -203,12 +225,7 @@ def run_inference(
if dataframe.shape[0] != 1:
raise ValueError('Only dataframes with single rows are supported.')

# vectorize data for better performance
vectorized_batch = pandas.concat(batch, axis=0)
predictions = model.predict(vectorized_batch)
splits = [
vectorized_batch.iloc[[i]] for i in range(vectorized_batch.shape[0])
]
predictions, splits = self._model_inference_fn(model, batch, inference_args)

return _convert_to_result(splits, predictions)

Expand Down