Skip to content

Commit

Permalink
Merge pull request #519 from graphql-java-kickstart/feature/516-webso…
Browse files Browse the repository at this point in the history
…cket-origin

fix(#516): add origin check to websockets
  • Loading branch information
oliemansm authored May 9, 2023
2 parents 6a0b786 + 060f753 commit b25975b
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ public class GraphQLConfiguration {
private final ContextSetting contextSetting;
private final GraphQLResponseCacheManager responseCacheManager;
@Getter private final Executor asyncExecutor;
@Getter private final List<String> allowedOrigins;
private HttpRequestHandler requestHandler;

private GraphQLConfiguration(
Expand All @@ -49,9 +50,11 @@ private GraphQLConfiguration(
ContextSetting contextSetting,
Supplier<BatchInputPreProcessor> batchInputPreProcessor,
GraphQLResponseCacheManager responseCacheManager,
Executor asyncExecutor) {
Executor asyncExecutor,
List<String> allowedOrigins) {
this.invocationInputFactory = invocationInputFactory;
this.asyncExecutor = asyncExecutor;
this.allowedOrigins = allowedOrigins;
this.graphQLInvoker = graphQLInvoker != null ? graphQLInvoker : queryInvoker.toGraphQLInvoker();
this.objectMapper = objectMapper;
this.listeners = listeners;
Expand Down Expand Up @@ -148,6 +151,7 @@ public static class Builder {
private int asyncMaxPoolSize = 200;
private Executor asyncExecutor;
private AsyncTaskDecorator asyncTaskDecorator;
private List<String> allowedOrigins = new ArrayList<>();

private Builder(GraphQLInvocationInputFactory.Builder invocationInputFactoryBuilder) {
this.invocationInputFactoryBuilder = invocationInputFactoryBuilder;
Expand Down Expand Up @@ -249,6 +253,13 @@ public Builder with(AsyncTaskDecorator asyncTaskDecorator) {
return this;
}

public Builder allowedOrigins(List<String> allowedOrigins) {
if (allowedOrigins != null) {
this.allowedOrigins.addAll(allowedOrigins);
}
return this;
}

private Executor getAsyncExecutor() {
if (asyncExecutor != null) {
return asyncExecutor;
Expand Down Expand Up @@ -279,7 +290,8 @@ public GraphQLConfiguration build() {
contextSetting,
batchInputPreProcessorSupplier,
responseCacheManager,
getAsyncTaskExecutor());
getAsyncTaskExecutor(),
allowedOrigins);
}
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package graphql.kickstart.servlet;

import static java.util.Arrays.asList;
import static java.util.Collections.emptyList;
import static java.util.Collections.singletonList;
import static java.util.stream.Collectors.toList;

Expand Down Expand Up @@ -65,6 +66,7 @@ public class GraphQLWebsocketServlet extends Endpoint {
private final AtomicBoolean isShuttingDown = new AtomicBoolean(false);
private final AtomicBoolean isShutDown = new AtomicBoolean(false);
private final Object cacheLock = new Object();
private final List<String> allowedOrigins;

public GraphQLWebsocketServlet(GraphQLConfiguration configuration) {
this(configuration, null);
Expand All @@ -77,21 +79,23 @@ public GraphQLWebsocketServlet(
configuration.getGraphQLInvoker(),
configuration.getInvocationInputFactory(),
configuration.getObjectMapper(),
connectionListeners);
connectionListeners,
configuration.getAllowedOrigins());
}

public GraphQLWebsocketServlet(
GraphQLInvoker graphQLInvoker,
GraphQLSubscriptionInvocationInputFactory invocationInputFactory,
GraphQLObjectMapper graphQLObjectMapper) {
this(graphQLInvoker, invocationInputFactory, graphQLObjectMapper, null);
this(graphQLInvoker, invocationInputFactory, graphQLObjectMapper, null, emptyList());
}

public GraphQLWebsocketServlet(
GraphQLInvoker graphQLInvoker,
GraphQLSubscriptionInvocationInputFactory invocationInputFactory,
GraphQLObjectMapper graphQLObjectMapper,
Collection<SubscriptionConnectionListener> connectionListeners) {
Collection<SubscriptionConnectionListener> connectionListeners,
List<String> allowedOrigins) {
List<ApolloSubscriptionConnectionListener> listeners = new ArrayList<>();
if (connectionListeners != null) {
connectionListeners.stream()
Expand All @@ -114,12 +118,10 @@ public GraphQLWebsocketServlet(
Stream.of(fallbackSubscriptionProtocolFactory))
.map(SubscriptionProtocolFactory::getProtocol)
.collect(toList());
this.allowedOrigins = allowedOrigins;
}

public GraphQLWebsocketServlet(
GraphQLInvoker graphQLInvoker,
GraphQLSubscriptionInvocationInputFactory invocationInputFactory,
GraphQLObjectMapper graphQLObjectMapper,
List<SubscriptionProtocolFactory> subscriptionProtocolFactory,
SubscriptionProtocolFactory fallbackSubscriptionProtocolFactory) {

Expand All @@ -132,6 +134,8 @@ public GraphQLWebsocketServlet(
Stream.of(fallbackSubscriptionProtocolFactory))
.map(SubscriptionProtocolFactory::getProtocol)
.collect(toList());

this.allowedOrigins = emptyList();
}

@Override
Expand Down Expand Up @@ -202,6 +206,26 @@ private void closeUnexpectedly(Session session, Throwable t) {
}
}

public boolean checkOrigin(String originHeaderValue) {
if (originHeaderValue == null || originHeaderValue.isBlank()) {
return allowedOrigins.isEmpty();
}
String originToCheck = trimTrailingSlash(originHeaderValue);
if (!allowedOrigins.isEmpty()) {
if (allowedOrigins.contains("*")) {
return true;
}
return allowedOrigins.stream()
.map(this::trimTrailingSlash)
.anyMatch(originToCheck::equalsIgnoreCase);
}
return true;
}

private String trimTrailingSlash(String origin) {
return (origin.endsWith("/") ? origin.substring(0, origin.length() - 1) : origin);
}

public void modifyHandshake(
ServerEndpointConfig sec, HandshakeRequest request, HandshakeResponse response) {
sec.getUserProperties().put(HANDSHAKE_REQUEST_KEY, request);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package graphql.kickstart.servlet

import spock.lang.Specification

class GraphQLWebsocketServletSpec extends Specification {

def "checkOrigin without any allowed origins allows given origin"() {
given: "a websocket servlet with no allowed origins"
def servlet = new GraphQLWebsocketServlet(GraphQLConfiguration.with(TestUtils.createGraphQlSchema()).build())

when: "we check origin http://localhost:8080"
def allowed = servlet.checkOrigin("http://localhost:8080")

then:
allowed
}

def "checkOrigin without any allowed origins allows when no origin given"() {
given: "a websocket servlet with no allowed origins"
def servlet = new GraphQLWebsocketServlet(GraphQLConfiguration.with(TestUtils.createGraphQlSchema()).build())

when: "we check origin null"
def allowed = servlet.checkOrigin(null)

then:
allowed
}

def "checkOrigin without any allowed origins allows when origin is empty"() {
given: "a websocket servlet with no allowed origins"
def servlet = new GraphQLWebsocketServlet(GraphQLConfiguration.with(TestUtils.createGraphQlSchema()).build())

when: "we check origin null"
def allowed = servlet.checkOrigin(" ")

then:
allowed
}

def "checkOrigin with allow all origins allows given origin"() {
given: "a websocket servlet with allow all origins"
def servlet = new GraphQLWebsocketServlet(GraphQLConfiguration.with(TestUtils.createGraphQlSchema()).allowedOrigins(List.of("*")).build())

when: "we check origin http://localhost:8080"
def allowed = servlet.checkOrigin("http://localhost:8080")

then:
allowed
}

def "checkOrigin with specific allowed origins allows given origin"() {
given: "a websocket servlet with allow all origins"
def servlet = new GraphQLWebsocketServlet(GraphQLConfiguration.with(TestUtils.createGraphQlSchema()).allowedOrigins(List.of("http://localhost:8080")).build())

when: "we check origin http://localhost:8080"
def allowed = servlet.checkOrigin("http://localhost:8080")

then:
allowed
}

def "checkOrigin with specific allowed origins allows given origin with trailing slash"() {
given: "a websocket servlet with allow all origins"
def servlet = new GraphQLWebsocketServlet(GraphQLConfiguration.with(TestUtils.createGraphQlSchema()).allowedOrigins(List.of("http://localhost:8080")).build())

when: "we check origin http://localhost:8080/"
def allowed = servlet.checkOrigin("http://localhost:8080/")

then:
allowed
}

def "checkOrigin with specific allowed origins with trailing slash allows given origin without trailing slash"() {
given: "a websocket servlet with allow all origins"
def servlet = new GraphQLWebsocketServlet(GraphQLConfiguration.with(TestUtils.createGraphQlSchema()).allowedOrigins(List.of("http://localhost:8080/")).build())

when: "we check origin http://localhost:8080"
def allowed = servlet.checkOrigin("http://localhost:8080")

then:
allowed
}
}

0 comments on commit b25975b

Please sign in to comment.