diff --git a/docs/src/main/asciidoc/websockets-next-reference.adoc b/docs/src/main/asciidoc/websockets-next-reference.adoc index 36d7a9b98f3f2c..fb0f9f36955c70 100644 --- a/docs/src/main/asciidoc/websockets-next-reference.adoc +++ b/docs/src/main/asciidoc/websockets-next-reference.adoc @@ -182,7 +182,10 @@ The session context remains active until the `@OnClose` method completes executi In cases where a WebSocket endpoint does not declare an `@OnOpen` method, the session context is still created. It remains active until the connection terminates, regardless of the presence of an `@OnClose` method. -Methods annotated with `@OnTextMessage,` `@OnBinaryMessage,` `@OnOpen`, and `@OnClose` also have the request scope activated for the duration of the method execution (until it produced its result). +Endpoint callbacks may also have the request context activated for the duration of the method execution (until it produced its result). +By default, the request context is only activated if needed, i.e. if there is a request scoped bean , or a bean annotated with a security annotation (such as `@RolesAllowed`) in the dependency tree of the endpoint. +However, it is possible to set the `quarkus.websockets-next.server.activate-request-context` config property to `always`. +In this case, the request context is always activated when an endpoint callback is invoked. [[callback-methods]] === Callback methods diff --git a/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketProcessor.java b/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketProcessor.java index ec6054118bc561..bb2f699544ce44 100644 --- a/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketProcessor.java +++ b/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketProcessor.java @@ -6,9 +6,12 @@ import java.util.ArrayList; import java.util.Comparator; import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.Set; import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Predicate; @@ -51,6 +54,8 @@ import io.quarkus.arc.deployment.ValidationPhaseBuildItem.ValidationErrorBuildItem; import io.quarkus.arc.processor.Annotations; import io.quarkus.arc.processor.BeanInfo; +import io.quarkus.arc.processor.BeanResolver; +import io.quarkus.arc.processor.BuiltinBean; import io.quarkus.arc.processor.BuiltinScope; import io.quarkus.arc.processor.DotNames; import io.quarkus.arc.processor.InjectionPointInfo; @@ -95,6 +100,7 @@ import io.quarkus.websockets.next.WebSocketConnection; import io.quarkus.websockets.next.WebSocketException; import io.quarkus.websockets.next.WebSocketServerException; +import io.quarkus.websockets.next.WebSocketsServerBuildConfig; import io.quarkus.websockets.next.deployment.Callback.MessageType; import io.quarkus.websockets.next.deployment.Callback.Target; import io.quarkus.websockets.next.runtime.BasicWebSocketConnectorImpl; @@ -443,19 +449,85 @@ public String apply(String name) { @Consume(SyntheticBeansRuntimeInitBuildItem.class) // SecurityHttpUpgradeCheck is runtime init due to runtime config @Record(RUNTIME_INIT) @BuildStep - public void registerRoutes(WebSocketServerRecorder recorder, List generatedEndpoints, - BuildProducer routes) { + public void registerRoutes(WebSocketServerRecorder recorder, List endpoints, + List generatedEndpoints, WebSocketsServerBuildConfig config, + ValidationPhaseBuildItem validationPhase, BuildProducer routes) { for (GeneratedEndpointBuildItem endpoint : generatedEndpoints.stream().filter(GeneratedEndpointBuildItem::isServer) .toList()) { RouteBuildItem.Builder builder = RouteBuildItem.builder() .route(endpoint.path) .displayOnNotFoundPage("WebSocket Endpoint") .handlerType(HandlerType.NORMAL) - .handler(recorder.createEndpointHandler(endpoint.generatedClassName, endpoint.endpointId)); + .handler(recorder.createEndpointHandler(endpoint.generatedClassName, endpoint.endpointId, + activateRequestContext(config, endpoint.endpointId, endpoints, validationPhase.getBeanResolver()))); routes.produce(builder.build()); } } + private boolean activateRequestContext(WebSocketsServerBuildConfig config, String endpointId, + List endpoints, BeanResolver beanResolver) { + return switch (config.activateRequestContext()) { + case ALWAYS -> true; + case AUTO -> needsRequestContext(findEndpoint(endpointId, endpoints).bean, new HashSet<>(), beanResolver); + default -> throw new IllegalArgumentException("Unexpected value: " + config.activateRequestContext()); + }; + } + + private WebSocketEndpointBuildItem findEndpoint(String endpointId, List endpoints) { + for (WebSocketEndpointBuildItem endpoint : endpoints) { + if (endpoint.id.equals(endpointId)) { + return endpoint; + } + } + throw new IllegalArgumentException("Endpoint not found: " + endpointId); + } + + private boolean needsRequestContext(BeanInfo bean, Set processedBeans, BeanResolver beanResolver) { + if (processedBeans.add(bean.getIdentifier())) { + if (BuiltinScope.REQUEST.is(bean.getScope()) + || (bean.isClassBean() + && bean.hasAroundInvokeInterceptors() + && SecurityTransformerUtils.hasSecurityAnnotation(bean.getTarget().get().asClass()))) { + // Bean is: + // 1. Request scoped, or + // 2. Is class-based, has an aroundInvoke interceptor associated and is annotated with a security annotation + return true; + } + for (InjectionPointInfo injectionPoint : bean.getAllInjectionPoints()) { + BeanInfo dependency = injectionPoint.getResolvedBean(); + if (dependency != null) { + if (needsRequestContext(dependency, processedBeans, beanResolver)) { + return true; + } + } else { + Type requiredType = null; + Set qualifiers = null; + if (BuiltinBean.INSTANCE.matches(injectionPoint)) { + requiredType = injectionPoint.getRequiredType(); + qualifiers = injectionPoint.getRequiredQualifiers(); + } else if (BuiltinBean.LIST.matches(injectionPoint)) { + requiredType = injectionPoint.getRequiredType().asParameterizedType().arguments().get(0); + qualifiers = new HashSet<>(injectionPoint.getRequiredQualifiers()); + for (Iterator it = qualifiers.iterator(); it.hasNext();) { + if (it.next().name().equals(DotNames.ALL)) { + it.remove(); + } + } + } + if (requiredType != null) { + // For programmatic lookup and @All List<> we need to resolve the beans manually + for (BeanInfo lookupDependency : beanResolver.resolveBeans(requiredType, qualifiers)) { + if (needsRequestContext(lookupDependency, processedBeans, beanResolver)) { + return true; + } + } + } + } + } + } + return false; + } + @BuildStep UnremovableBeanBuildItem makeHttpUpgradeChecksUnremovable() { // we access the checks programmatically diff --git a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/connection/ServiceConnectionScopeTest.java b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/connection/ServiceConnectionScopeTest.java index 7db59696cc863b..897e8720eb0844 100644 --- a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/connection/ServiceConnectionScopeTest.java +++ b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/connection/ServiceConnectionScopeTest.java @@ -64,7 +64,8 @@ public static class MyEndpoint { @OnTextMessage public String onMessage(String message) { assertNotNull(Arc.container().getActiveContext(SessionScoped.class)); - assertNotNull(Arc.container().getActiveContext(RequestScoped.class)); + // By default, the request context is only activated if needed + assertNull(Arc.container().getActiveContext(RequestScoped.class)); assertNotNull(connection.id()); return message.toUpperCase(); } diff --git a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/requestcontext/RequestContextActivatedByInstanceTest.java b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/requestcontext/RequestContextActivatedByInstanceTest.java new file mode 100644 index 00000000000000..bf77ea6beb6a50 --- /dev/null +++ b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/requestcontext/RequestContextActivatedByInstanceTest.java @@ -0,0 +1,56 @@ +package io.quarkus.websockets.next.test.requestcontext; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.net.URI; + +import jakarta.enterprise.inject.Instance; +import jakarta.inject.Inject; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.arc.Arc; +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.websockets.next.OnTextMessage; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.test.utils.WSClient; +import io.vertx.core.Vertx; + +public class RequestContextActivatedByInstanceTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(Endpoint.class, WSClient.class, RequestScopedBean.class); + }); + + @Inject + Vertx vertx; + + @TestHTTPResource("end") + URI endUri; + + @Test + void testRequestContext() throws InterruptedException { + try (WSClient client = WSClient.create(vertx).connect(endUri)) { + client.sendAndAwait("ping"); + client.waitForMessages(1); + assertEquals("pong:true", client.getLastMessage().toString()); + } + } + + @WebSocket(path = "/end") + public static class Endpoint { + + @Inject + Instance instance; + + @OnTextMessage + String process(String message) { + return "pong:" + Arc.container().requestContext().isActive(); + } + } + +} diff --git a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/requestcontext/RequestContextActivatedByListTest.java b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/requestcontext/RequestContextActivatedByListTest.java new file mode 100644 index 00000000000000..dc43444e6942d5 --- /dev/null +++ b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/requestcontext/RequestContextActivatedByListTest.java @@ -0,0 +1,57 @@ +package io.quarkus.websockets.next.test.requestcontext; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.net.URI; +import java.util.List; + +import jakarta.inject.Inject; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.arc.All; +import io.quarkus.arc.Arc; +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.websockets.next.OnTextMessage; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.test.utils.WSClient; +import io.vertx.core.Vertx; + +public class RequestContextActivatedByListTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(Endpoint.class, WSClient.class, RequestScopedBean.class); + }); + + @Inject + Vertx vertx; + + @TestHTTPResource("end") + URI endUri; + + @Test + void testRequestContext() throws InterruptedException { + try (WSClient client = WSClient.create(vertx).connect(endUri)) { + client.sendAndAwait("ping"); + client.waitForMessages(1); + assertEquals("pong:true", client.getLastMessage().toString()); + } + } + + @WebSocket(path = "/end") + public static class Endpoint { + + @All + List list; + + @OnTextMessage + String process(String message) { + return "pong:" + Arc.container().requestContext().isActive(); + } + } + +} diff --git a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/requestcontext/RequestContextAlwaysActiveTest.java b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/requestcontext/RequestContextAlwaysActiveTest.java new file mode 100644 index 00000000000000..f14744ebe47fed --- /dev/null +++ b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/requestcontext/RequestContextAlwaysActiveTest.java @@ -0,0 +1,53 @@ +package io.quarkus.websockets.next.test.requestcontext; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.net.URI; + +import jakarta.inject.Inject; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.arc.Arc; +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.websockets.next.OnTextMessage; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.test.utils.WSClient; +import io.vertx.core.Vertx; + +public class RequestContextAlwaysActiveTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(Endpoint.class, WSClient.class); + }) + .overrideConfigKey("quarkus.websockets-next.server.activate-request-context", "always"); + + @Inject + Vertx vertx; + + @TestHTTPResource("end") + URI endUri; + + @Test + void testRequestContext() throws InterruptedException { + try (WSClient client = WSClient.create(vertx).connect(endUri)) { + client.sendAndAwait("ping"); + client.waitForMessages(1); + assertEquals("pong:true", client.getLastMessage().toString()); + } + } + + @WebSocket(path = "/end") + public static class Endpoint { + + @OnTextMessage + String process(String message) { + return "pong:" + Arc.container().requestContext().isActive(); + } + } + +} diff --git a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/requestcontext/RequestContextNotActiveTest.java b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/requestcontext/RequestContextNotActiveTest.java new file mode 100644 index 00000000000000..6f4d35fddf66b2 --- /dev/null +++ b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/requestcontext/RequestContextNotActiveTest.java @@ -0,0 +1,52 @@ +package io.quarkus.websockets.next.test.requestcontext; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.net.URI; + +import jakarta.inject.Inject; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.arc.Arc; +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.websockets.next.OnTextMessage; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.test.utils.WSClient; +import io.vertx.core.Vertx; + +public class RequestContextNotActiveTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(Endpoint.class, WSClient.class); + }); + + @Inject + Vertx vertx; + + @TestHTTPResource("end") + URI endUri; + + @Test + void testRequestContext() throws InterruptedException { + try (WSClient client = WSClient.create(vertx).connect(endUri)) { + client.sendAndAwait("ping"); + client.waitForMessages(1); + assertEquals("pong:false", client.getLastMessage().toString()); + } + } + + @WebSocket(path = "/end") + public static class Endpoint { + + @OnTextMessage + String process(String message) { + return "pong:" + Arc.container().requestContext().isActive(); + } + } + +} diff --git a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/LazySecurityTest.java b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/LazySecurityTest.java index 7d21f28dbc2c55..5d61c001248224 100644 --- a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/LazySecurityTest.java +++ b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/LazySecurityTest.java @@ -16,14 +16,13 @@ import io.quarkus.websockets.next.OnOpen; import io.quarkus.websockets.next.OnTextMessage; import io.quarkus.websockets.next.WebSocket; -import io.quarkus.websockets.next.test.security.EagerSecurityTest.Endpoint; import io.quarkus.websockets.next.test.utils.WSClient; public class LazySecurityTest extends SecurityTestBase { @RegisterExtension static final QuarkusUnitTest config = new QuarkusUnitTest() - .withApplicationRoot((jar) -> jar + .withApplicationRoot(root -> root .addAsResource(new StringAsset("quarkus.http.auth.proactive=false\n" + "quarkus.http.auth.permission.secured.paths=/end\n" + "quarkus.http.auth.permission.secured.policy=authenticated\n"), "application.properties") diff --git a/extensions/websockets-next/deployment/src/test/java21/io/quarkus/websockets/next/test/virtualthreads/RunOnVirtualThreadTest.java b/extensions/websockets-next/deployment/src/test/java21/io/quarkus/websockets/next/test/virtualthreads/RunOnVirtualThreadTest.java index 9de1896a6d1645..0c767e18834cd4 100644 --- a/extensions/websockets-next/deployment/src/test/java21/io/quarkus/websockets/next/test/virtualthreads/RunOnVirtualThreadTest.java +++ b/extensions/websockets-next/deployment/src/test/java21/io/quarkus/websockets/next/test/virtualthreads/RunOnVirtualThreadTest.java @@ -18,6 +18,7 @@ import io.quarkus.websockets.next.test.utils.WSClient; import io.smallrye.common.annotation.RunOnVirtualThread; import io.vertx.core.Vertx; +import jakarta.enterprise.context.RequestScoped; import jakarta.inject.Inject; public class RunOnVirtualThreadTest { @@ -25,7 +26,7 @@ public class RunOnVirtualThreadTest { @RegisterExtension public static final QuarkusUnitTest test = new QuarkusUnitTest() .withApplicationRoot(root -> { - root.addClasses(Endpoint.class, WSClient.class) + root.addClasses(Endpoint.class, WSClient.class, RequestScopedBean.class) .addAsResource(new StringAsset( "quarkus.virtual-threads.name-prefix=wsnext-virtual-thread-"), "application.properties"); @@ -54,6 +55,9 @@ void testVirtualThreads() { @WebSocket(path = "/end") public static class Endpoint { + @Inject + RequestScopedBean bean; + @RunOnVirtualThread @OnTextMessage String text(String ignored) { @@ -67,5 +71,10 @@ String error(Throwable t) { } } + + @RequestScoped + public static class RequestScopedBean { + + } } diff --git a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/WebSocketsServerBuildConfig.java b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/WebSocketsServerBuildConfig.java new file mode 100644 index 00000000000000..94860bcd0c18f1 --- /dev/null +++ b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/WebSocketsServerBuildConfig.java @@ -0,0 +1,31 @@ +package io.quarkus.websockets.next; + +import io.quarkus.runtime.annotations.ConfigPhase; +import io.quarkus.runtime.annotations.ConfigRoot; +import io.smallrye.config.ConfigMapping; +import io.smallrye.config.WithDefault; + +@ConfigMapping(prefix = "quarkus.websockets-next.server") +@ConfigRoot(phase = ConfigPhase.BUILD_AND_RUN_TIME_FIXED) +public interface WebSocketsServerBuildConfig { + + /** + * Specifies whether to activate the CDI request context when an endpoint callback is invoked. By default, the request + * context is only activated if needed. + */ + @WithDefault("auto") + RequestContextActivation activateRequestContext(); + + enum RequestContextActivation { + /** + * The request context is only activated if needed, i.e. if there is a request scoped bean , or a bean annotated + * with a security annotation (such as {@code @RolesAllowed}) in the dependency tree of the endpoint. + */ + AUTO, + /** + * The request context is always activated. + */ + ALWAYS + } + +} diff --git a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/ContextSupport.java b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/ContextSupport.java index b36d4dc834b3e0..7b4a605d8ddc1b 100644 --- a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/ContextSupport.java +++ b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/ContextSupport.java @@ -36,7 +36,9 @@ void start() { void start(ContextState requestContextState) { LOG.debugf("Start contexts: %s", connection); startSession(); - requestContext.activate(requestContextState); + if (requestContext != null) { + requestContext.activate(requestContextState); + } } void startSession() { @@ -51,10 +53,12 @@ void end(boolean terminateSession) { void end(boolean terminateRequest, boolean terminateSession) { LOG.debugf("End contexts: %s [terminateRequest: %s, terminateSession: %s]", connection, terminateRequest, terminateSession); - if (terminateRequest) { - requestContext.terminate(); - } else { - requestContext.deactivate(); + if (requestContext != null) { + if (terminateRequest) { + requestContext.terminate(); + } else { + requestContext.deactivate(); + } } if (terminateSession) { // OnClose - terminate the session context @@ -68,10 +72,6 @@ void endSession() { sessionContext.terminate(); } - ContextState currentRequestContextState() { - return requestContext.getStateIfActive(); - } - static Context createNewDuplicatedContext(Context context, WebSocketConnectionBase connection) { Context duplicated = VertxContext.createNewDuplicatedContext(context); VertxContextSafetyToggle.setContextSafe(duplicated, true); diff --git a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/Endpoints.java b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/Endpoints.java index 587dfe047bba18..7ccc97539e7f28 100644 --- a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/Endpoints.java +++ b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/Endpoints.java @@ -34,7 +34,7 @@ class Endpoints { static void initialize(Vertx vertx, ArcContainer container, Codecs codecs, WebSocketConnectionBase connection, WebSocketBase ws, String generatedEndpointClass, Optional autoPingInterval, SecuritySupport securitySupport, UnhandledFailureStrategy unhandledFailureStrategy, TrafficLogger trafficLogger, - Runnable onClose) { + Runnable onClose, boolean activateRequestContext) { Context context = vertx.getOrCreateContext(); @@ -44,7 +44,7 @@ static void initialize(Vertx vertx, ArcContainer container, Codecs codecs, WebSo SessionContextState sessionContextState = sessionContext.initializeContextState(); ContextSupport contextSupport = new ContextSupport(connection, sessionContextState, sessionContext(container), - container.requestContext()); + activateRequestContext ? container.requestContext() : null); // Create an endpoint that delegates callbacks to the endpoint bean WebSocketEndpoint endpoint = createEndpoint(generatedEndpointClass, context, connection, codecs, contextSupport, diff --git a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/SecuritySupport.java b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/SecuritySupport.java index eeb5f5a5ad342c..0d7f0c6b9e0d12 100644 --- a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/SecuritySupport.java +++ b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/SecuritySupport.java @@ -7,6 +7,8 @@ import org.jboss.logging.Logger; +import io.quarkus.arc.Arc; +import io.quarkus.arc.ManagedContext; import io.quarkus.security.identity.CurrentIdentityAssociation; import io.quarkus.security.identity.SecurityIdentity; import io.quarkus.websockets.next.CloseReason; @@ -20,6 +22,7 @@ public class SecuritySupport { private final Instance currentIdentity; private final SecurityIdentity identity; private final Runnable onClose; + private final ManagedContext requestContext; SecuritySupport(Instance currentIdentity, SecurityIdentity identity, Vertx vertx, WebSocketConnectionImpl connection) { @@ -31,13 +34,15 @@ public class SecuritySupport { this.identity = null; this.onClose = null; } + this.requestContext = Arc.container().requestContext(); } /** * This method is called before an endpoint callback is invoked. */ void start() { - if (currentIdentity != null) { + if (currentIdentity != null && requestContext.isActive()) { + // If the request context is active then set the current identity CurrentIdentityAssociation current = currentIdentity.get(); current.setIdentity(identity); } diff --git a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectorImpl.java b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectorImpl.java index 686f132c71038d..05b41ce6a336c9 100644 --- a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectorImpl.java +++ b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectorImpl.java @@ -106,7 +106,7 @@ public Uni connect() { () -> { connectionManager.remove(clientEndpoint.generatedEndpointClass, connection); client.close(); - }); + }, true); return connection; }); diff --git a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketServerRecorder.java b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketServerRecorder.java index ff5030af7ee24a..077dca8885fee8 100644 --- a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketServerRecorder.java +++ b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketServerRecorder.java @@ -60,7 +60,8 @@ public Object get() { }; } - public Handler createEndpointHandler(String generatedEndpointClass, String endpointId) { + public Handler createEndpointHandler(String generatedEndpointClass, String endpointId, + boolean activateRequestContext) { ArcContainer container = Arc.container(); ConnectionManager connectionManager = container.instance(ConnectionManager.class).get(); Codecs codecs = container.instance(Codecs.class).get(); @@ -107,7 +108,7 @@ private void httpUpgrade(RoutingContext ctx) { Endpoints.initialize(vertx, container, codecs, connection, ws, generatedEndpointClass, config.autoPingInterval(), securitySupport, config.unhandledFailureStrategy(), trafficLogger, - () -> connectionManager.remove(generatedEndpointClass, connection)); + () -> connectionManager.remove(generatedEndpointClass, connection), activateRequestContext); }); }