Skip to content

Commit

Permalink
Merge pull request #704 from axsaucedo/tfserving_json_support
Browse files Browse the repository at this point in the history
TFServing Enabled Text Response and Fixed JSON Parse
  • Loading branch information
ukclivecox authored Jul 24, 2019
2 parents 304aab0 + 73bd7f4 commit 5a05a5c
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 26 deletions.
2 changes: 1 addition & 1 deletion integrations/tfserving/Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
IMAGE_VERSION=0.4
IMAGE_VERSION=0.5
IMAGE_NAME = docker.io/seldonio/tfserving-proxy

SELDON_CORE_DIR=../../..
Expand Down
64 changes: 39 additions & 25 deletions integrations/tfserving/TfServingProxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,18 @@
from tensorflow.python.saved_model import signature_constants
from tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis import prediction_service_pb2_grpc
from seldon_core.utils import get_data_from_proto, array_to_grpc_datadef
from seldon_core.utils import get_data_from_proto, array_to_grpc_datadef, json_to_seldon_message
from seldon_core.proto import prediction_pb2
from google.protobuf.json_format import ParseError

import requests
import json
import numpy as np

class TensorflowServerError(Exception):
import logging

log = logging.getLogger()

def __init__(self, message):
self.message = message

'''
A basic tensorflow serving proxy
Expand All @@ -30,8 +31,8 @@ def __init__(
signature_name=None,
model_input=None,
model_output=None):
print("rest_endpoint:",rest_endpoint)
print("grpc_endpoint:",grpc_endpoint)
log.warning("rest_endpoint:",rest_endpoint)
log.warning("grpc_endpoint:",grpc_endpoint)
if not grpc_endpoint is None:
self.grpc = True
channel = grpc.insecure_channel(grpc_endpoint)
Expand All @@ -50,40 +51,44 @@ def __init__(

# if we have a TFTensor message we got directly without converting the message otherwise we go the usual route
def predict_raw(self,request):
print("Predict raw")
log.debug("Predict raw")
request_data_type = request.WhichOneof("data_oneof")
default_data_type = request.data.WhichOneof("data_oneof")
print(default_data_type)
log.debug(str(request_data_type), str(default_data_type))
if default_data_type == "tftensor" and self.grpc:
tfrequest = predict_pb2.PredictRequest()
tfrequest.model_spec.name = self.model_name
tfrequest.model_spec.signature_name = self.signature_name
tfrequest.inputs[self.model_input].CopyFrom(request.data.tftensor)
result = self.stub.Predict(tfrequest)
print(result)
log.debug(result)
datadef = prediction_pb2.DefaultData(
tftensor=result.outputs[self.model_output]
)
return prediction_pb2.SeldonMessage(data=datadef)

elif default_data_type == "jsonData":
predictions = self.predict(request.jsonData, features_names=[])
return prediction_pb2.SeldonMessage(jsonData=predictions)
elif request_data_type == "jsonData":
features = get_data_from_proto(request)
predictions = self.predict(features, features_names=[])
try:
sm = json_to_seldon_message({"jsonData": predictions})
except ParseError as e:
sm = prediction_pb2.SeldonMessage(strData=predictions)
return sm

else:
features = get_data_from_proto(request)
datadef = request.data
data_type = request.WhichOneof("data_oneof")
predictions = self.predict(features, datadef.names)
predictions = np.array(predictions)

if data_type == "data":
default_data_type = request.data.WhichOneof("data_oneof")
else:
if request_data_type is not "data":
default_data_type = "tensor"

class_names = []
data = array_to_grpc_datadef(
predictions, class_names, default_data_type)

return prediction_pb2.SeldonMessage(data=data)


Expand All @@ -95,32 +100,41 @@ def predict(self,X,features_names=[]):
request.model_spec.signature_name = self.signature_name
request.inputs[self.model_input].CopyFrom(tf.contrib.util.make_tensor_proto(X.tolist(), shape=X.shape))
result = self.stub.Predict(request)
print(result)
log.debug("GRPC Response: ", str(result))
response = numpy.array(result.outputs[self.model_output].float_val)
if len(response.shape) == 1:
response = numpy.expand_dims(response, axis=0)
return response
else:
print(self.rest_endpoint)
log.debug(self.rest_endpoint)
if type(X) is dict:
print("JSON Request")
log.debug("JSON Request")
data = X
else:
print("Data Request")
log.debug("Data Request")
data = {"instances":X.tolist()}
if not self.signature_name is None:
data["signature_name"] = self.signature_name
print(data)
log.debug(str(data))

response = requests.post(self.rest_endpoint, data=json.dumps(data))

if response.status_code == 200:
print(response.json())
log.debug(response.text)
if type(X) is dict:
return response.json()
try:
return response.json()
except ValueError:
return response.text
else:
result = numpy.array(response.json()["predictions"])
if len(result.shape) == 1:
result = numpy.expand_dims(result, axis=0)
return result
else:
print("Error from server:",response)
return response.json()
log.warning("Error from server: "+ str(response) + " content: " + str(response.text))
try:
return response.json()
except ValueError:
return response.text

0 comments on commit 5a05a5c

Please sign in to comment.