Skip to content

Commit

Permalink
Merge pull request #989 from lennon310/status_code
Browse files Browse the repository at this point in the history
Set Http Status Code in REST Predict
  • Loading branch information
seldondev authored Oct 24, 2019
2 parents 95de04c + a465dec commit c10c321
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 1 deletion.
6 changes: 5 additions & 1 deletion python/seldon_core/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,12 @@ def Predict():
requestJson = get_request()
logger.debug("REST Request: %s", request)
response = seldon_core.seldon_methods.predict(user_model, requestJson)
json_response = jsonify(response)
if 'status' in response and 'code' in response['status']:
json_response.status_code = response['status']['code']

logger.debug("REST Response: %s", response)
return jsonify(response)
return json_response

@app.route("/send-feedback", methods=["GET", "POST"])
@app.route("/api/v0.1/feedback", methods=["POST"])
Expand Down
73 changes: 73 additions & 0 deletions python/tests/test_model_microservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,57 @@ def send_feedback_rest(self, request):
def send_feedback_grpc(self, request):
print("Feedback called")

class UserObjectLowLevelWithStatusInResponse(SeldonComponent):
def __init__(self, metrics_ok=True, ret_nparray=False):
self.metrics_ok = metrics_ok
self.ret_nparray = ret_nparray
self.nparray = np.array([1, 2, 3])

def predict_rest(self, request):
return {"data": {"ndarray": [9, 9]}, "status": {"code": 400, "status": "FAILURE"}}

def predict_grpc(self, request):
arr = np.array([9, 9])
datadef = prediction_pb2.DefaultData(
tensor=prediction_pb2.Tensor(
shape=(2, 1),
values=arr
)
)
request = prediction_pb2.SeldonMessage(data=datadef)
return request

def send_feedback_rest(self, request):
print("Feedback called")

def send_feedback_grpc(self, request):
print("Feedback called")


class UserObjectLowLevelWithStatusInResponseWithPredictRaw(SeldonComponent):
def __init__(self, check_name):
self.check_name=check_name

def predict_raw(self, msg):
msg=json_to_seldon_message(msg)
if self.check_name == 'img':
file_data=msg.binData
img = Image.open(io.BytesIO (file_data))
img.verify()
return {"meta": seldon_message_to_json(msg.meta),
"data": {"ndarray": [rs232_checksum(file_data).decode('utf-8')]},
"status": {"code": 400, "status": "FAILURE"}}
elif self.check_name == 'txt':
file_data=msg.binData
return {"meta": seldon_message_to_json(msg.meta),
"data": {"ndarray": [file_data.decode('utf-8')]},
"status": {"code": 400, "status": "FAILURE"}}
elif self.check_name == 'strData':
file_data=msg.strData
return {"meta": seldon_message_to_json(msg.meta),
"data": {"ndarray": [file_data]},
"status": {"code": 400, "status": "FAILURE"}}


class UserObjectLowLevelWithPredictRaw(SeldonComponent):
def __init__(self, check_name):
Expand Down Expand Up @@ -259,6 +310,18 @@ def test_model_lowlevel_multi_form_data_strData_ok():
)



def test_model_lowlevel_multi_form_data_strData_non200status():
user_object = UserObjectLowLevelWithStatusInResponseWithPredictRaw('strData')
app = get_rest_microservice(user_object)
client = app.test_client()
rv = client.post('/predict',data={"meta":'{"puid":"1234"}',"strData":(f'./tests/resources/test.txt','test.txt')},content_type='multipart/form-data')
j = json.loads(rv.data)
assert rv.status_code == 400
assert j["meta"]["puid"] == "1234"
assert j["data"]["ndarray"][0] == "this is test file for testing multipart/form-data input\n"


def test_model_multi_form_data_ok():
user_object = UserObject()
app = get_rest_microservice(user_object)
Expand Down Expand Up @@ -302,6 +365,16 @@ def test_model_feedback_lowlevel_ok():
assert rv.status_code == 200


def test_model_non200status_lowlevel():
user_object = UserObjectLowLevelWithStatusInResponse()
app = get_rest_microservice(user_object)
client = app.test_client()
rv = client.get('/predict?json={"request":{"data":{"ndarray":[]}},"reward":1.0}')
j = json.loads(rv.data)
print(j)
assert rv.status_code == 400


@skipif_tf_missing
def test_model_tftensor_ok():
user_object = UserObject()
Expand Down

0 comments on commit c10c321

Please sign in to comment.