diff --git a/python/tests/test_model_microservice.py b/python/tests/test_model_microservice.py index 816a94b8b2..ad39cbab71 100644 --- a/python/tests/test_model_microservice.py +++ b/python/tests/test_model_microservice.py @@ -97,6 +97,57 @@ def send_feedback_rest(self, request): def send_feedback_grpc(self, request): print("Feedback called") +class UserObjectLowLevelWithStatusInResponse(SeldonComponent): + def __init__(self, metrics_ok=True, ret_nparray=False): + self.metrics_ok = metrics_ok + self.ret_nparray = ret_nparray + self.nparray = np.array([1, 2, 3]) + + def predict_rest(self, request): + return {"data": {"ndarray": [9, 9]}, "status": {"code": 400, "status": "FAILURE"}} + + def predict_grpc(self, request): + arr = np.array([9, 9]) + datadef = prediction_pb2.DefaultData( + tensor=prediction_pb2.Tensor( + shape=(2, 1), + values=arr + ) + ) + request = prediction_pb2.SeldonMessage(data=datadef) + return request + + def send_feedback_rest(self, request): + print("Feedback called") + + def send_feedback_grpc(self, request): + print("Feedback called") + + +class UserObjectLowLevelWithStatusInResponseWithPredictRaw(SeldonComponent): + def __init__(self, check_name): + self.check_name=check_name + + def predict_raw(self, msg): + msg=json_to_seldon_message(msg) + if self.check_name == 'img': + file_data=msg.binData + img = Image.open(io.BytesIO (file_data)) + img.verify() + return {"meta": seldon_message_to_json(msg.meta), + "data": {"ndarray": [rs232_checksum(file_data).decode('utf-8')]}, + "status": {"code": 400, "status": "FAILURE"}} + elif self.check_name == 'txt': + file_data=msg.binData + return {"meta": seldon_message_to_json(msg.meta), + "data": {"ndarray": [file_data.decode('utf-8')]}, + "status": {"code": 400, "status": "FAILURE"}} + elif self.check_name == 'strData': + file_data=msg.strData + return {"meta": seldon_message_to_json(msg.meta), + "data": {"ndarray": [file_data]}, + "status": {"code": 400, "status": "FAILURE"}} + class UserObjectLowLevelWithPredictRaw(SeldonComponent): def __init__(self, check_name): @@ -211,6 +262,18 @@ def test_model_lowlevel_multi_form_data_strData_ok(): assert j["meta"]["puid"] == "1234" assert j["data"]["ndarray"][0] == "this is test file for testing multipart/form-data input\n" + +def test_model_lowlevel_multi_form_data_strData_non200status(): + user_object = UserObjectLowLevelWithStatusInResponseWithPredictRaw('strData') + app = get_rest_microservice(user_object) + client = app.test_client() + rv = client.post('/predict',data={"meta":'{"puid":"1234"}',"strData":(f'./tests/resources/test.txt','test.txt')},content_type='multipart/form-data') + j = json.loads(rv.data) + assert rv.status_code == 400 + assert j["meta"]["puid"] == "1234" + assert j["data"]["ndarray"][0] == "this is test file for testing multipart/form-data input\n" + + def test_model_multi_form_data_ok(): user_object = UserObject() app = get_rest_microservice(user_object) @@ -245,6 +308,16 @@ def test_model_feedback_lowlevel_ok(): assert rv.status_code == 200 +def test_model_non200status_lowlevel(): + user_object = UserObjectLowLevelWithStatusInResponse() + app = get_rest_microservice(user_object) + client = app.test_client() + rv = client.get('/predict?json={"request":{"data":{"ndarray":[]}},"reward":1.0}') + j = json.loads(rv.data) + print(j) + assert rv.status_code == 400 + + def test_model_tftensor_ok(): user_object = UserObject() app = get_rest_microservice(user_object)