Skip to content

Commit

Permalink
WebSockets Next: activate CDI request context only if needed
Browse files Browse the repository at this point in the history
- related to quarkusio#39148
  • Loading branch information
mkouba committed Oct 16, 2024
1 parent 2c9a5ac commit 744238e
Show file tree
Hide file tree
Showing 15 changed files with 362 additions and 23 deletions.
5 changes: 4 additions & 1 deletion docs/src/main/asciidoc/websockets-next-reference.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,10 @@ The session context remains active until the `@OnClose` method completes executi
In cases where a WebSocket endpoint does not declare an `@OnOpen` method, the session context is still created.
It remains active until the connection terminates, regardless of the presence of an `@OnClose` method.

Methods annotated with `@OnTextMessage,` `@OnBinaryMessage,` `@OnOpen`, and `@OnClose` also have the request scope activated for the duration of the method execution (until it produced its result).
Endpoint callbacks may also have the request context activated for the duration of the method execution (until it produced its result).
By default, the request context is only activated if needed, i.e. if there is a request scoped bean , or a bean annotated with a security annotation (such as `@RolesAllowed`) in the dependency tree of the endpoint.
However, it is possible to set the `quarkus.websockets-next.server.activate-request-context` config property to `always`.
In this case, the request context is always activated when an endpoint callback is invoked.

