diff --git a/servers/mlflowserver/mlflowserver/MLFlowServer.py b/servers/mlflowserver/mlflowserver/MLFlowServer.py index f6ace0d7a3..97069527e0 100644 --- a/servers/mlflowserver/mlflowserver/MLFlowServer.py +++ b/servers/mlflowserver/mlflowserver/MLFlowServer.py @@ -7,7 +7,7 @@ from mlflow import pyfunc from seldon_core import Storage -from seldon_core.user_model import SeldonComponent +from seldon_core.user_model import SeldonComponent, SeldonNotImplementedError from typing import Dict, List, Union logger = logging.getLogger() @@ -23,6 +23,7 @@ def __init__(self, model_uri: str, xtype: str = "ndarray"): self.model_uri = model_uri self.xtype = xtype self.ready = False + self.column_names = None def load(self): logger.info(f"Downloading model from {self.model_uri}") @@ -47,6 +48,11 @@ def predict( df = pd.DataFrame(data=X) result = self._model.predict(df) + if isinstance(result, pd.DataFrame): + if self.column_names is None: + self.column_names = result.columns.to_list() + result = result.to_numpy() + logger.debug(f"Prediction result: {result}") return result @@ -64,3 +70,9 @@ def init_metadata(self): f"metadata file {file_path} present but does not contain valid yaml" ) return {} + + def class_names(self): + if self.column_names is not None: + return self.column_names + + raise SeldonNotImplementedError("prediction result is not a dataframe")