Skip to content

Added support for query batching #48

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 8, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

# GraphQL Servlet

This module implements a GraphQL Java Servlet. It also supports Relay.js and OSGi out of the box.
This module implements a GraphQL Java Servlet. It also supports Relay.js, Apollo and OSGi out of the box.

# Downloading

Expand Down Expand Up @@ -114,6 +114,10 @@ You **MUST** pass this execution strategy to the servlet for Relay.js support.

This is the default execution strategy for the `OsgiGraphQLServlet`, and must be added as a dependency when using that servlet.

## Apollo support

Query batching is supported, no configuration required.

## Spring Framework support

To use the servlet with Spring Framework, either use the [Spring Boot starter](https://github.com/graphql-java/graphql-spring-boot) or simply define a `ServletRegistrationBean` in a web app:
Expand Down
247 changes: 189 additions & 58 deletions src/main/java/graphql/servlet/GraphQLServlet.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package graphql.servlet;

import com.fasterxml.jackson.annotation.JacksonInject;
import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.DeserializationContext;
Expand Down Expand Up @@ -30,13 +29,17 @@
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.BufferedInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.Writer;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
Expand Down Expand Up @@ -69,8 +72,8 @@ public abstract class GraphQLServlet extends HttpServlet implements Servlet, Gra
private final List<GraphQLServletListener> listeners;
private final ServletFileUpload fileUpload;

private final RequestHandler getHandler;
private final RequestHandler postHandler;
private final HttpRequestHandler getHandler;
private final HttpRequestHandler postHandler;

public GraphQLServlet() {
this(null, null, null);
Expand All @@ -84,23 +87,31 @@ public GraphQLServlet(ObjectMapperConfigurer objectMapperConfigurer, List<GraphQ
this.getHandler = (request, response) -> {
final GraphQLContext context = createContext(Optional.of(request), Optional.of(response));
final Object rootObject = createRootObject(Optional.of(request), Optional.of(response));

String path = request.getPathInfo();
if (path == null) {
path = request.getServletPath();
}
if (path.contentEquals("/schema.json")) {
query(IntrospectionQuery.INTROSPECTION_QUERY, null, new HashMap<>(), getSchemaProvider().getSchema(request), request, response, context, rootObject);
doQuery(IntrospectionQuery.INTROSPECTION_QUERY, null, new HashMap<>(), getSchemaProvider().getSchema(request), context, rootObject, request, response);
} else {
if (request.getParameter("query") != null) {
final Map<String, Object> variables = new HashMap<>();
if (request.getParameter("variables") != null) {
variables.putAll(deserializeVariables(request.getParameter("variables")));
}
String operationName = null;
if (request.getParameter("operationName") != null) {
operationName = request.getParameter("operationName");
String query = request.getParameter("query");
if (query != null) {
if (isBatchedQuery(query)) {
doBatchedQuery(getGraphQLRequestMapper().readValues(query), getSchemaProvider().getReadOnlySchema(request), context, rootObject, request, response);
} else {
final Map<String, Object> variables = new HashMap<>();
if (request.getParameter("variables") != null) {
variables.putAll(deserializeVariables(request.getParameter("variables")));
}

String operationName = null;
if (request.getParameter("operationName") != null) {
operationName = request.getParameter("operationName");
}

doQuery(query, operationName, variables, getSchemaProvider().getReadOnlySchema(request), context, rootObject, request, response);
}
query(request.getParameter("query"), operationName, variables, getSchemaProvider().getReadOnlySchema(request), request, response, context, rootObject);
} else {
response.setStatus(STATUS_BAD_REQUEST);
log.info("Bad GET request: path was not \"/schema.json\" or no query variable named \"query\" given");
Expand All @@ -111,70 +122,82 @@ public GraphQLServlet(ObjectMapperConfigurer objectMapperConfigurer, List<GraphQ
this.postHandler = (request, response) -> {
final GraphQLContext context = createContext(Optional.of(request), Optional.of(response));
final Object rootObject = createRootObject(Optional.of(request), Optional.of(response));
GraphQLRequest graphQLRequest = null;

try {
InputStream inputStream = null;

if (ServletFileUpload.isMultipartContent(request)) {
final Map<String, List<FileItem>> fileItems = fileUpload.parseParameterMap(request);
context.setFiles(Optional.of(fileItems));

if (fileItems.containsKey("graphql")) {
final Optional<FileItem> graphqlItem = getFileItem(fileItems, "graphql");
if (graphqlItem.isPresent()) {
inputStream = graphqlItem.get().getInputStream();
}
InputStream inputStream = graphqlItem.get().getInputStream();

if (!inputStream.markSupported()) {
inputStream = new BufferedInputStream(inputStream);
}

if (isBatchedQuery(inputStream)) {
doBatchedQuery(getGraphQLRequestMapper().readValues(inputStream), getSchemaProvider().getSchema(request), context, rootObject, request, response);
return;
} else {
doQuery(getGraphQLRequestMapper().readValue(inputStream), getSchemaProvider().getSchema(request), context, rootObject, request, response);
return;
}
}
} else if (fileItems.containsKey("query")) {
final Optional<FileItem> queryItem = getFileItem(fileItems, "query");
if (queryItem.isPresent()) {
graphQLRequest = new GraphQLRequest();
graphQLRequest.setQuery(new String(queryItem.get().get()));
InputStream inputStream = queryItem.get().getInputStream();

final Optional<FileItem> operationNameItem = getFileItem(fileItems, "operationName");
if (operationNameItem.isPresent()) {
graphQLRequest.setOperationName(new String(operationNameItem.get().get()).trim());
if (!inputStream.markSupported()) {
inputStream = new BufferedInputStream(inputStream);
}

final Optional<FileItem> variablesItem = getFileItem(fileItems, "variables");
if (variablesItem.isPresent()) {
String variables = new String(variablesItem.get().get());
if (!variables.isEmpty()) {
graphQLRequest.setVariables(deserializeVariables(variables));
if (isBatchedQuery(inputStream)) {
doBatchedQuery(getGraphQLRequestMapper().readValues(inputStream), getSchemaProvider().getSchema(request), context, rootObject, request, response);
return;
} else {
String query = new String(queryItem.get().get());

Map<String, Object> variables = null;
final Optional<FileItem> variablesItem = getFileItem(fileItems, "variables");
if (variablesItem.isPresent()) {
variables = deserializeVariables(new String(variablesItem.get().get()));
}

String operationName = null;
final Optional<FileItem> operationNameItem = getFileItem(fileItems, "operationName");
if (operationNameItem.isPresent()) {
operationName = new String(operationNameItem.get().get()).trim();
}

doQuery(query, operationName, variables, getSchemaProvider().getSchema(request), context, rootObject, request, response);
return;
}
}
}

if (inputStream == null && graphQLRequest == null) {
response.setStatus(STATUS_BAD_REQUEST);
log.info("Bad POST multipart request: no part named \"graphql\" or \"query\"");
return;
}

context.setFiles(Optional.of(fileItems));

response.setStatus(STATUS_BAD_REQUEST);
log.info("Bad POST multipart request: no part named \"graphql\" or \"query\"");
} else {
// this is not a multipart request
inputStream = request.getInputStream();
}
InputStream inputStream = request.getInputStream();

if (graphQLRequest == null) {
graphQLRequest = getGraphQLRequestMapper().readValue(inputStream);
}
if (!inputStream.markSupported()) {
inputStream = new BufferedInputStream(inputStream);
}

if (isBatchedQuery(inputStream)) {
doBatchedQuery(getGraphQLRequestMapper().readValues(inputStream), getSchemaProvider().getSchema(request), context, rootObject, request, response);
} else {
doQuery(getGraphQLRequestMapper().readValue(inputStream), getSchemaProvider().getSchema(request), context, rootObject, request, response);
}
}
} catch (Exception e) {
log.info("Bad POST request: parsing failed", e);
response.setStatus(STATUS_BAD_REQUEST);
return;
}

Map<String,Object> variables = graphQLRequest.getVariables();
if (variables == null) {
variables = new HashMap<>();
}

query(graphQLRequest.getQuery(), graphQLRequest.getOperationName(), variables, getSchemaProvider().getSchema(request), request, response, context, rootObject);
};
}

Expand Down Expand Up @@ -221,7 +244,7 @@ public String executeQuery(String query) {
}
}

private void doRequest(HttpServletRequest request, HttpServletResponse response, RequestHandler handler) {
private void doRequest(HttpServletRequest request, HttpServletResponse response, HttpRequestHandler handler) {

List<GraphQLServletListener.RequestCallback> requestCallbacks = runListeners(l -> l.onRequest(request, response));

Expand Down Expand Up @@ -266,14 +289,42 @@ private GraphQL newGraphQL(GraphQLSchema schema) {
.build();
}

private void query(String query, String operationName, Map<String, Object> variables, GraphQLSchema schema, HttpServletRequest req, HttpServletResponse resp, GraphQLContext context, Object rootObject) throws IOException {
private void doQuery(GraphQLRequest graphQLRequest, GraphQLSchema schema, GraphQLContext context, Object rootObject, HttpServletRequest httpReq, HttpServletResponse httpRes) throws Exception {
doQuery(graphQLRequest.getQuery(), graphQLRequest.getOperationName(), graphQLRequest.getVariables(), schema, context, rootObject, httpReq, httpRes);
}

private void doQuery(String query, String operationName, Map<String, Object> variables, GraphQLSchema schema, GraphQLContext context, Object rootObject, HttpServletRequest req, HttpServletResponse resp) throws Exception {
query(query, operationName, variables, schema, context, rootObject, (r) -> {
resp.setContentType(APPLICATION_JSON_UTF8);
resp.setStatus(r.getStatus());
resp.getWriter().write(r.getResponse());
});
}

private void doBatchedQuery(Iterator<GraphQLRequest> graphQLRequests, GraphQLSchema schema, GraphQLContext context, Object rootObject, HttpServletRequest req, HttpServletResponse resp) throws Exception {
resp.setContentType(APPLICATION_JSON_UTF8);
resp.setStatus(STATUS_OK);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A call to query below might result in an exception. Is it ok to set this here?


Writer respWriter = resp.getWriter();
respWriter.write('[');
while (graphQLRequests.hasNext()) {
GraphQLRequest graphQLRequest = graphQLRequests.next();
query(graphQLRequest.getQuery(), graphQLRequest.getOperationName(), graphQLRequest.getVariables(), schema, context, rootObject, (r) -> respWriter.write(r.getResponse()));
if (graphQLRequests.hasNext()) {
respWriter.write(',');
}
}
respWriter.write(']');
}

private void query(String query, String operationName, Map<String, Object> variables, GraphQLSchema schema, GraphQLContext context, Object rootObject, GraphQLResponseHandler responseHandler) throws Exception {
if (operationName != null && operationName.isEmpty()) {
query(query, null, variables, schema, req, resp, context, rootObject);
query(query, null, variables, schema, context, rootObject, responseHandler);
} else if (Subject.getSubject(AccessController.getContext()) == null && context.getSubject().isPresent()) {
Subject.doAs(context.getSubject().get(), (PrivilegedAction<Void>) () -> {
try {
query(query, operationName, variables, schema, req, resp, context, rootObject);
} catch (IOException e) {
query(query, operationName, variables, schema, context, rootObject, responseHandler);
} catch (Exception e) {
throw new RuntimeException(e);
}
return null;
Expand All @@ -287,9 +338,10 @@ private void query(String query, String operationName, Map<String, Object> varia

final String response = getMapper().writeValueAsString(createResultFromDataAndErrors(data, errors));

resp.setContentType(APPLICATION_JSON_UTF8);
resp.setStatus(STATUS_OK);
resp.getWriter().write(response);
GraphQLResponse graphQLResponse = new GraphQLResponse();
graphQLResponse.setStatus(STATUS_OK);
graphQLResponse.setResponse(response);
responseHandler.handle(graphQLResponse);

if(getGraphQLErrorHandler().errorsPresent(errors)) {
runCallbacks(operationCallbacks, c -> c.onError(context, operationName, query, variables, data, errors));
Expand Down Expand Up @@ -373,6 +425,51 @@ private static Map<String, Object> deserializeVariablesObject(Object variables,
}
}

private boolean isBatchedQuery(InputStream inputStream) throws IOException {
if (inputStream == null) {
return false;
}

ByteArrayOutputStream result = new ByteArrayOutputStream();
byte[] buffer = new byte[128];
int length;

inputStream.mark(0);
while ((length = inputStream.read(buffer)) != -1) {
result.write(buffer, 0, length);
String chunk = result.toString();
Boolean isArrayStart = isArrayStart(chunk);
if (isArrayStart != null) {
inputStream.reset();
return isArrayStart;
}
}

inputStream.reset();
return false;
}

private boolean isBatchedQuery(String query) {
if (query == null) {
return false;
}

Boolean isArrayStart = isArrayStart(query);
return isArrayStart != null && isArrayStart;
}

// return true if the first non whitespace character is the beginning of an array
private Boolean isArrayStart(String s) {
for (int i = 0; i < s.length(); i++) {
char ch = s.charAt(i);
if (!Character.isWhitespace(ch)) {
return ch == '[';
}
}

return null;
}

protected static class GraphQLRequest {
private String query;
@JsonDeserialize(using = GraphQLServlet.VariablesDeserializer.class)
Expand Down Expand Up @@ -404,7 +501,28 @@ public void setOperationName(String operationName) {
}
}

protected interface RequestHandler extends BiConsumer<HttpServletRequest, HttpServletResponse> {
protected static class GraphQLResponse {
private int status;
private String response;

public int getStatus() {
return status;
}

public void setStatus(int status) {
this.status = status;
}

public String getResponse() {
return response;
}

public void setResponse(String response) {
this.response = response;
}
}

protected interface HttpRequestHandler extends BiConsumer<HttpServletRequest, HttpServletResponse> {
@Override
default void accept(HttpServletRequest request, HttpServletResponse response) {
try {
Expand All @@ -416,4 +534,17 @@ default void accept(HttpServletRequest request, HttpServletResponse response) {

void handle(HttpServletRequest request, HttpServletResponse response) throws Exception;
}

protected interface GraphQLResponseHandler extends Consumer<GraphQLResponse> {
@Override
default void accept(GraphQLResponse response) {
try {
handle(response);
} catch (Exception e) {
throw new RuntimeException(e);
}
}

void handle(GraphQLResponse r) throws Exception;
}
}
Loading