From b47b1361715eaa36103ab6290848ea1ae378ba30 Mon Sep 17 00:00:00 2001 From: ronald-d-rogers Date: Wed, 1 Nov 2017 21:19:51 -0400 Subject: [PATCH 1/5] Added support for query batching. --- README.md | 6 +- .../java/graphql/servlet/GraphQLServlet.java | 242 ++++++++++----- .../graphql/servlet/GraphQLServletSpec.groovy | 289 ++++++++++++++++++ 3 files changed, 462 insertions(+), 75 deletions(-) diff --git a/README.md b/README.md index 8a1f0413..afa3f247 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ # GraphQL Servlet -This module implements a GraphQL Java Servlet. It also supports Relay.js and OSGi out of the box. +This module implements a GraphQL Java Servlet. It also supports Relay.js, Apollo and OSGi out of the box. # Downloading @@ -114,6 +114,10 @@ You **MUST** pass this execution strategy to the servlet for Relay.js support. This is the default execution strategy for the `OsgiGraphQLServlet`, and must be added as a dependency when using that servlet. +## Apollo support + +Query batching is supported, no configuration required. + ## Spring Framework support To use the servlet with Spring Framework, either use the [Spring Boot starter](https://github.com/graphql-java/graphql-spring-boot) or simply define a `ServletRegistrationBean` in a web app: diff --git a/src/main/java/graphql/servlet/GraphQLServlet.java b/src/main/java/graphql/servlet/GraphQLServlet.java index b22bdacf..4e655ead 100644 --- a/src/main/java/graphql/servlet/GraphQLServlet.java +++ b/src/main/java/graphql/servlet/GraphQLServlet.java @@ -1,13 +1,8 @@ package graphql.servlet; -import com.fasterxml.jackson.annotation.JacksonInject; import com.fasterxml.jackson.core.JsonParser; import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.DeserializationContext; -import com.fasterxml.jackson.databind.InjectableValues; -import com.fasterxml.jackson.databind.JsonDeserializer; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.databind.ObjectReader; +import com.fasterxml.jackson.databind.*; import com.fasterxml.jackson.databind.annotation.JsonDeserialize; import graphql.ExecutionInput; import graphql.ExecutionResult; @@ -30,17 +25,11 @@ import javax.servlet.http.HttpServlet; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; -import java.io.IOException; -import java.io.InputStream; +import java.io.*; +import java.nio.charset.StandardCharsets; import java.security.AccessController; import java.security.PrivilegedAction; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Optional; +import java.util.*; import java.util.function.BiConsumer; import java.util.function.Consumer; import java.util.function.Function; @@ -69,8 +58,8 @@ public abstract class GraphQLServlet extends HttpServlet implements Servlet, Gra private final List listeners; private final ServletFileUpload fileUpload; - private final RequestHandler getHandler; - private final RequestHandler postHandler; + private final HttpRequestHandler getHandler; + private final HttpRequestHandler postHandler; public GraphQLServlet() { this(null, null, null); @@ -84,23 +73,31 @@ public GraphQLServlet(ObjectMapperConfigurer objectMapperConfigurer, List { final GraphQLContext context = createContext(Optional.of(request), Optional.of(response)); final Object rootObject = createRootObject(Optional.of(request), Optional.of(response)); + String path = request.getPathInfo(); if (path == null) { path = request.getServletPath(); } if (path.contentEquals("/schema.json")) { - query(IntrospectionQuery.INTROSPECTION_QUERY, null, new HashMap<>(), getSchemaProvider().getSchema(request), request, response, context, rootObject); + doQuery(IntrospectionQuery.INTROSPECTION_QUERY, null, new HashMap<>(), getSchemaProvider().getSchema(request), context, rootObject, request, response); } else { - if (request.getParameter("query") != null) { - final Map variables = new HashMap<>(); - if (request.getParameter("variables") != null) { - variables.putAll(deserializeVariables(request.getParameter("variables"))); - } - String operationName = null; - if (request.getParameter("operationName") != null) { - operationName = request.getParameter("operationName"); + String query = request.getParameter("query"); + if (query != null) { + if (isBatchedQuery(query)) { + doBatchedQuery(getGraphQLRequestMapper().readValues(query), getSchemaProvider().getReadOnlySchema(request), context, rootObject, request, response); + } else { + final Map variables = new HashMap<>(); + if (request.getParameter("variables") != null) { + variables.putAll(deserializeVariables(request.getParameter("variables"))); + } + + String operationName = null; + if (request.getParameter("operationName") != null) { + operationName = request.getParameter("operationName"); + } + + doQuery(query, operationName, variables, getSchemaProvider().getReadOnlySchema(request), context, rootObject, request, response); } - query(request.getParameter("query"), operationName, variables, getSchemaProvider().getReadOnlySchema(request), request, response, context, rootObject); } else { response.setStatus(STATUS_BAD_REQUEST); log.info("Bad GET request: path was not \"/schema.json\" or no query variable named \"query\" given"); @@ -111,70 +108,68 @@ public GraphQLServlet(ObjectMapperConfigurer objectMapperConfigurer, List { final GraphQLContext context = createContext(Optional.of(request), Optional.of(response)); final Object rootObject = createRootObject(Optional.of(request), Optional.of(response)); - GraphQLRequest graphQLRequest = null; try { - InputStream inputStream = null; - if (ServletFileUpload.isMultipartContent(request)) { final Map> fileItems = fileUpload.parseParameterMap(request); + context.setFiles(Optional.of(fileItems)); if (fileItems.containsKey("graphql")) { final Optional graphqlItem = getFileItem(fileItems, "graphql"); if (graphqlItem.isPresent()) { - inputStream = graphqlItem.get().getInputStream(); + String query = new String(graphqlItem.get().get()); + + if (isBatchedQuery(query)) { + doBatchedQuery(getGraphQLRequestMapper().readValues(query), getSchemaProvider().getSchema(request), context, rootObject, request, response); + return; + } else { + doQuery(getGraphQLRequestMapper().readValue(query), getSchemaProvider().getSchema(request), context, rootObject, request, response); + return; + } } - } else if (fileItems.containsKey("query")) { final Optional queryItem = getFileItem(fileItems, "query"); if (queryItem.isPresent()) { - graphQLRequest = new GraphQLRequest(); - graphQLRequest.setQuery(new String(queryItem.get().get())); - - final Optional operationNameItem = getFileItem(fileItems, "operationName"); - if (operationNameItem.isPresent()) { - graphQLRequest.setOperationName(new String(operationNameItem.get().get()).trim()); - } + String query = new String(queryItem.get().get()); + + if (isBatchedQuery(query)) { + doBatchedQuery(getGraphQLRequestMapper().readValues(query), getSchemaProvider().getSchema(request), context, rootObject, request, response); + return; + } else { + Map variables = null; + final Optional variablesItem = getFileItem(fileItems, "variables"); + if (variablesItem.isPresent()) { + variables = deserializeVariables(new String(variablesItem.get().get())); + } - final Optional variablesItem = getFileItem(fileItems, "variables"); - if (variablesItem.isPresent()) { - String variables = new String(variablesItem.get().get()); - if (!variables.isEmpty()) { - graphQLRequest.setVariables(deserializeVariables(variables)); + String operationName = null; + final Optional operationNameItem = getFileItem(fileItems, "operationName"); + if (operationNameItem.isPresent()) { + operationName = new String(operationNameItem.get().get()).trim(); } + + doQuery(query, operationName, variables, getSchemaProvider().getSchema(request), context, rootObject, request, response); + return; } } } - if (inputStream == null && graphQLRequest == null) { - response.setStatus(STATUS_BAD_REQUEST); - log.info("Bad POST multipart request: no part named \"graphql\" or \"query\""); - return; - } - - context.setFiles(Optional.of(fileItems)); - + response.setStatus(STATUS_BAD_REQUEST); + log.info("Bad POST multipart request: no part named \"graphql\" or \"query\""); } else { // this is not a multipart request - inputStream = request.getInputStream(); - } + String query = inputStreamToString(request.getInputStream()); - if (graphQLRequest == null) { - graphQLRequest = getGraphQLRequestMapper().readValue(inputStream); + if (isBatchedQuery(query)) { + doBatchedQuery(getGraphQLRequestMapper().readValues(query), getSchemaProvider().getSchema(request), context, rootObject, request, response); + } else { + doQuery(getGraphQLRequestMapper().readValue(query), getSchemaProvider().getSchema(request), context, rootObject, request, response); + } } - } catch (Exception e) { log.info("Bad POST request: parsing failed", e); response.setStatus(STATUS_BAD_REQUEST); - return; } - - Map variables = graphQLRequest.getVariables(); - if (variables == null) { - variables = new HashMap<>(); - } - - query(graphQLRequest.getQuery(), graphQLRequest.getOperationName(), variables, getSchemaProvider().getSchema(request), request, response, context, rootObject); }; } @@ -221,7 +216,7 @@ public String executeQuery(String query) { } } - private void doRequest(HttpServletRequest request, HttpServletResponse response, RequestHandler handler) { + private void doRequest(HttpServletRequest request, HttpServletResponse response, HttpRequestHandler handler) { List requestCallbacks = runListeners(l -> l.onRequest(request, response)); @@ -266,14 +261,48 @@ private GraphQL newGraphQL(GraphQLSchema schema) { .build(); } - private void query(String query, String operationName, Map variables, GraphQLSchema schema, HttpServletRequest req, HttpServletResponse resp, GraphQLContext context, Object rootObject) throws IOException { + private void doQuery(GraphQLRequest graphQLRequest, GraphQLSchema schema, GraphQLContext context, Object rootObject, HttpServletRequest httpReq, HttpServletResponse httpRes) throws Exception { + doQuery(graphQLRequest.getQuery(), graphQLRequest.getOperationName(), graphQLRequest.getVariables(), schema, context, rootObject, httpReq, httpRes); + } + + private void doQuery(String query, String operationName, Map variables, GraphQLSchema schema, GraphQLContext context, Object rootObject, HttpServletRequest req, HttpServletResponse resp) throws Exception { + query(query, operationName, variables, schema, context, rootObject, (r) -> { + resp.setContentType(APPLICATION_JSON_UTF8); + resp.setStatus(r.getStatus()); + resp.getWriter().write(r.getResponse()); + }); + } + + private void doBatchedQuery(Iterator graphQLRequests, GraphQLSchema schema, GraphQLContext context, Object rootObject, HttpServletRequest req, HttpServletResponse resp) throws Exception { + final List graphQLResponses = new ArrayList<>(); + + while (graphQLRequests.hasNext()) { + GraphQLRequest graphQLRequest = graphQLRequests.next(); + query(graphQLRequest.getQuery(), graphQLRequest.getOperationName(), graphQLRequest.getVariables(), schema, context, rootObject, graphQLResponses::add); + } + + resp.setContentType(APPLICATION_JSON_UTF8); + resp.setStatus(STATUS_OK); + + Writer responseWriter = resp.getWriter(); + responseWriter.write('['); + for (Iterator i = graphQLResponses.iterator(); i.hasNext();) { + responseWriter.write(i.next().getResponse()); + if (i.hasNext()) { + responseWriter.write(','); + } + } + responseWriter.write(']'); + } + + private void query(String query, String operationName, Map variables, GraphQLSchema schema, GraphQLContext context, Object rootObject, GraphQLResponseHandler responseHandler) throws Exception { if (operationName != null && operationName.isEmpty()) { - query(query, null, variables, schema, req, resp, context, rootObject); + query(query, null, variables, schema, context, rootObject, responseHandler); } else if (Subject.getSubject(AccessController.getContext()) == null && context.getSubject().isPresent()) { Subject.doAs(context.getSubject().get(), (PrivilegedAction) () -> { try { - query(query, operationName, variables, schema, req, resp, context, rootObject); - } catch (IOException e) { + query(query, operationName, variables, schema, context, rootObject, responseHandler); + } catch (Exception e) { throw new RuntimeException(e); } return null; @@ -287,9 +316,10 @@ private void query(String query, String operationName, Map varia final String response = getMapper().writeValueAsString(createResultFromDataAndErrors(data, errors)); - resp.setContentType(APPLICATION_JSON_UTF8); - resp.setStatus(STATUS_OK); - resp.getWriter().write(response); + GraphQLResponse graphQLResponse = new GraphQLResponse(); + graphQLResponse.setStatus(STATUS_OK); + graphQLResponse.setResponse(response); + responseHandler.handle(graphQLResponse); if(getGraphQLErrorHandler().errorsPresent(errors)) { runCallbacks(operationCallbacks, c -> c.onError(context, operationName, query, variables, data, errors)); @@ -373,6 +403,36 @@ private static Map deserializeVariablesObject(Object variables, } } + private boolean isBatchedQuery(String query) { + if (query == null) { + return false; + } + + // return true if the first non whitespace character is the beginning of an array + for (int i = 0; i < query.length(); i++) { + char ch = query.charAt(i); + if (!Character.isWhitespace(ch)) { + return ch == '['; + } + } + + return false; + } + + private String inputStreamToString(InputStream inputStream) throws IOException { + if (inputStream == null) { + return null; + } + + ByteArrayOutputStream result = new ByteArrayOutputStream(); + byte[] buffer = new byte[1024]; + int length; + while ((length = inputStream.read(buffer)) != -1) { + result.write(buffer, 0, length); + } + return result.toString(StandardCharsets.UTF_8.name()); + } + protected static class GraphQLRequest { private String query; @JsonDeserialize(using = GraphQLServlet.VariablesDeserializer.class) @@ -404,7 +464,28 @@ public void setOperationName(String operationName) { } } - protected interface RequestHandler extends BiConsumer { + protected static class GraphQLResponse { + private int status; + private String response; + + public int getStatus() { + return status; + } + + public void setStatus(int status) { + this.status = status; + } + + public String getResponse() { + return response; + } + + public void setResponse(String response) { + this.response = response; + } + } + + protected interface HttpRequestHandler extends BiConsumer { @Override default void accept(HttpServletRequest request, HttpServletResponse response) { try { @@ -416,4 +497,17 @@ default void accept(HttpServletRequest request, HttpServletResponse response) { void handle(HttpServletRequest request, HttpServletResponse response) throws Exception; } + + protected interface GraphQLResponseHandler extends Consumer { + @Override + default void accept(GraphQLResponse response) { + try { + handle(response); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + void handle(GraphQLResponse r) throws Exception; + } } diff --git a/src/test/groovy/graphql/servlet/GraphQLServletSpec.groovy b/src/test/groovy/graphql/servlet/GraphQLServletSpec.groovy index fa53f48f..4b32e332 100644 --- a/src/test/groovy/graphql/servlet/GraphQLServletSpec.groovy +++ b/src/test/groovy/graphql/servlet/GraphQLServletSpec.groovy @@ -75,6 +75,10 @@ class GraphQLServletSpec extends Specification { mapper.readValue(response.getContentAsByteArray(), Map) } + List> getBatchedResponseContent() { + mapper.readValue(response.getContentAsByteArray(), List) + } + def "HTTP GET without info returns bad request"() { when: servlet.doGet(request, response) @@ -161,7 +165,76 @@ class GraphQLServletSpec extends Specification { response.getStatus() == STATUS_OK response.getContentType() == CONTENT_TYPE_JSON_UTF8 getResponseContent().data.echo == "test" + } + + def "batched query over HTTP GET returns data"() { + setup: + request.addParameter('query', '[{ "query": "query { echo(arg:\\"test\\") }" }, { "query": "query { echo(arg:\\"test\\") }" }]') + + when: + servlet.doGet(request, response) + + then: + response.getStatus() == STATUS_OK + response.getContentType() == CONTENT_TYPE_JSON_UTF8 + getBatchedResponseContent()[0].data.echo == "test" + getBatchedResponseContent()[1].data.echo == "test" + } + + def "batched query over HTTP GET with variables returns data"() { + setup: + request.addParameter('query', '[{ "query": "query { echo(arg:\\"test\\") }", "variables": { "arg": "test" } }, { "query": "query { echo(arg:\\"test\\") }", "variables": { "arg": "test" } }]') + + when: + servlet.doGet(request, response) + + then: + response.getStatus() == STATUS_OK + response.getContentType() == CONTENT_TYPE_JSON_UTF8 + getBatchedResponseContent()[0].data.echo == "test" + getBatchedResponseContent()[1].data.echo == "test" + } + + def "batched query over HTTP GET with variables as string returns data"() { + setup: + request.addParameter('query', '[{ "query": "query { echo(arg:\\"test\\") }", "variables": "{ \\"arg\\": \\"test\\" }" }, { "query": "query { echo(arg:\\"test\\") }", "variables": "{ \\"arg\\": \\"test\\" }" }]') + when: + servlet.doGet(request, response) + + then: + response.getStatus() == STATUS_OK + response.getContentType() == CONTENT_TYPE_JSON_UTF8 + getBatchedResponseContent()[0].data.echo == "test" + getBatchedResponseContent()[1].data.echo == "test" + } + + def "batched query over HTTP GET with operationName returns data"() { + when: + response = new MockHttpServletResponse() + request.addParameter('query', '[{ "query": "query one{ echoOne: echo(arg:\\"test-one\\") } query two{ echoTwo: echo(arg:\\"test-two\\") }", "operationName": "one" }, { "query": "query one{ echoOne: echo(arg:\\"test-one\\") } query two{ echoTwo: echo(arg:\\"test-two\\") }", "operationName": "two" }]') + servlet.doGet(request, response) + + then: + response.getStatus() == STATUS_OK + response.getContentType() == CONTENT_TYPE_JSON_UTF8 + getBatchedResponseContent()[0].data.echoOne == "test-one" + getBatchedResponseContent()[0].data.echoTwo == null + getBatchedResponseContent()[1].data.echoOne == null + getBatchedResponseContent()[1].data.echoTwo == "test-two" + } + + def "batched query over HTTP GET with empty non-null operationName returns data"() { + when: + response = new MockHttpServletResponse() + request.addParameter('query', '[{ "query": "query echo{ echo: echo(arg:\\"test\\") }", "operationName": "" }, { "query": "query echo{ echo: echo(arg:\\"test\\") }", "operationName": "" }]') + servlet.doGet(request, response) + + then: + response.getStatus() == STATUS_OK + response.getContentType() == CONTENT_TYPE_JSON_UTF8 + getBatchedResponseContent()[0].data.echo == "test" + getBatchedResponseContent()[1].data.echo == "test" } def "mutation over HTTP GET returns errors"() { @@ -177,6 +250,20 @@ class GraphQLServletSpec extends Specification { getResponseContent().errors.size() == 1 } + def "batched mutation over HTTP GET returns errors"() { + setup: + request.addParameter('query', '[{ "query": "mutation { echo(arg:\\"test\\") }" }, { "query": "mutation {echo(arg:\\"test\\") }" }]') + + when: + servlet.doGet(request, response) + + then: + response.getStatus() == STATUS_OK + response.getContentType() == CONTENT_TYPE_JSON_UTF8 + getBatchedResponseContent()[0].errors.size() == 1 + getBatchedResponseContent()[1].errors.size() == 1 + } + def "query over HTTP POST without part or body returns bad request"() { when: servlet.doPost(request, response) @@ -339,6 +426,157 @@ class GraphQLServletSpec extends Specification { getResponseContent().data.echo == "test" } + def "batched query over HTTP POST body returns data"() { + setup: + request.setContent('[{ "query": "query { echo(arg:\\"test\\") }" }, { "query": "query { echo(arg:\\"test\\") }" }]'.bytes) + + when: + servlet.doPost(request, response) + + then: + response.getStatus() == STATUS_OK + response.getContentType() == CONTENT_TYPE_JSON_UTF8 + getBatchedResponseContent()[0].data.echo == "test" + getBatchedResponseContent()[1].data.echo == "test" + } + + def "batched query over HTTP POST body with variables returns data"() { + setup: + request.setContent('[{ "query": "query { echo(arg:\\"test\\") }", "variables": { "arg": "test" } }, { "query": "query { echo(arg:\\"test\\") }", "variables": { "arg": "test" } }]'.bytes) + + when: + servlet.doPost(request, response) + + then: + response.getStatus() == STATUS_OK + response.getContentType() == CONTENT_TYPE_JSON_UTF8 + getBatchedResponseContent()[0].data.echo == "test" + getBatchedResponseContent()[1].data.echo == "test" + } + + def "batched query over HTTP POST body with operationName returns data"() { + setup: + request.setContent('[{ "query": "query one{ echoOne: echo(arg:\\"test-one\\") } query two{ echoTwo: echo(arg:\\"test-two\\") }", "operationName": "one" }, { "query": "query one{ echoOne: echo(arg:\\"test-one\\") } query two{ echoTwo: echo(arg:\\"test-two\\") }", "operationName": "two" }]'.bytes) + + when: + servlet.doPost(request, response) + + then: + response.getStatus() == STATUS_OK + response.getContentType() == CONTENT_TYPE_JSON_UTF8 + getBatchedResponseContent()[0].data.echoOne == "test-one" + getBatchedResponseContent()[0].data.echoTwo == null + getBatchedResponseContent()[1].data.echoOne == null + getBatchedResponseContent()[1].data.echoTwo == "test-two" + } + + def "batched query over HTTP POST body with empty non-null operationName returns data"() { + setup: + request.setContent('[{ "query": "query echo{ echo: echo(arg:\\"test\\") }", "operationName": "" }, { "query": "query echo{ echo: echo(arg:\\"test\\") }", "operationName": "" }]'.bytes) + + when: + servlet.doPost(request, response) + + then: + response.getStatus() == STATUS_OK + response.getContentType() == CONTENT_TYPE_JSON_UTF8 + getBatchedResponseContent()[0].data.echo == "test" + getBatchedResponseContent()[1].data.echo == "test" + } + + def "batched query over HTTP POST multipart named 'graphql' returns data"() { + setup: + request.setContentType("multipart/form-data, boundary=test") + request.setMethod("POST") + + request.setContent(new TestMultipartContentBuilder() + .addPart('graphql', '[{ "query": "query { echo(arg:\\"test\\") }" }, { "query": "query { echo(arg:\\"test\\") }" }]') + .build()) + + when: + servlet.doPost(request, response) + + then: + response.getStatus() == STATUS_OK + response.getContentType() == CONTENT_TYPE_JSON_UTF8 + getBatchedResponseContent()[0].data.echo == "test" + getBatchedResponseContent()[1].data.echo == "test" + } + + def "batched query over HTTP POST multipart named 'query' returns data"() { + setup: + request.setContentType("multipart/form-data, boundary=test") + request.setMethod("POST") + request.setContent(new TestMultipartContentBuilder() + .addPart('query', '[{ "query": "query { echo(arg:\\"test\\") }" }, { "query": "query { echo(arg:\\"test\\") }" }]') + .build()) + + when: + servlet.doPost(request, response) + + then: + response.getStatus() == STATUS_OK + response.getContentType() == CONTENT_TYPE_JSON_UTF8 + getBatchedResponseContent()[0].data.echo == "test" + getBatchedResponseContent()[1].data.echo == "test" + } + + def "batched query over HTTP POST multipart named 'query' with operationName returns data"() { + setup: + request.setContentType("multipart/form-data, boundary=test") + request.setMethod("POST") + request.setContent(new TestMultipartContentBuilder() + .addPart('query', '[{ "query": "query one{ echoOne: echo(arg:\\"test-one\\") } query two{ echoTwo: echo(arg:\\"test-two\\") }", "operationName": "one" }, { "query": "query one{ echoOne: echo(arg:\\"test-one\\") } query two{ echoTwo: echo(arg:\\"test-two\\") }", "operationName": "two" }]') + .build()) + + when: + servlet.doPost(request, response) + + then: + response.getStatus() == STATUS_OK + response.getContentType() == CONTENT_TYPE_JSON_UTF8 + getBatchedResponseContent()[0].data.echoOne == "test-one" + getBatchedResponseContent()[0].data.echoTwo == null + getBatchedResponseContent()[1].data.echoOne == null + getBatchedResponseContent()[1].data.echoTwo == "test-two" + } + + def "batched query over HTTP POST multipart named 'query' with empty non-null operationName returns data"() { + setup: + request.setContentType("multipart/form-data, boundary=test") + request.setMethod("POST") + request.setContent(new TestMultipartContentBuilder() + .addPart('query', '[{ "query": "query echo{ echo: echo(arg:\\"test\\") }", "operationName": "" }, { "query": "query echo{ echo: echo(arg:\\"test\\") }", "operationName": "" }]') + .build()) + + when: + servlet.doPost(request, response) + + then: + response.getStatus() == STATUS_OK + response.getContentType() == CONTENT_TYPE_JSON_UTF8 + getBatchedResponseContent()[0].data.echo == "test" + getBatchedResponseContent()[1].data.echo == "test" + } + + def "batched query over HTTP POST multipart named 'query' with variables returns data"() { + setup: + request.setContentType("multipart/form-data, boundary=test") + request.setMethod("POST") + request.setContent(new TestMultipartContentBuilder() + .addPart('query', '[{ "query": "query echo($arg: String) { echo(arg:$arg) }", "variables": { "arg": "test" } }, { "query": "query echo($arg: String) { echo(arg:$arg) }", "variables": { "arg": "test" } }]') + .build()) + + when: + servlet.doPost(request, response) + + then: + response.getStatus() == STATUS_OK + response.getContentType() == CONTENT_TYPE_JSON_UTF8 + getBatchedResponseContent()[0].data.echo == "test" + getBatchedResponseContent()[1].data.echo == "test" + } + def "mutation over HTTP POST body returns data"() { setup: request.setContent(mapper.writeValueAsBytes([ @@ -354,6 +592,20 @@ class GraphQLServletSpec extends Specification { getResponseContent().data.echo == "test" } + def "batched mutation over HTTP POST body returns data"() { + setup: + request.setContent('[{ "query": "mutation { echo(arg:\\"test\\") }" }, { "query": "mutation { echo(arg:\\"test\\") }" }]'.bytes) + + when: + servlet.doPost(request, response) + + then: + response.getStatus() == STATUS_OK + response.getContentType() == CONTENT_TYPE_JSON_UTF8 + getBatchedResponseContent()[0].data.echo == "test" + getBatchedResponseContent()[1].data.echo == "test" + } + def "errors before graphql schema execution return internal server error"() { setup: servlet = new SimpleGraphQLServlet(servlet.getSchemaProvider().getSchema()) { @@ -389,6 +641,24 @@ class GraphQLServletSpec extends Specification { errors.first().message.startsWith("Internal Server Error(s)") } + def "batched errors while data fetching are masked in the response"() { + setup: + servlet = createServlet({ throw new TestException() }) + request.addParameter('query', '[{ "query": "query { echo(arg:\\"test\\") }" }, { "query": "query { echo(arg:\\"test\\") }" }]') + + when: + servlet.doGet(request, response) + + then: + response.getStatus() == STATUS_OK + response.getContentType() == CONTENT_TYPE_JSON_UTF8 + def errors = getBatchedResponseContent().errors + errors[0].size() == 1 + errors[0].first().message.startsWith("Internal Server Error(s)") + errors[1].size() == 1 + errors[1].first().message.startsWith("Internal Server Error(s)") + } + def "data field is present and null if no data can be returned"() { setup: request.addParameter('query', 'query { not-a-field(arg:"test") }') @@ -405,6 +675,25 @@ class GraphQLServletSpec extends Specification { resp.errors != null } + def "batched data field is present and null if no data can be returned"() { + setup: + request.addParameter('query', '[{ "query": "query { not-a-field(arg:\\"test\\") }" }, { "query": "query { not-a-field(arg:\\"test\\") }" }]') + + when: + servlet.doGet(request, response) + + then: + response.getStatus() == STATUS_OK + response.getContentType() == CONTENT_TYPE_JSON_UTF8 + def resp = getBatchedResponseContent() + resp[0].containsKey("data") + resp[0].data == null + resp[0].errors != null + resp[1].containsKey("data") + resp[1].data == null + resp[1].errors != null + } + def "typeInfo is serialized correctly"() { expect: servlet.getMapper().writeValueAsString(ExecutionTypeInfo.newTypeInfo().type(new GraphQLNonNull(Scalars.GraphQLString)).build()) != "{}" From 6317887380e2ed4f5bd0a01c14f51366a833c21e Mon Sep 17 00:00:00 2001 From: ronald-d-rogers Date: Wed, 1 Nov 2017 21:23:54 -0400 Subject: [PATCH 2/5] Removed wild-card imports. --- .../java/graphql/servlet/GraphQLServlet.java | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/src/main/java/graphql/servlet/GraphQLServlet.java b/src/main/java/graphql/servlet/GraphQLServlet.java index 4e655ead..144d8eb9 100644 --- a/src/main/java/graphql/servlet/GraphQLServlet.java +++ b/src/main/java/graphql/servlet/GraphQLServlet.java @@ -2,7 +2,11 @@ import com.fasterxml.jackson.core.JsonParser; import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.*; +import com.fasterxml.jackson.databind.DeserializationContext; +import com.fasterxml.jackson.databind.InjectableValues; +import com.fasterxml.jackson.databind.JsonDeserializer; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.ObjectReader; import com.fasterxml.jackson.databind.annotation.JsonDeserialize; import graphql.ExecutionInput; import graphql.ExecutionResult; @@ -25,11 +29,21 @@ import javax.servlet.http.HttpServlet; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; -import java.io.*; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.Writer; import java.nio.charset.StandardCharsets; import java.security.AccessController; import java.security.PrivilegedAction; -import java.util.*; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; import java.util.function.BiConsumer; import java.util.function.Consumer; import java.util.function.Function; From 145971a48d90277b59f5ddc6ba50852dfc7996af Mon Sep 17 00:00:00 2001 From: ronald-d-rogers Date: Wed, 1 Nov 2017 23:11:35 -0400 Subject: [PATCH 3/5] Switched back to using input streams where possible. --- .../java/graphql/servlet/GraphQLServlet.java | 88 ++++++++++++------- 1 file changed, 57 insertions(+), 31 deletions(-) diff --git a/src/main/java/graphql/servlet/GraphQLServlet.java b/src/main/java/graphql/servlet/GraphQLServlet.java index 144d8eb9..4072f31f 100644 --- a/src/main/java/graphql/servlet/GraphQLServlet.java +++ b/src/main/java/graphql/servlet/GraphQLServlet.java @@ -29,10 +29,7 @@ import javax.servlet.http.HttpServlet; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.InputStream; -import java.io.Writer; +import java.io.*; import java.nio.charset.StandardCharsets; import java.security.AccessController; import java.security.PrivilegedAction; @@ -131,25 +128,35 @@ public GraphQLServlet(ObjectMapperConfigurer objectMapperConfigurer, List graphqlItem = getFileItem(fileItems, "graphql"); if (graphqlItem.isPresent()) { - String query = new String(graphqlItem.get().get()); + InputStream inputStream = graphqlItem.get().getInputStream(); - if (isBatchedQuery(query)) { - doBatchedQuery(getGraphQLRequestMapper().readValues(query), getSchemaProvider().getSchema(request), context, rootObject, request, response); + if (!inputStream.markSupported()) { + inputStream = new BufferedInputStream(inputStream); + } + + if (isBatchedQuery(inputStream)) { + doBatchedQuery(getGraphQLRequestMapper().readValues(inputStream), getSchemaProvider().getSchema(request), context, rootObject, request, response); return; } else { - doQuery(getGraphQLRequestMapper().readValue(query), getSchemaProvider().getSchema(request), context, rootObject, request, response); + doQuery(getGraphQLRequestMapper().readValue(inputStream), getSchemaProvider().getSchema(request), context, rootObject, request, response); return; } } } else if (fileItems.containsKey("query")) { final Optional queryItem = getFileItem(fileItems, "query"); if (queryItem.isPresent()) { - String query = new String(queryItem.get().get()); + InputStream inputStream = queryItem.get().getInputStream(); + + if (!inputStream.markSupported()) { + inputStream = new BufferedInputStream(inputStream); + } - if (isBatchedQuery(query)) { - doBatchedQuery(getGraphQLRequestMapper().readValues(query), getSchemaProvider().getSchema(request), context, rootObject, request, response); + if (isBatchedQuery(inputStream)) { + doBatchedQuery(getGraphQLRequestMapper().readValues(inputStream), getSchemaProvider().getSchema(request), context, rootObject, request, response); return; } else { + String query = new String(queryItem.get().get()); + Map variables = null; final Optional variablesItem = getFileItem(fileItems, "variables"); if (variablesItem.isPresent()) { @@ -172,12 +179,16 @@ public GraphQLServlet(ObjectMapperConfigurer objectMapperConfigurer, List deserializeVariablesObject(Object variables, } } - private boolean isBatchedQuery(String query) { - if (query == null) { + private boolean isBatchedQuery(InputStream inputStream) throws IOException { + if (inputStream == null) { return false; } - // return true if the first non whitespace character is the beginning of an array - for (int i = 0; i < query.length(); i++) { - char ch = query.charAt(i); - if (!Character.isWhitespace(ch)) { - return ch == '['; + ByteArrayOutputStream result = new ByteArrayOutputStream(); + byte[] buffer = new byte[128]; + int length; + + inputStream.mark(0); + while ((length = inputStream.read(buffer)) != -1) { + result.write(buffer, 0, length); + String chunk = result.toString(); + Boolean isArrayStart = isArrayStart(chunk); + if (isArrayStart != null) { + inputStream.reset(); + return isArrayStart; } } + inputStream.reset(); return false; } - private String inputStreamToString(InputStream inputStream) throws IOException { - if (inputStream == null) { - return null; + private boolean isBatchedQuery(String query) { + if (query == null) { + return false; } - ByteArrayOutputStream result = new ByteArrayOutputStream(); - byte[] buffer = new byte[1024]; - int length; - while ((length = inputStream.read(buffer)) != -1) { - result.write(buffer, 0, length); + Boolean isArrayStart = isArrayStart(query); + return isArrayStart != null && isArrayStart; + } + + // return true if the first non whitespace character is the beginning of an array + private Boolean isArrayStart(String s) { + for (int i = 0; i < s.length(); i++) { + char ch = s.charAt(i); + if (!Character.isWhitespace(ch)) { + return ch == '['; + } } - return result.toString(StandardCharsets.UTF_8.name()); + + return null; } protected static class GraphQLRequest { From 43033e4373306197028f2cb729dec40e7199c575 Mon Sep 17 00:00:00 2001 From: ronald-d-rogers Date: Wed, 1 Nov 2017 23:24:04 -0400 Subject: [PATCH 4/5] Streaming output of batched queries. --- .../java/graphql/servlet/GraphQLServlet.java | 22 +++++++------------ 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/src/main/java/graphql/servlet/GraphQLServlet.java b/src/main/java/graphql/servlet/GraphQLServlet.java index 4072f31f..54ab6b41 100644 --- a/src/main/java/graphql/servlet/GraphQLServlet.java +++ b/src/main/java/graphql/servlet/GraphQLServlet.java @@ -299,25 +299,19 @@ private void doQuery(String query, String operationName, Map var } private void doBatchedQuery(Iterator graphQLRequests, GraphQLSchema schema, GraphQLContext context, Object rootObject, HttpServletRequest req, HttpServletResponse resp) throws Exception { - final List graphQLResponses = new ArrayList<>(); - - while (graphQLRequests.hasNext()) { - GraphQLRequest graphQLRequest = graphQLRequests.next(); - query(graphQLRequest.getQuery(), graphQLRequest.getOperationName(), graphQLRequest.getVariables(), schema, context, rootObject, graphQLResponses::add); - } - resp.setContentType(APPLICATION_JSON_UTF8); resp.setStatus(STATUS_OK); - Writer responseWriter = resp.getWriter(); - responseWriter.write('['); - for (Iterator i = graphQLResponses.iterator(); i.hasNext();) { - responseWriter.write(i.next().getResponse()); - if (i.hasNext()) { - responseWriter.write(','); + Writer respWriter = resp.getWriter(); + respWriter.write('['); + while (graphQLRequests.hasNext()) { + GraphQLRequest graphQLRequest = graphQLRequests.next(); + query(graphQLRequest.getQuery(), graphQLRequest.getOperationName(), graphQLRequest.getVariables(), schema, context, rootObject, (r) -> respWriter.write(r.getResponse())); + if (graphQLRequests.hasNext()) { + respWriter.write(','); } } - responseWriter.write(']'); + respWriter.write(']'); } private void query(String query, String operationName, Map variables, GraphQLSchema schema, GraphQLContext context, Object rootObject, GraphQLResponseHandler responseHandler) throws Exception { From 4c2602ae134448d2ba53099c91d0907316613261 Mon Sep 17 00:00:00 2001 From: ronald-d-rogers Date: Thu, 2 Nov 2017 12:45:33 -0400 Subject: [PATCH 5/5] Removed wild-card imports again. --- src/main/java/graphql/servlet/GraphQLServlet.java | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/main/java/graphql/servlet/GraphQLServlet.java b/src/main/java/graphql/servlet/GraphQLServlet.java index 54ab6b41..421327c5 100644 --- a/src/main/java/graphql/servlet/GraphQLServlet.java +++ b/src/main/java/graphql/servlet/GraphQLServlet.java @@ -29,8 +29,11 @@ import javax.servlet.http.HttpServlet; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; -import java.io.*; -import java.nio.charset.StandardCharsets; +import java.io.BufferedInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.Writer; import java.security.AccessController; import java.security.PrivilegedAction; import java.util.ArrayList;