diff --git a/README.md b/README.md index 8a1f0413..afa3f247 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ # GraphQL Servlet -This module implements a GraphQL Java Servlet. It also supports Relay.js and OSGi out of the box. +This module implements a GraphQL Java Servlet. It also supports Relay.js, Apollo and OSGi out of the box. # Downloading @@ -114,6 +114,10 @@ You **MUST** pass this execution strategy to the servlet for Relay.js support. This is the default execution strategy for the `OsgiGraphQLServlet`, and must be added as a dependency when using that servlet. +## Apollo support + +Query batching is supported, no configuration required. + ## Spring Framework support To use the servlet with Spring Framework, either use the [Spring Boot starter](https://github.com/graphql-java/graphql-spring-boot) or simply define a `ServletRegistrationBean` in a web app: diff --git a/src/main/java/graphql/servlet/GraphQLServlet.java b/src/main/java/graphql/servlet/GraphQLServlet.java index b22bdacf..421327c5 100644 --- a/src/main/java/graphql/servlet/GraphQLServlet.java +++ b/src/main/java/graphql/servlet/GraphQLServlet.java @@ -1,6 +1,5 @@ package graphql.servlet; -import com.fasterxml.jackson.annotation.JacksonInject; import com.fasterxml.jackson.core.JsonParser; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.DeserializationContext; @@ -30,13 +29,17 @@ import javax.servlet.http.HttpServlet; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; +import java.io.BufferedInputStream; +import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.InputStream; +import java.io.Writer; import java.security.AccessController; import java.security.PrivilegedAction; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; +import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Objects; @@ -69,8 +72,8 @@ public abstract class GraphQLServlet extends HttpServlet implements Servlet, Gra private final List listeners; private final ServletFileUpload fileUpload; - private final RequestHandler getHandler; - private final RequestHandler postHandler; + private final HttpRequestHandler getHandler; + private final HttpRequestHandler postHandler; public GraphQLServlet() { this(null, null, null); @@ -84,23 +87,31 @@ public GraphQLServlet(ObjectMapperConfigurer objectMapperConfigurer, List { final GraphQLContext context = createContext(Optional.of(request), Optional.of(response)); final Object rootObject = createRootObject(Optional.of(request), Optional.of(response)); + String path = request.getPathInfo(); if (path == null) { path = request.getServletPath(); } if (path.contentEquals("/schema.json")) { - query(IntrospectionQuery.INTROSPECTION_QUERY, null, new HashMap<>(), getSchemaProvider().getSchema(request), request, response, context, rootObject); + doQuery(IntrospectionQuery.INTROSPECTION_QUERY, null, new HashMap<>(), getSchemaProvider().getSchema(request), context, rootObject, request, response); } else { - if (request.getParameter("query") != null) { - final Map variables = new HashMap<>(); - if (request.getParameter("variables") != null) { - variables.putAll(deserializeVariables(request.getParameter("variables"))); - } - String operationName = null; - if (request.getParameter("operationName") != null) { - operationName = request.getParameter("operationName"); + String query = request.getParameter("query"); + if (query != null) { + if (isBatchedQuery(query)) { + doBatchedQuery(getGraphQLRequestMapper().readValues(query), getSchemaProvider().getReadOnlySchema(request), context, rootObject, request, response); + } else { + final Map variables = new HashMap<>(); + if (request.getParameter("variables") != null) { + variables.putAll(deserializeVariables(request.getParameter("variables"))); + } + + String operationName = null; + if (request.getParameter("operationName") != null) { + operationName = request.getParameter("operationName"); + } + + doQuery(query, operationName, variables, getSchemaProvider().getReadOnlySchema(request), context, rootObject, request, response); } - query(request.getParameter("query"), operationName, variables, getSchemaProvider().getReadOnlySchema(request), request, response, context, rootObject); } else { response.setStatus(STATUS_BAD_REQUEST); log.info("Bad GET request: path was not \"/schema.json\" or no query variable named \"query\" given"); @@ -111,70 +122,82 @@ public GraphQLServlet(ObjectMapperConfigurer objectMapperConfigurer, List { final GraphQLContext context = createContext(Optional.of(request), Optional.of(response)); final Object rootObject = createRootObject(Optional.of(request), Optional.of(response)); - GraphQLRequest graphQLRequest = null; try { - InputStream inputStream = null; - if (ServletFileUpload.isMultipartContent(request)) { final Map> fileItems = fileUpload.parseParameterMap(request); + context.setFiles(Optional.of(fileItems)); if (fileItems.containsKey("graphql")) { final Optional graphqlItem = getFileItem(fileItems, "graphql"); if (graphqlItem.isPresent()) { - inputStream = graphqlItem.get().getInputStream(); - } + InputStream inputStream = graphqlItem.get().getInputStream(); + if (!inputStream.markSupported()) { + inputStream = new BufferedInputStream(inputStream); + } + + if (isBatchedQuery(inputStream)) { + doBatchedQuery(getGraphQLRequestMapper().readValues(inputStream), getSchemaProvider().getSchema(request), context, rootObject, request, response); + return; + } else { + doQuery(getGraphQLRequestMapper().readValue(inputStream), getSchemaProvider().getSchema(request), context, rootObject, request, response); + return; + } + } } else if (fileItems.containsKey("query")) { final Optional queryItem = getFileItem(fileItems, "query"); if (queryItem.isPresent()) { - graphQLRequest = new GraphQLRequest(); - graphQLRequest.setQuery(new String(queryItem.get().get())); + InputStream inputStream = queryItem.get().getInputStream(); - final Optional operationNameItem = getFileItem(fileItems, "operationName"); - if (operationNameItem.isPresent()) { - graphQLRequest.setOperationName(new String(operationNameItem.get().get()).trim()); + if (!inputStream.markSupported()) { + inputStream = new BufferedInputStream(inputStream); } - final Optional variablesItem = getFileItem(fileItems, "variables"); - if (variablesItem.isPresent()) { - String variables = new String(variablesItem.get().get()); - if (!variables.isEmpty()) { - graphQLRequest.setVariables(deserializeVariables(variables)); + if (isBatchedQuery(inputStream)) { + doBatchedQuery(getGraphQLRequestMapper().readValues(inputStream), getSchemaProvider().getSchema(request), context, rootObject, request, response); + return; + } else { + String query = new String(queryItem.get().get()); + + Map variables = null; + final Optional variablesItem = getFileItem(fileItems, "variables"); + if (variablesItem.isPresent()) { + variables = deserializeVariables(new String(variablesItem.get().get())); } + + String operationName = null; + final Optional operationNameItem = getFileItem(fileItems, "operationName"); + if (operationNameItem.isPresent()) { + operationName = new String(operationNameItem.get().get()).trim(); + } + + doQuery(query, operationName, variables, getSchemaProvider().getSchema(request), context, rootObject, request, response); + return; } } } - if (inputStream == null && graphQLRequest == null) { - response.setStatus(STATUS_BAD_REQUEST); - log.info("Bad POST multipart request: no part named \"graphql\" or \"query\""); - return; - } - - context.setFiles(Optional.of(fileItems)); - + response.setStatus(STATUS_BAD_REQUEST); + log.info("Bad POST multipart request: no part named \"graphql\" or \"query\""); } else { // this is not a multipart request - inputStream = request.getInputStream(); - } + InputStream inputStream = request.getInputStream(); - if (graphQLRequest == null) { - graphQLRequest = getGraphQLRequestMapper().readValue(inputStream); - } + if (!inputStream.markSupported()) { + inputStream = new BufferedInputStream(inputStream); + } + if (isBatchedQuery(inputStream)) { + doBatchedQuery(getGraphQLRequestMapper().readValues(inputStream), getSchemaProvider().getSchema(request), context, rootObject, request, response); + } else { + doQuery(getGraphQLRequestMapper().readValue(inputStream), getSchemaProvider().getSchema(request), context, rootObject, request, response); + } + } } catch (Exception e) { log.info("Bad POST request: parsing failed", e); response.setStatus(STATUS_BAD_REQUEST); - return; - } - - Map variables = graphQLRequest.getVariables(); - if (variables == null) { - variables = new HashMap<>(); } - - query(graphQLRequest.getQuery(), graphQLRequest.getOperationName(), variables, getSchemaProvider().getSchema(request), request, response, context, rootObject); }; } @@ -221,7 +244,7 @@ public String executeQuery(String query) { } } - private void doRequest(HttpServletRequest request, HttpServletResponse response, RequestHandler handler) { + private void doRequest(HttpServletRequest request, HttpServletResponse response, HttpRequestHandler handler) { List requestCallbacks = runListeners(l -> l.onRequest(request, response)); @@ -266,14 +289,42 @@ private GraphQL newGraphQL(GraphQLSchema schema) { .build(); } - private void query(String query, String operationName, Map variables, GraphQLSchema schema, HttpServletRequest req, HttpServletResponse resp, GraphQLContext context, Object rootObject) throws IOException { + private void doQuery(GraphQLRequest graphQLRequest, GraphQLSchema schema, GraphQLContext context, Object rootObject, HttpServletRequest httpReq, HttpServletResponse httpRes) throws Exception { + doQuery(graphQLRequest.getQuery(), graphQLRequest.getOperationName(), graphQLRequest.getVariables(), schema, context, rootObject, httpReq, httpRes); + } + + private void doQuery(String query, String operationName, Map variables, GraphQLSchema schema, GraphQLContext context, Object rootObject, HttpServletRequest req, HttpServletResponse resp) throws Exception { + query(query, operationName, variables, schema, context, rootObject, (r) -> { + resp.setContentType(APPLICATION_JSON_UTF8); + resp.setStatus(r.getStatus()); + resp.getWriter().write(r.getResponse()); + }); + } + + private void doBatchedQuery(Iterator graphQLRequests, GraphQLSchema schema, GraphQLContext context, Object rootObject, HttpServletRequest req, HttpServletResponse resp) throws Exception { + resp.setContentType(APPLICATION_JSON_UTF8); + resp.setStatus(STATUS_OK); + + Writer respWriter = resp.getWriter(); + respWriter.write('['); + while (graphQLRequests.hasNext()) { + GraphQLRequest graphQLRequest = graphQLRequests.next(); + query(graphQLRequest.getQuery(), graphQLRequest.getOperationName(), graphQLRequest.getVariables(), schema, context, rootObject, (r) -> respWriter.write(r.getResponse())); + if (graphQLRequests.hasNext()) { + respWriter.write(','); + } + } + respWriter.write(']'); + } + + private void query(String query, String operationName, Map variables, GraphQLSchema schema, GraphQLContext context, Object rootObject, GraphQLResponseHandler responseHandler) throws Exception { if (operationName != null && operationName.isEmpty()) { - query(query, null, variables, schema, req, resp, context, rootObject); + query(query, null, variables, schema, context, rootObject, responseHandler); } else if (Subject.getSubject(AccessController.getContext()) == null && context.getSubject().isPresent()) { Subject.doAs(context.getSubject().get(), (PrivilegedAction) () -> { try { - query(query, operationName, variables, schema, req, resp, context, rootObject); - } catch (IOException e) { + query(query, operationName, variables, schema, context, rootObject, responseHandler); + } catch (Exception e) { throw new RuntimeException(e); } return null; @@ -287,9 +338,10 @@ private void query(String query, String operationName, Map varia final String response = getMapper().writeValueAsString(createResultFromDataAndErrors(data, errors)); - resp.setContentType(APPLICATION_JSON_UTF8); - resp.setStatus(STATUS_OK); - resp.getWriter().write(response); + GraphQLResponse graphQLResponse = new GraphQLResponse(); + graphQLResponse.setStatus(STATUS_OK); + graphQLResponse.setResponse(response); + responseHandler.handle(graphQLResponse); if(getGraphQLErrorHandler().errorsPresent(errors)) { runCallbacks(operationCallbacks, c -> c.onError(context, operationName, query, variables, data, errors)); @@ -373,6 +425,51 @@ private static Map deserializeVariablesObject(Object variables, } } + private boolean isBatchedQuery(InputStream inputStream) throws IOException { + if (inputStream == null) { + return false; + } + + ByteArrayOutputStream result = new ByteArrayOutputStream(); + byte[] buffer = new byte[128]; + int length; + + inputStream.mark(0); + while ((length = inputStream.read(buffer)) != -1) { + result.write(buffer, 0, length); + String chunk = result.toString(); + Boolean isArrayStart = isArrayStart(chunk); + if (isArrayStart != null) { + inputStream.reset(); + return isArrayStart; + } + } + + inputStream.reset(); + return false; + } + + private boolean isBatchedQuery(String query) { + if (query == null) { + return false; + } + + Boolean isArrayStart = isArrayStart(query); + return isArrayStart != null && isArrayStart; + } + + // return true if the first non whitespace character is the beginning of an array + private Boolean isArrayStart(String s) { + for (int i = 0; i < s.length(); i++) { + char ch = s.charAt(i); + if (!Character.isWhitespace(ch)) { + return ch == '['; + } + } + + return null; + } + protected static class GraphQLRequest { private String query; @JsonDeserialize(using = GraphQLServlet.VariablesDeserializer.class) @@ -404,7 +501,28 @@ public void setOperationName(String operationName) { } } - protected interface RequestHandler extends BiConsumer { + protected static class GraphQLResponse { + private int status; + private String response; + + public int getStatus() { + return status; + } + + public void setStatus(int status) { + this.status = status; + } + + public String getResponse() { + return response; + } + + public void setResponse(String response) { + this.response = response; + } + } + + protected interface HttpRequestHandler extends BiConsumer { @Override default void accept(HttpServletRequest request, HttpServletResponse response) { try { @@ -416,4 +534,17 @@ default void accept(HttpServletRequest request, HttpServletResponse response) { void handle(HttpServletRequest request, HttpServletResponse response) throws Exception; } + + protected interface GraphQLResponseHandler extends Consumer { + @Override + default void accept(GraphQLResponse response) { + try { + handle(response); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + void handle(GraphQLResponse r) throws Exception; + } } diff --git a/src/test/groovy/graphql/servlet/GraphQLServletSpec.groovy b/src/test/groovy/graphql/servlet/GraphQLServletSpec.groovy index fa53f48f..4b32e332 100644 --- a/src/test/groovy/graphql/servlet/GraphQLServletSpec.groovy +++ b/src/test/groovy/graphql/servlet/GraphQLServletSpec.groovy @@ -75,6 +75,10 @@ class GraphQLServletSpec extends Specification { mapper.readValue(response.getContentAsByteArray(), Map) } + List> getBatchedResponseContent() { + mapper.readValue(response.getContentAsByteArray(), List) + } + def "HTTP GET without info returns bad request"() { when: servlet.doGet(request, response) @@ -161,7 +165,76 @@ class GraphQLServletSpec extends Specification { response.getStatus() == STATUS_OK response.getContentType() == CONTENT_TYPE_JSON_UTF8 getResponseContent().data.echo == "test" + } + + def "batched query over HTTP GET returns data"() { + setup: + request.addParameter('query', '[{ "query": "query { echo(arg:\\"test\\") }" }, { "query": "query { echo(arg:\\"test\\") }" }]') + + when: + servlet.doGet(request, response) + + then: + response.getStatus() == STATUS_OK + response.getContentType() == CONTENT_TYPE_JSON_UTF8 + getBatchedResponseContent()[0].data.echo == "test" + getBatchedResponseContent()[1].data.echo == "test" + } + + def "batched query over HTTP GET with variables returns data"() { + setup: + request.addParameter('query', '[{ "query": "query { echo(arg:\\"test\\") }", "variables": { "arg": "test" } }, { "query": "query { echo(arg:\\"test\\") }", "variables": { "arg": "test" } }]') + + when: + servlet.doGet(request, response) + + then: + response.getStatus() == STATUS_OK + response.getContentType() == CONTENT_TYPE_JSON_UTF8 + getBatchedResponseContent()[0].data.echo == "test" + getBatchedResponseContent()[1].data.echo == "test" + } + + def "batched query over HTTP GET with variables as string returns data"() { + setup: + request.addParameter('query', '[{ "query": "query { echo(arg:\\"test\\") }", "variables": "{ \\"arg\\": \\"test\\" }" }, { "query": "query { echo(arg:\\"test\\") }", "variables": "{ \\"arg\\": \\"test\\" }" }]') + when: + servlet.doGet(request, response) + + then: + response.getStatus() == STATUS_OK + response.getContentType() == CONTENT_TYPE_JSON_UTF8 + getBatchedResponseContent()[0].data.echo == "test" + getBatchedResponseContent()[1].data.echo == "test" + } + + def "batched query over HTTP GET with operationName returns data"() { + when: + response = new MockHttpServletResponse() + request.addParameter('query', '[{ "query": "query one{ echoOne: echo(arg:\\"test-one\\") } query two{ echoTwo: echo(arg:\\"test-two\\") }", "operationName": "one" }, { "query": "query one{ echoOne: echo(arg:\\"test-one\\") } query two{ echoTwo: echo(arg:\\"test-two\\") }", "operationName": "two" }]') + servlet.doGet(request, response) + + then: + response.getStatus() == STATUS_OK + response.getContentType() == CONTENT_TYPE_JSON_UTF8 + getBatchedResponseContent()[0].data.echoOne == "test-one" + getBatchedResponseContent()[0].data.echoTwo == null + getBatchedResponseContent()[1].data.echoOne == null + getBatchedResponseContent()[1].data.echoTwo == "test-two" + } + + def "batched query over HTTP GET with empty non-null operationName returns data"() { + when: + response = new MockHttpServletResponse() + request.addParameter('query', '[{ "query": "query echo{ echo: echo(arg:\\"test\\") }", "operationName": "" }, { "query": "query echo{ echo: echo(arg:\\"test\\") }", "operationName": "" }]') + servlet.doGet(request, response) + + then: + response.getStatus() == STATUS_OK + response.getContentType() == CONTENT_TYPE_JSON_UTF8 + getBatchedResponseContent()[0].data.echo == "test" + getBatchedResponseContent()[1].data.echo == "test" } def "mutation over HTTP GET returns errors"() { @@ -177,6 +250,20 @@ class GraphQLServletSpec extends Specification { getResponseContent().errors.size() == 1 } + def "batched mutation over HTTP GET returns errors"() { + setup: + request.addParameter('query', '[{ "query": "mutation { echo(arg:\\"test\\") }" }, { "query": "mutation {echo(arg:\\"test\\") }" }]') + + when: + servlet.doGet(request, response) + + then: + response.getStatus() == STATUS_OK + response.getContentType() == CONTENT_TYPE_JSON_UTF8 + getBatchedResponseContent()[0].errors.size() == 1 + getBatchedResponseContent()[1].errors.size() == 1 + } + def "query over HTTP POST without part or body returns bad request"() { when: servlet.doPost(request, response) @@ -339,6 +426,157 @@ class GraphQLServletSpec extends Specification { getResponseContent().data.echo == "test" } + def "batched query over HTTP POST body returns data"() { + setup: + request.setContent('[{ "query": "query { echo(arg:\\"test\\") }" }, { "query": "query { echo(arg:\\"test\\") }" }]'.bytes) + + when: + servlet.doPost(request, response) + + then: + response.getStatus() == STATUS_OK + response.getContentType() == CONTENT_TYPE_JSON_UTF8 + getBatchedResponseContent()[0].data.echo == "test" + getBatchedResponseContent()[1].data.echo == "test" + } + + def "batched query over HTTP POST body with variables returns data"() { + setup: + request.setContent('[{ "query": "query { echo(arg:\\"test\\") }", "variables": { "arg": "test" } }, { "query": "query { echo(arg:\\"test\\") }", "variables": { "arg": "test" } }]'.bytes) + + when: + servlet.doPost(request, response) + + then: + response.getStatus() == STATUS_OK + response.getContentType() == CONTENT_TYPE_JSON_UTF8 + getBatchedResponseContent()[0].data.echo == "test" + getBatchedResponseContent()[1].data.echo == "test" + } + + def "batched query over HTTP POST body with operationName returns data"() { + setup: + request.setContent('[{ "query": "query one{ echoOne: echo(arg:\\"test-one\\") } query two{ echoTwo: echo(arg:\\"test-two\\") }", "operationName": "one" }, { "query": "query one{ echoOne: echo(arg:\\"test-one\\") } query two{ echoTwo: echo(arg:\\"test-two\\") }", "operationName": "two" }]'.bytes) + + when: + servlet.doPost(request, response) + + then: + response.getStatus() == STATUS_OK + response.getContentType() == CONTENT_TYPE_JSON_UTF8 + getBatchedResponseContent()[0].data.echoOne == "test-one" + getBatchedResponseContent()[0].data.echoTwo == null + getBatchedResponseContent()[1].data.echoOne == null + getBatchedResponseContent()[1].data.echoTwo == "test-two" + } + + def "batched query over HTTP POST body with empty non-null operationName returns data"() { + setup: + request.setContent('[{ "query": "query echo{ echo: echo(arg:\\"test\\") }", "operationName": "" }, { "query": "query echo{ echo: echo(arg:\\"test\\") }", "operationName": "" }]'.bytes) + + when: + servlet.doPost(request, response) + + then: + response.getStatus() == STATUS_OK + response.getContentType() == CONTENT_TYPE_JSON_UTF8 + getBatchedResponseContent()[0].data.echo == "test" + getBatchedResponseContent()[1].data.echo == "test" + } + + def "batched query over HTTP POST multipart named 'graphql' returns data"() { + setup: + request.setContentType("multipart/form-data, boundary=test") + request.setMethod("POST") + + request.setContent(new TestMultipartContentBuilder() + .addPart('graphql', '[{ "query": "query { echo(arg:\\"test\\") }" }, { "query": "query { echo(arg:\\"test\\") }" }]') + .build()) + + when: + servlet.doPost(request, response) + + then: + response.getStatus() == STATUS_OK + response.getContentType() == CONTENT_TYPE_JSON_UTF8 + getBatchedResponseContent()[0].data.echo == "test" + getBatchedResponseContent()[1].data.echo == "test" + } + + def "batched query over HTTP POST multipart named 'query' returns data"() { + setup: + request.setContentType("multipart/form-data, boundary=test") + request.setMethod("POST") + request.setContent(new TestMultipartContentBuilder() + .addPart('query', '[{ "query": "query { echo(arg:\\"test\\") }" }, { "query": "query { echo(arg:\\"test\\") }" }]') + .build()) + + when: + servlet.doPost(request, response) + + then: + response.getStatus() == STATUS_OK + response.getContentType() == CONTENT_TYPE_JSON_UTF8 + getBatchedResponseContent()[0].data.echo == "test" + getBatchedResponseContent()[1].data.echo == "test" + } + + def "batched query over HTTP POST multipart named 'query' with operationName returns data"() { + setup: + request.setContentType("multipart/form-data, boundary=test") + request.setMethod("POST") + request.setContent(new TestMultipartContentBuilder() + .addPart('query', '[{ "query": "query one{ echoOne: echo(arg:\\"test-one\\") } query two{ echoTwo: echo(arg:\\"test-two\\") }", "operationName": "one" }, { "query": "query one{ echoOne: echo(arg:\\"test-one\\") } query two{ echoTwo: echo(arg:\\"test-two\\") }", "operationName": "two" }]') + .build()) + + when: + servlet.doPost(request, response) + + then: + response.getStatus() == STATUS_OK + response.getContentType() == CONTENT_TYPE_JSON_UTF8 + getBatchedResponseContent()[0].data.echoOne == "test-one" + getBatchedResponseContent()[0].data.echoTwo == null + getBatchedResponseContent()[1].data.echoOne == null + getBatchedResponseContent()[1].data.echoTwo == "test-two" + } + + def "batched query over HTTP POST multipart named 'query' with empty non-null operationName returns data"() { + setup: + request.setContentType("multipart/form-data, boundary=test") + request.setMethod("POST") + request.setContent(new TestMultipartContentBuilder() + .addPart('query', '[{ "query": "query echo{ echo: echo(arg:\\"test\\") }", "operationName": "" }, { "query": "query echo{ echo: echo(arg:\\"test\\") }", "operationName": "" }]') + .build()) + + when: + servlet.doPost(request, response) + + then: + response.getStatus() == STATUS_OK + response.getContentType() == CONTENT_TYPE_JSON_UTF8 + getBatchedResponseContent()[0].data.echo == "test" + getBatchedResponseContent()[1].data.echo == "test" + } + + def "batched query over HTTP POST multipart named 'query' with variables returns data"() { + setup: + request.setContentType("multipart/form-data, boundary=test") + request.setMethod("POST") + request.setContent(new TestMultipartContentBuilder() + .addPart('query', '[{ "query": "query echo($arg: String) { echo(arg:$arg) }", "variables": { "arg": "test" } }, { "query": "query echo($arg: String) { echo(arg:$arg) }", "variables": { "arg": "test" } }]') + .build()) + + when: + servlet.doPost(request, response) + + then: + response.getStatus() == STATUS_OK + response.getContentType() == CONTENT_TYPE_JSON_UTF8 + getBatchedResponseContent()[0].data.echo == "test" + getBatchedResponseContent()[1].data.echo == "test" + } + def "mutation over HTTP POST body returns data"() { setup: request.setContent(mapper.writeValueAsBytes([ @@ -354,6 +592,20 @@ class GraphQLServletSpec extends Specification { getResponseContent().data.echo == "test" } + def "batched mutation over HTTP POST body returns data"() { + setup: + request.setContent('[{ "query": "mutation { echo(arg:\\"test\\") }" }, { "query": "mutation { echo(arg:\\"test\\") }" }]'.bytes) + + when: + servlet.doPost(request, response) + + then: + response.getStatus() == STATUS_OK + response.getContentType() == CONTENT_TYPE_JSON_UTF8 + getBatchedResponseContent()[0].data.echo == "test" + getBatchedResponseContent()[1].data.echo == "test" + } + def "errors before graphql schema execution return internal server error"() { setup: servlet = new SimpleGraphQLServlet(servlet.getSchemaProvider().getSchema()) { @@ -389,6 +641,24 @@ class GraphQLServletSpec extends Specification { errors.first().message.startsWith("Internal Server Error(s)") } + def "batched errors while data fetching are masked in the response"() { + setup: + servlet = createServlet({ throw new TestException() }) + request.addParameter('query', '[{ "query": "query { echo(arg:\\"test\\") }" }, { "query": "query { echo(arg:\\"test\\") }" }]') + + when: + servlet.doGet(request, response) + + then: + response.getStatus() == STATUS_OK + response.getContentType() == CONTENT_TYPE_JSON_UTF8 + def errors = getBatchedResponseContent().errors + errors[0].size() == 1 + errors[0].first().message.startsWith("Internal Server Error(s)") + errors[1].size() == 1 + errors[1].first().message.startsWith("Internal Server Error(s)") + } + def "data field is present and null if no data can be returned"() { setup: request.addParameter('query', 'query { not-a-field(arg:"test") }') @@ -405,6 +675,25 @@ class GraphQLServletSpec extends Specification { resp.errors != null } + def "batched data field is present and null if no data can be returned"() { + setup: + request.addParameter('query', '[{ "query": "query { not-a-field(arg:\\"test\\") }" }, { "query": "query { not-a-field(arg:\\"test\\") }" }]') + + when: + servlet.doGet(request, response) + + then: + response.getStatus() == STATUS_OK + response.getContentType() == CONTENT_TYPE_JSON_UTF8 + def resp = getBatchedResponseContent() + resp[0].containsKey("data") + resp[0].data == null + resp[0].errors != null + resp[1].containsKey("data") + resp[1].data == null + resp[1].errors != null + } + def "typeInfo is serialized correctly"() { expect: servlet.getMapper().writeValueAsString(ExecutionTypeInfo.newTypeInfo().type(new GraphQLNonNull(Scalars.GraphQLString)).build()) != "{}"