diff --git a/api/predict.py b/api/predict.py index 02dbebc..4792023 100644 --- a/api/predict.py +++ b/api/predict.py @@ -1,5 +1,5 @@ from core.model import ModelWrapper -from flask_restplus import fields, abort +from flask_restplus import fields from werkzeug.datastructures import FileStorage from maxfw.core import MAX_API, PredictAPI @@ -58,11 +58,8 @@ def post(self): """Make a prediction given input data""" result = {'status': 'error'} args = input_parser.parse_args() - try: - input_data = args['file'].read() - image = self.model_wrapper._read_image(input_data) - except OSError as e: - abort(400, "Please submit a valid image in PNG, Tiff or JPEG format") + input_data = args['file'].read() + image = self.model_wrapper._read_image(input_data) label_preds = self.model_wrapper.predict(image) result['predictions'] = label_preds diff --git a/core/model.py b/core/model.py index 67d19b3..363c18c 100644 --- a/core/model.py +++ b/core/model.py @@ -1,9 +1,11 @@ +from maxfw.model import MAXModelWrapper + import io import logging import time from PIL import Image import numpy as np -from maxfw.model import MAXModelWrapper +from flask_restplus import abort from core.tf_pose.estimator import TfPoseEstimator from config import DEFAULT_MODEL_PATH, DEFAULT_IMAGE_SIZE, MODEL_NAME @@ -35,10 +37,16 @@ def __init__(self, path=DEFAULT_MODEL_PATH): logger.info("W = {}, H = {} ".format(self.w, self.h)) def _read_image(self, image_data): - image = Image.open(io.BytesIO(image_data)) - # Convert RGB to BGR for OpenCV. - image = np.array(image)[:, :, ::-1] - return image + try: + image = Image.open(io.BytesIO(image_data)) + if image.mode is not 'RGB': + image = image.convert('RGB') + # Convert RGB to BGR for OpenCV. + image = np.array(image)[:, :, ::-1] + return image + except IOError as e: + logger.error(str(e)) + abort(400, "Please submit a valid image in PNG, TIFF or JPEG format") def _predict(self, x): t = time.time() diff --git a/tests/Pilots.jpg b/tests/Pilots.jpg new file mode 100644 index 0000000..82bf41d Binary files /dev/null and b/tests/Pilots.jpg differ diff --git a/tests/Pilots.png b/tests/Pilots.png new file mode 100644 index 0000000..fa1bb7b Binary files /dev/null and b/tests/Pilots.png differ diff --git a/tests/Pilots.tiff b/tests/Pilots.tiff new file mode 100644 index 0000000..786bb6e Binary files /dev/null and b/tests/Pilots.tiff differ diff --git a/tests/test.py b/tests/test.py index 23443d1..7009380 100644 --- a/tests/test.py +++ b/tests/test.py @@ -29,17 +29,7 @@ def test_metadata(): assert metadata['license'] == 'Apache License 2.0' -def test_predict(): - - model_endpoint = 'http://localhost:5000/model/predict' - - # Test by the image with multiple faces - img1_path = 'assets/Pilots.jpg' - - with open(img1_path, 'rb') as file: - file_form = {'file': (img1_path, file, 'image/jpeg')} - r = requests.post(url=model_endpoint, files=file_form) - +def _check_response(r): assert r.status_code == 200 response = r.json() @@ -49,6 +39,20 @@ def test_predict(): assert len(response['predictions'][0]['pose_lines']) > 0 assert len(response['predictions'][0]['body_parts']) > 0 + +def test_predict(): + + model_endpoint = 'http://localhost:5000/model/predict' + formats = ['jpg', 'png', 'tiff'] + img_path = 'tests/Pilots.{}' + + for f in formats: + p = img_path.format(f) + with open(p, 'rb') as file: + file_form = {'file': (p, file, 'image/{}'.format(f))} + r = requests.post(url=model_endpoint, files=file_form) + _check_response(r) + # Test by the image without faces img2_path = 'assets/IBM.jpeg'