From ea62c095975efd872044084cafcdd8a981c9ad1a Mon Sep 17 00:00:00 2001 From: Teagan glenn Date: Thu, 2 Apr 2020 12:46:09 -0600 Subject: [PATCH] Octet should accept byte array payload --- .../engine/api/rest/RestClientController.java | 32 ++++++++++-- .../config/MessageConvertersConfig.java | 49 ------------------- .../api/rest/TestRestClientController.java | 14 +++--- 3 files changed, 34 insertions(+), 61 deletions(-) delete mode 100644 engine/src/main/java/io/seldon/engine/config/MessageConvertersConfig.java diff --git a/engine/src/main/java/io/seldon/engine/api/rest/RestClientController.java b/engine/src/main/java/io/seldon/engine/api/rest/RestClientController.java index 3887c5a36c..75e39030ec 100644 --- a/engine/src/main/java/io/seldon/engine/api/rest/RestClientController.java +++ b/engine/src/main/java/io/seldon/engine/api/rest/RestClientController.java @@ -165,7 +165,7 @@ public ResponseEntity predictions_text(RequestEntity requestEnti method = RequestMethod.POST, consumes = "application/octet-stream", produces = "application/json; charset=utf-8") - public ResponseEntity predictions_binary(RequestEntity requestEntity) { + public ResponseEntity predictions_binary(RequestEntity requestEntity) { logger.debug("Received binary predict request"); Span tracingSpan = null; if (tracingProvider.isActive()) { @@ -174,10 +174,7 @@ public ResponseEntity predictions_binary(RequestEntity requ tracer.scopeManager().activate(tracingSpan); } try { - return _predictions(toByteArray(requestEntity.getBody())); - } catch (IOException e) { - logger.error("Bad request", e); - throw new APIException(ApiExceptionType.REQUEST_IO_EXCEPTION, e.getMessage()); + return _predictions(requestEntity.getBody()); } finally { if (tracingSpan != null) { tracingSpan.finish(); @@ -276,6 +273,31 @@ private ResponseEntity _predictions(String json) { throw new APIException(ApiExceptionType.ENGINE_INVALID_JSON, json); } + return _predictions(request); + } + + /** + * It calls the prediction service for the input byte array. + * + * @param bytes - Input byte array to predict REST api + * @return The response for prediction service + */ + private ResponseEntity _predictions(byte[] bytes) { + SeldonMessage request = SeldonMessage.newBuilder() + .setBinData(ByteString.copyFrom(bytes)) + .build(); + + return _predictions(request); + } + + /** + * It calls the prediction service for the request. It is the base function for all forms of + * request Content-type + * + * @param request - The SeldonMessage request to the predict REST api + * @return The response for prediction service + */ + private ResponseEntity _predictions(SeldonMessage request) { try { SeldonMessage response = predictionService.predict(request); String responseJson = ProtoBufUtils.toJson(response); diff --git a/engine/src/main/java/io/seldon/engine/config/MessageConvertersConfig.java b/engine/src/main/java/io/seldon/engine/config/MessageConvertersConfig.java deleted file mode 100644 index a05b6baa7a..0000000000 --- a/engine/src/main/java/io/seldon/engine/config/MessageConvertersConfig.java +++ /dev/null @@ -1,49 +0,0 @@ -package io.seldon.engine.config; - -import org.jetbrains.annotations.NotNull; -import org.springframework.context.annotation.Configuration; -import org.springframework.http.HttpInputMessage; -import org.springframework.http.HttpOutputMessage; -import org.springframework.http.MediaType; -import org.springframework.http.converter.AbstractHttpMessageConverter; -import org.springframework.http.converter.HttpMessageConverter; -import org.springframework.http.converter.HttpMessageNotReadableException; -import org.springframework.http.converter.HttpMessageNotWritableException; -import org.springframework.web.servlet.config.annotation.WebMvcConfigurationSupport; - -import java.io.IOException; -import java.io.InputStream; -import java.util.List; - -import static io.seldon.engine.util.StreamUtils.copyStream; - -/** - * Configure Spring Boot to allow upload of octet-stream. - */ -@Configuration -public class MessageConvertersConfig extends WebMvcConfigurationSupport { - - @Override - protected void configureMessageConverters(List> converters) { - converters.add(new AbstractHttpMessageConverter(MediaType.APPLICATION_OCTET_STREAM) { - protected boolean supports(@NotNull Class clazz) { - return InputStream.class.isAssignableFrom(clazz); - } - - @NotNull - protected InputStream readInternal( - @NotNull Class clazz, - @NotNull HttpInputMessage inputMessage) throws IOException, HttpMessageNotReadableException { - return inputMessage.getBody(); - } - - protected void writeInternal( - @NotNull InputStream inputStream, - @NotNull HttpOutputMessage outputMessage) throws IOException, HttpMessageNotWritableException { - copyStream(inputStream, outputMessage.getBody()); - } - }); - - super.configureMessageConverters(converters); - } -} diff --git a/engine/src/test/java/io/seldon/engine/api/rest/TestRestClientController.java b/engine/src/test/java/io/seldon/engine/api/rest/TestRestClientController.java index 7fcb7c6e3a..5b0af24c19 100644 --- a/engine/src/test/java/io/seldon/engine/api/rest/TestRestClientController.java +++ b/engine/src/test/java/io/seldon/engine/api/rest/TestRestClientController.java @@ -347,7 +347,7 @@ public void testPredict_b64img_as_text() throws Exception { MvcResult res = mvc.perform( MockMvcRequestBuilders.post("/api/v1.0/predictions") - .accept(MediaType.APPLICATION_JSON) + .accept(MediaType.APPLICATION_JSON_UTF8) .content(base64Image) .contentType(MediaType.TEXT_PLAIN)) .andReturn(); @@ -364,7 +364,7 @@ public void testPredict_b64img_as_text() throws Exception { Assert.assertEquals(base64Image, seldonMessage.getStrData()); // No Puid specified in request, verify response generated random of correct length Assert.assertNotNull(seldonMessage.getMeta().getPuid()); - Assert.assertTrue(Pattern.matches("[a-z0-7]{26}", seldonMessage.getMeta().getPuid())); + Assert.assertTrue(Pattern.matches("[a-z0-9]{26}", seldonMessage.getMeta().getPuid())); } @Test @@ -373,15 +373,15 @@ public void testPredict_img_as_binary() throws Exception { MvcResult res = mvc.perform( MockMvcRequestBuilders.post("/api/v1.0/predictions") - .accept(MediaType.APPLICATION_JSON) + .accept(MediaType.APPLICATION_JSON_UTF8) .content(imageBytes) .contentType(MediaType.APPLICATION_OCTET_STREAM)) .andReturn(); - String response = res.getResponse().getContentAsString(); - System.out.println(response); + byte[] response = res.getResponse().getContentAsByteArray(); + System.out.println(String(response); Assert.assertEquals(200, res.getResponse().getStatus()); SeldonMessage.Builder builder = SeldonMessage.newBuilder(); - ProtoBufUtils.updateMessageBuilderFromJson(builder, response); + ProtoBufUtils.updateMessageBuilderFromJson(builder, new String(response)); SeldonMessage seldonMessage = builder.build(); Assert.assertEquals(3, seldonMessage.getMeta().getMetricsCount()); Assert.assertEquals("COUNTER", seldonMessage.getMeta().getMetrics(0).getType().toString()); @@ -390,6 +390,6 @@ public void testPredict_img_as_binary() throws Exception { Assert.assertEquals(imageBytes, seldonMessage.getBinData().toByteArray()); // No Puid specified in request, verify response generated random of correct length Assert.assertNotNull(seldonMessage.getMeta().getPuid()); - Assert.assertTrue(Pattern.matches("[a-z0-7]{26}", seldonMessage.getMeta().getPuid())); + Assert.assertTrue(Pattern.matches("[a-z0-9]{26}", seldonMessage.getMeta().getPuid())); } }