[[callback-methods]]
=== Callback methods
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Predicate;
Expand Down Expand Up @@ -51,6 +54,8 @@
import io.quarkus.arc.deployment.ValidationPhaseBuildItem.ValidationErrorBuildItem;
import io.quarkus.arc.processor.Annotations;
import io.quarkus.arc.processor.BeanInfo;
import io.quarkus.arc.processor.BeanResolver;
import io.quarkus.arc.processor.BuiltinBean;
import io.quarkus.arc.processor.BuiltinScope;
import io.quarkus.arc.processor.DotNames;
import io.quarkus.arc.processor.InjectionPointInfo;
Expand Down Expand Up @@ -95,6 +100,7 @@
import io.quarkus.websockets.next.WebSocketConnection;
import io.quarkus.websockets.next.WebSocketException;
import io.quarkus.websockets.next.WebSocketServerException;
import io.quarkus.websockets.next.WebSocketsServerBuildConfig;
import io.quarkus.websockets.next.deployment.Callback.MessageType;
import io.quarkus.websockets.next.deployment.Callback.Target;
import io.quarkus.websockets.next.runtime.BasicWebSocketConnectorImpl;
Expand Down Expand Up @@ -443,19 +449,85 @@ public String apply(String name) {
@Consume(SyntheticBeansRuntimeInitBuildItem.class) // SecurityHttpUpgradeCheck is runtime init due to runtime config
@Record(RUNTIME_INIT)
@BuildStep
public void registerRoutes(WebSocketServerRecorder recorder, List<GeneratedEndpointBuildItem> generatedEndpoints,
BuildProducer<RouteBuildItem> routes) {
public void registerRoutes(WebSocketServerRecorder recorder, List<WebSocketEndpointBuildItem> endpoints,
List<GeneratedEndpointBuildItem> generatedEndpoints, WebSocketsServerBuildConfig config,
ValidationPhaseBuildItem validationPhase, BuildProducer<RouteBuildItem> routes) {
for (GeneratedEndpointBuildItem endpoint : generatedEndpoints.stream().filter(GeneratedEndpointBuildItem::isServer)
.toList()) {
RouteBuildItem.Builder builder = RouteBuildItem.builder()
.route(endpoint.path)
.displayOnNotFoundPage("WebSocket Endpoint")
.handlerType(HandlerType.NORMAL)
.handler(recorder.createEndpointHandler(endpoint.generatedClassName, endpoint.endpointId));
.handler(recorder.createEndpointHandler(endpoint.generatedClassName, endpoint.endpointId,
activateRequestContext(config, endpoint.endpointId, endpoints, validationPhase.getBeanResolver())));
routes.produce(builder.build());
}
}

private boolean activateRequestContext(WebSocketsServerBuildConfig config, String endpointId,
List<WebSocketEndpointBuildItem> endpoints, BeanResolver beanResolver) {
return switch (config.activateRequestContext()) {
case ALWAYS -> true;
case AUTO -> needsRequestContext(findEndpoint(endpointId, endpoints).bean, new HashSet<>(), beanResolver);
default -> throw new IllegalArgumentException("Unexpected value: " + config.activateRequestContext());
};
}

private WebSocketEndpointBuildItem findEndpoint(String endpointId, List<WebSocketEndpointBuildItem> endpoints) {
for (WebSocketEndpointBuildItem endpoint : endpoints) {
if (endpoint.id.equals(endpointId)) {
return endpoint;
}
}
throw new IllegalArgumentException("Endpoint not found: " + endpointId);
}

private boolean needsRequestContext(BeanInfo bean, Set<String> processedBeans, BeanResolver beanResolver) {
if (processedBeans.add(bean.getIdentifier())) {
if (BuiltinScope.REQUEST.is(bean.getScope())
|| (bean.isClassBean()
&& bean.hasAroundInvokeInterceptors()
&& SecurityTransformerUtils.hasSecurityAnnotation(bean.getTarget().get().asClass()))) {
// Bean is:
// 1. Request scoped, or
// 2. Is class-based, has an aroundInvoke interceptor associated and is annotated with a security annotation
return true;
}
for (InjectionPointInfo injectionPoint : bean.getAllInjectionPoints()) {
BeanInfo dependency = injectionPoint.getResolvedBean();
if (dependency != null) {
if (needsRequestContext(dependency, processedBeans, beanResolver)) {
return true;
}
} else {
Type requiredType = null;
Set<AnnotationInstance> qualifiers = null;
if (BuiltinBean.INSTANCE.matches(injectionPoint)) {
requiredType = injectionPoint.getRequiredType();
qualifiers = injectionPoint.getRequiredQualifiers();
} else if (BuiltinBean.LIST.matches(injectionPoint)) {
requiredType = injectionPoint.getRequiredType().asParameterizedType().arguments().get(0);
qualifiers = new HashSet<>(injectionPoint.getRequiredQualifiers());
for (Iterator<AnnotationInstance> it = qualifiers.iterator(); it.hasNext();) {
if (it.next().name().equals(DotNames.ALL)) {
it.remove();
}
}
}
if (requiredType != null) {
// For programmatic lookup and @All List<> we need to resolve the beans manually
for (BeanInfo lookupDependency : beanResolver.resolveBeans(requiredType, qualifiers)) {
if (needsRequestContext(lookupDependency, processedBeans, beanResolver)) {
return true;
}
}
}
}
}
}
return false;
}

@BuildStep
UnremovableBeanBuildItem makeHttpUpgradeChecksUnremovable() {
// we access the checks programmatically
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ public static class MyEndpoint {
@OnTextMessage
public String onMessage(String message) {
assertNotNull(Arc.container().getActiveContext(SessionScoped.class));
assertNotNull(Arc.container().getActiveContext(RequestScoped.class));
// By default, the request context is only activated if needed
assertNull(Arc.container().getActiveContext(RequestScoped.class));
assertNotNull(connection.id());
return message.toUpperCase();
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package io.quarkus.websockets.next.test.requestcontext;

import static org.junit.jupiter.api.Assertions.assertEquals;

import java.net.URI;

import jakarta.enterprise.inject.Instance;
import jakarta.inject.Inject;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.quarkus.arc.Arc;
import io.quarkus.test.QuarkusUnitTest;
import io.quarkus.test.common.http.TestHTTPResource;
import io.quarkus.websockets.next.OnTextMessage;
import io.quarkus.websockets.next.WebSocket;
import io.quarkus.websockets.next.test.utils.WSClient;
import io.vertx.core.Vertx;

public class RequestContextActivatedByInstanceTest {

@RegisterExtension
public static final QuarkusUnitTest test = new QuarkusUnitTest()
.withApplicationRoot(root -> {
root.addClasses(Endpoint.class, WSClient.class, RequestScopedBean.class);
});

@Inject
Vertx vertx;

@TestHTTPResource("end")
URI endUri;

@Test
void testRequestContext() throws InterruptedException {
try (WSClient client = WSClient.create(vertx).connect(endUri)) {
client.sendAndAwait("ping");
client.waitForMessages(1);
assertEquals("pong:true", client.getLastMessage().toString());
}
}

@WebSocket(path = "/end")
public static class Endpoint {

@Inject
Instance<RequestScopedBean> instance;

@OnTextMessage
String process(String message) {
return "pong:" + Arc.container().requestContext().isActive();
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package io.quarkus.websockets.next.test.requestcontext;

import static org.junit.jupiter.api.Assertions.assertEquals;

import java.net.URI;
import java.util.List;

import jakarta.inject.Inject;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.quarkus.arc.All;
import io.quarkus.arc.Arc;
import io.quarkus.test.QuarkusUnitTest;
import io.quarkus.test.common.http.TestHTTPResource;
import io.quarkus.websockets.next.OnTextMessage;
import io.quarkus.websockets.next.WebSocket;
import io.quarkus.websockets.next.test.utils.WSClient;
import io.vertx.core.Vertx;

public class RequestContextActivatedByListTest {

@RegisterExtension
public static final QuarkusUnitTest test = new QuarkusUnitTest()
.withApplicationRoot(root -> {
root.addClasses(Endpoint.class, WSClient.class, RequestScopedBean.class);
});

@Inject
Vertx vertx;

@TestHTTPResource("end")
URI endUri;

@Test
void testRequestContext() throws InterruptedException {
try (WSClient client = WSClient.create(vertx).connect(endUri)) {
client.sendAndAwait("ping");
client.waitForMessages(1);
assertEquals("pong:true", client.getLastMessage().toString());
}
}

@WebSocket(path = "/end")
public static class Endpoint {

@All
List<RequestScopedBean> list;

@OnTextMessage
String process(String message) {
return "pong:" + Arc.container().requestContext().isActive();
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package io.quarkus.websockets.next.test.requestcontext;

import static org.junit.jupiter.api.Assertions.assertEquals;

import java.net.URI;

import jakarta.inject.Inject;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.quarkus.arc.Arc;
import io.quarkus.test.QuarkusUnitTest;
import io.quarkus.test.common.http.TestHTTPResource;
import io.quarkus.websockets.next.OnTextMessage;
import io.quarkus.websockets.next.WebSocket;
import io.quarkus.websockets.next.test.utils.WSClient;
import io.vertx.core.Vertx;

public class RequestContextAlwaysActiveTest {

@RegisterExtension
public static final QuarkusUnitTest test = new QuarkusUnitTest()
.withApplicationRoot(root -> {
root.addClasses(Endpoint.class, WSClient.class);
})
.overrideConfigKey("quarkus.websockets-next.server.activate-request-context", "always");

@Inject
Vertx vertx;

@TestHTTPResource("end")
URI endUri;

@Test
void testRequestContext() throws InterruptedException {
try (WSClient client = WSClient.create(vertx).connect(endUri)) {
client.sendAndAwait("ping");
client.waitForMessages(1);
assertEquals("pong:true", client.getLastMessage().toString());
}
}

@WebSocket(path = "/end")
public static class Endpoint {

@OnTextMessage
String process(String message) {
return "pong:" + Arc.container().requestContext().isActive();
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package io.quarkus.websockets.next.test.requestcontext;

import static org.junit.jupiter.api.Assertions.assertEquals;

import java.net.URI;

import jakarta.inject.Inject;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.quarkus.arc.Arc;
import io.quarkus.test.QuarkusUnitTest;
import io.quarkus.test.common.http.TestHTTPResource;
import io.quarkus.websockets.next.OnTextMessage;
import io.quarkus.websockets.next.WebSocket;
import io.quarkus.websockets.next.test.utils.WSClient;
import io.vertx.core.Vertx;

public class RequestContextNotActiveTest {

@RegisterExtension
public static final QuarkusUnitTest test = new QuarkusUnitTest()
.withApplicationRoot(root -> {
root.addClasses(Endpoint.class, WSClient.class);
});

@Inject
Vertx vertx;

@TestHTTPResource("end")
URI endUri;

@Test
void testRequestContext() throws InterruptedException {
try (WSClient client = WSClient.create(vertx).connect(endUri)) {
client.sendAndAwait("ping");
client.waitForMessages(1);
assertEquals("pong:false", client.getLastMessage().toString());
}
}

@WebSocket(path = "/end")
public static class Endpoint {

@OnTextMessage
String process(String message) {
return "pong:" + Arc.container().requestContext().isActive();
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,13 @@
import io.quarkus.websockets.next.OnOpen;
import io.quarkus.websockets.next.OnTextMessage;
import io.quarkus.websockets.next.WebSocket;
import io.quarkus.websockets.next.test.security.EagerSecurityTest.Endpoint;
import io.quarkus.websockets.next.test.utils.WSClient;

public class LazySecurityTest extends SecurityTestBase {

@RegisterExtension
static final QuarkusUnitTest config = new QuarkusUnitTest()
.withApplicationRoot((jar) -> jar
.withApplicationRoot(root -> root
.addAsResource(new StringAsset("quarkus.http.auth.proactive=false\n" +
"quarkus.http.auth.permission.secured.paths=/end\n" +
"quarkus.http.auth.permission.secured.policy=authenticated\n"), "application.properties")
Expand Down
Loading

0 comments on commit 744238e

Please sign in to comment.