Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(#516): add origin check to websockets #519

Merged
merged 2 commits into from
May 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ public class GraphQLConfiguration {
private final ContextSetting contextSetting;
private final GraphQLResponseCacheManager responseCacheManager;
@Getter private final Executor asyncExecutor;
@Getter private final List<String> allowedOrigins;
private HttpRequestHandler requestHandler;

private GraphQLConfiguration(
Expand All @@ -49,9 +50,11 @@ private GraphQLConfiguration(
ContextSetting contextSetting,
Supplier<BatchInputPreProcessor> batchInputPreProcessor,
GraphQLResponseCacheManager responseCacheManager,
Executor asyncExecutor) {
Executor asyncExecutor,
List<String> allowedOrigins) {
this.invocationInputFactory = invocationInputFactory;
this.asyncExecutor = asyncExecutor;
this.allowedOrigins = allowedOrigins;
this.graphQLInvoker = graphQLInvoker != null ? graphQLInvoker : queryInvoker.toGraphQLInvoker();
this.objectMapper = objectMapper;
this.listeners = listeners;
Expand Down Expand Up @@ -148,6 +151,7 @@ public static class Builder {
private int asyncMaxPoolSize = 200;
private Executor asyncExecutor;
private AsyncTaskDecorator asyncTaskDecorator;
private List<String> allowedOrigins = new ArrayList<>();

private Builder(GraphQLInvocationInputFactory.Builder invocationInputFactoryBuilder) {
this.invocationInputFactoryBuilder = invocationInputFactoryBuilder;
Expand Down Expand Up @@ -249,6 +253,13 @@ public Builder with(AsyncTaskDecorator asyncTaskDecorator) {
return this;
}

public Builder allowedOrigins(List<String> allowedOrigins) {
if (allowedOrigins != null) {
this.allowedOrigins.addAll(allowedOrigins);
}
return this;
}

private Executor getAsyncExecutor() {
if (asyncExecutor != null) {
return asyncExecutor;
Expand Down Expand Up @@ -279,7 +290,8 @@ public GraphQLConfiguration build() {
contextSetting,
batchInputPreProcessorSupplier,
responseCacheManager,
getAsyncTaskExecutor());
getAsyncTaskExecutor(),
allowedOrigins);
}
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package graphql.kickstart.servlet;

import static java.util.Arrays.asList;
import static java.util.Collections.emptyList;
import static java.util.Collections.singletonList;
import static java.util.stream.Collectors.toList;

Expand Down Expand Up @@ -65,6 +66,7 @@ public class GraphQLWebsocketServlet extends Endpoint {
private final AtomicBoolean isShuttingDown = new AtomicBoolean(false);
private final AtomicBoolean isShutDown = new AtomicBoolean(false);
private final Object cacheLock = new Object();
private final List<String> allowedOrigins;

public GraphQLWebsocketServlet(GraphQLConfiguration configuration) {
this(configuration, null);
Expand All @@ -77,21 +79,23 @@ public GraphQLWebsocketServlet(
configuration.getGraphQLInvoker(),
configuration.getInvocationInputFactory(),
configuration.getObjectMapper(),
connectionListeners);
connectionListeners,
configuration.getAllowedOrigins());
}

public GraphQLWebsocketServlet(
GraphQLInvoker graphQLInvoker,
GraphQLSubscriptionInvocationInputFactory invocationInputFactory,
GraphQLObjectMapper graphQLObjectMapper) {
this(graphQLInvoker, invocationInputFactory, graphQLObjectMapper, null);
this(graphQLInvoker, invocationInputFactory, graphQLObjectMapper, null, emptyList());
}

public GraphQLWebsocketServlet(
GraphQLInvoker graphQLInvoker,
GraphQLSubscriptionInvocationInputFactory invocationInputFactory,
GraphQLObjectMapper graphQLObjectMapper,
Collection<SubscriptionConnectionListener> connectionListeners) {
Collection<SubscriptionConnectionListener> connectionListeners,
List<String> allowedOrigins) {
List<ApolloSubscriptionConnectionListener> listeners = new ArrayList<>();
if (connectionListeners != null) {
connectionListeners.stream()
Expand All @@ -114,12 +118,10 @@ public GraphQLWebsocketServlet(
Stream.of(fallbackSubscriptionProtocolFactory))
.map(SubscriptionProtocolFactory::getProtocol)
.collect(toList());
this.allowedOrigins = allowedOrigins;
}

public GraphQLWebsocketServlet(
GraphQLInvoker graphQLInvoker,
GraphQLSubscriptionInvocationInputFactory invocationInputFactory,
GraphQLObjectMapper graphQLObjectMapper,
List<SubscriptionProtocolFactory> subscriptionProtocolFactory,
SubscriptionProtocolFactory fallbackSubscriptionProtocolFactory) {

Expand All @@ -132,6 +134,8 @@ public GraphQLWebsocketServlet(
Stream.of(fallbackSubscriptionProtocolFactory))
.map(SubscriptionProtocolFactory::getProtocol)
.collect(toList());

this.allowedOrigins = emptyList();
}

@Override
Expand Down Expand Up @@ -202,6 +206,26 @@ private void closeUnexpectedly(Session session, Throwable t) {
}
}

public boolean checkOrigin(String originHeaderValue) {
if (originHeaderValue == null || originHeaderValue.isBlank()) {
return allowedOrigins.isEmpty();
}
String originToCheck = trimTrailingSlash(originHeaderValue);
if (!allowedOrigins.isEmpty()) {
if (allowedOrigins.contains("*")) {
return true;
}
return allowedOrigins.stream()
.map(this::trimTrailingSlash)
.anyMatch(originToCheck::equalsIgnoreCase);
}
return true;
}

private String trimTrailingSlash(String origin) {
return (origin.endsWith("/") ? origin.substring(0, origin.length() - 1) : origin);
}

public void modifyHandshake(
ServerEndpointConfig sec, HandshakeRequest request, HandshakeResponse response) {
sec.getUserProperties().put(HANDSHAKE_REQUEST_KEY, request);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package graphql.kickstart.servlet

import spock.lang.Specification

class GraphQLWebsocketServletSpec extends Specification {

def "checkOrigin without any allowed origins allows given origin"() {
given: "a websocket servlet with no allowed origins"
def servlet = new GraphQLWebsocketServlet(GraphQLConfiguration.with(TestUtils.createGraphQlSchema()).build())

when: "we check origin http://localhost:8080"
def allowed = servlet.checkOrigin("http://localhost:8080")

then:
allowed
}

def "checkOrigin without any allowed origins allows when no origin given"() {
given: "a websocket servlet with no allowed origins"
def servlet = new GraphQLWebsocketServlet(GraphQLConfiguration.with(TestUtils.createGraphQlSchema()).build())

when: "we check origin null"
def allowed = servlet.checkOrigin(null)

then:
allowed
}

def "checkOrigin without any allowed origins allows when origin is empty"() {
given: "a websocket servlet with no allowed origins"
def servlet = new GraphQLWebsocketServlet(GraphQLConfiguration.with(TestUtils.createGraphQlSchema()).build())

when: "we check origin null"
def allowed = servlet.checkOrigin(" ")

then:
allowed
}

def "checkOrigin with allow all origins allows given origin"() {
given: "a websocket servlet with allow all origins"
def servlet = new GraphQLWebsocketServlet(GraphQLConfiguration.with(TestUtils.createGraphQlSchema()).allowedOrigins(List.of("*")).build())

when: "we check origin http://localhost:8080"
def allowed = servlet.checkOrigin("http://localhost:8080")

then:
allowed
}

def "checkOrigin with specific allowed origins allows given origin"() {
given: "a websocket servlet with allow all origins"
def servlet = new GraphQLWebsocketServlet(GraphQLConfiguration.with(TestUtils.createGraphQlSchema()).allowedOrigins(List.of("http://localhost:8080")).build())

when: "we check origin http://localhost:8080"
def allowed = servlet.checkOrigin("http://localhost:8080")

then:
allowed
}

def "checkOrigin with specific allowed origins allows given origin with trailing slash"() {
given: "a websocket servlet with allow all origins"
def servlet = new GraphQLWebsocketServlet(GraphQLConfiguration.with(TestUtils.createGraphQlSchema()).allowedOrigins(List.of("http://localhost:8080")).build())

when: "we check origin http://localhost:8080/"
def allowed = servlet.checkOrigin("http://localhost:8080/")

then:
allowed
}

def "checkOrigin with specific allowed origins with trailing slash allows given origin without trailing slash"() {
given: "a websocket servlet with allow all origins"
def servlet = new GraphQLWebsocketServlet(GraphQLConfiguration.with(TestUtils.createGraphQlSchema()).allowedOrigins(List.of("http://localhost:8080/")).build())

when: "we check origin http://localhost:8080"
def allowed = servlet.checkOrigin("http://localhost:8080")

then:
allowed
}
}