diff --git a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/errors/EchoMessageError.java b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/errors/EchoMessageError.java new file mode 100644 index 0000000000000..3d52df32d1473 --- /dev/null +++ b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/errors/EchoMessageError.java @@ -0,0 +1,23 @@ +package io.quarkus.websockets.next.test.errors; + +import java.util.concurrent.CountDownLatch; + +import io.quarkus.websockets.next.OnTextMessage; +import io.quarkus.websockets.next.WebSocket; + +@WebSocket(path = "/echo") +public class EchoMessageError { + + static final CountDownLatch MESSAGE_FAILURE_CALLED = new CountDownLatch(1); + + @OnTextMessage + String echo(String message) { + if ("foo".equals(message)) { + MESSAGE_FAILURE_CALLED.countDown(); + throw new IllegalStateException("I cannot do it!"); + } else { + return message; + } + } + +} \ No newline at end of file diff --git a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/errors/EchoOpenError.java b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/errors/EchoOpenError.java new file mode 100644 index 0000000000000..7a079a0eb45c2 --- /dev/null +++ b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/errors/EchoOpenError.java @@ -0,0 +1,25 @@ +package io.quarkus.websockets.next.test.errors; + +import java.util.concurrent.CountDownLatch; + +import io.quarkus.websockets.next.OnOpen; +import io.quarkus.websockets.next.OnTextMessage; +import io.quarkus.websockets.next.WebSocket; + +@WebSocket(path = "/echo") +public class EchoOpenError { + + static final CountDownLatch OPEN_CALLED = new CountDownLatch(1); + + @OnOpen + void open() { + OPEN_CALLED.countDown(); + throw new IllegalStateException("I cannot do it!"); + } + + @OnTextMessage + String echo(String message) { + return message; + } + +} \ No newline at end of file diff --git a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/errors/UnhandledMessageFailureDefaultStrategyTest.java b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/errors/UnhandledMessageFailureDefaultStrategyTest.java new file mode 100644 index 0000000000000..1207e6689277a --- /dev/null +++ b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/errors/UnhandledMessageFailureDefaultStrategyTest.java @@ -0,0 +1,46 @@ +package io.quarkus.websockets.next.test.errors; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.net.URI; +import java.time.Duration; +import java.util.concurrent.TimeUnit; + +import jakarta.inject.Inject; + +import org.awaitility.Awaitility; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus; +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.websockets.next.test.utils.WSClient; +import io.vertx.core.Vertx; + +public class UnhandledMessageFailureDefaultStrategyTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(EchoMessageError.class, WSClient.class); + }); + + @Inject + Vertx vertx; + + @TestHTTPResource("echo") + URI testUri; + + @Test + void testError() throws InterruptedException { + try (WSClient client = WSClient.create(vertx).connect(testUri)) { + client.sendAndAwait("foo"); + assertTrue(EchoMessageError.MESSAGE_FAILURE_CALLED.await(5, TimeUnit.SECONDS)); + Awaitility.await().atMost(Duration.ofSeconds(5)).until(() -> client.isClosed()); + assertEquals(WebSocketCloseStatus.INTERNAL_SERVER_ERROR.code(), client.closeStatusCode()); + } + } + +} diff --git a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/errors/UnhandledMessageFailureLogStrategyTest.java b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/errors/UnhandledMessageFailureLogStrategyTest.java new file mode 100644 index 0000000000000..0061937345fcf --- /dev/null +++ b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/errors/UnhandledMessageFailureLogStrategyTest.java @@ -0,0 +1,44 @@ +package io.quarkus.websockets.next.test.errors; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.net.URI; +import java.util.concurrent.TimeUnit; + +import jakarta.inject.Inject; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.websockets.next.test.utils.WSClient; +import io.vertx.core.Vertx; + +public class UnhandledMessageFailureLogStrategyTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(EchoMessageError.class, WSClient.class); + }).overrideConfigKey("quarkus.websockets-next.server.unhandled-failure-strategy", "log"); + + @Inject + Vertx vertx; + + @TestHTTPResource("echo") + URI testUri; + + @Test + void testErrorDoesNotCloseConnection() throws InterruptedException { + try (WSClient client = WSClient.create(vertx).connect(testUri)) { + client.sendAndAwait("foo"); + assertTrue(EchoMessageError.MESSAGE_FAILURE_CALLED.await(5, TimeUnit.SECONDS)); + client.sendAndAwait("bar"); + client.waitForMessages(1); + assertEquals("bar", client.getLastMessage().toString()); + } + } + +} diff --git a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/errors/UnhandledOpenFailureDefaultStrategyTest.java b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/errors/UnhandledOpenFailureDefaultStrategyTest.java new file mode 100644 index 0000000000000..61c712d005d86 --- /dev/null +++ b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/errors/UnhandledOpenFailureDefaultStrategyTest.java @@ -0,0 +1,45 @@ +package io.quarkus.websockets.next.test.errors; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.net.URI; +import java.time.Duration; +import java.util.concurrent.TimeUnit; + +import jakarta.inject.Inject; + +import org.awaitility.Awaitility; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus; +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.websockets.next.test.utils.WSClient; +import io.vertx.core.Vertx; + +public class UnhandledOpenFailureDefaultStrategyTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(EchoOpenError.class, WSClient.class); + }); + + @Inject + Vertx vertx; + + @TestHTTPResource("echo") + URI testUri; + + @Test + void testError() throws InterruptedException { + try (WSClient client = WSClient.create(vertx).connect(testUri)) { + assertTrue(EchoOpenError.OPEN_CALLED.await(5, TimeUnit.SECONDS)); + Awaitility.await().atMost(Duration.ofSeconds(5)).until(() -> client.isClosed()); + assertEquals(WebSocketCloseStatus.INTERNAL_SERVER_ERROR.code(), client.closeStatusCode()); + } + } + +} diff --git a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/errors/UnhandledOpenFailureLogStrategyTest.java b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/errors/UnhandledOpenFailureLogStrategyTest.java new file mode 100644 index 0000000000000..b704e8c551cde --- /dev/null +++ b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/errors/UnhandledOpenFailureLogStrategyTest.java @@ -0,0 +1,43 @@ +package io.quarkus.websockets.next.test.errors; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.net.URI; +import java.util.concurrent.TimeUnit; + +import jakarta.inject.Inject; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.websockets.next.test.utils.WSClient; +import io.vertx.core.Vertx; + +public class UnhandledOpenFailureLogStrategyTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(EchoOpenError.class, WSClient.class); + }).overrideConfigKey("quarkus.websockets-next.server.unhandled-failure-strategy", "log"); + + @Inject + Vertx vertx; + + @TestHTTPResource("echo") + URI testUri; + + @Test + void testErrorDoesNotCloseConnection() throws InterruptedException { + try (WSClient client = WSClient.create(vertx).connect(testUri)) { + assertTrue(EchoOpenError.OPEN_CALLED.await(5, TimeUnit.SECONDS)); + client.sendAndAwait("foo"); + client.waitForMessages(1); + assertEquals("foo", client.getLastMessage().toString()); + } + } + +} diff --git a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/utils/WSClient.java b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/utils/WSClient.java index 773b9ab8d134f..955eb9c1b315c 100644 --- a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/utils/WSClient.java +++ b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/utils/WSClient.java @@ -126,6 +126,10 @@ public boolean isClosed() { return socket.get().isClosed(); } + public int closeStatusCode() { + return socket.get().closeStatusCode(); + } + @Override public void close() { disconnect(); diff --git a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/CloseReason.java b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/CloseReason.java index 55e100a9b9e7d..108c2d150b55b 100644 --- a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/CloseReason.java +++ b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/CloseReason.java @@ -15,6 +15,8 @@ public class CloseReason { public static final CloseReason NORMAL = new CloseReason(WebSocketCloseStatus.NORMAL_CLOSURE.code()); + public static final CloseReason INTERNAL_SERVER_ERROR = new CloseReason(WebSocketCloseStatus.INTERNAL_SERVER_ERROR.code()); + private final int code; private final String message; diff --git a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/UnhandledFailureStrategy.java b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/UnhandledFailureStrategy.java new file mode 100644 index 0000000000000..bdfb1f17ad2be --- /dev/null +++ b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/UnhandledFailureStrategy.java @@ -0,0 +1,20 @@ +package io.quarkus.websockets.next; + +/** + * The strategy used when an error occurs but no error handler can handle the failure. + */ +public enum UnhandledFailureStrategy { + /** + * Close the connection. + */ + CLOSE, + /** + * Log an error message. + */ + LOG, + /** + * No operation. + */ + NOOP; + +} \ No newline at end of file diff --git a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/WebSocketsClientRuntimeConfig.java b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/WebSocketsClientRuntimeConfig.java index dff4780aa45c7..ecaf0bb169d0d 100644 --- a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/WebSocketsClientRuntimeConfig.java +++ b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/WebSocketsClientRuntimeConfig.java @@ -40,4 +40,12 @@ public interface WebSocketsClientRuntimeConfig { */ Optional autoPingInterval(); + /** + * The strategy used when an error occurs but no error handler can handle the failure. + *

+ * By default, the connection is closed when an unhandled failure occurs. + */ + @WithDefault("close") + UnhandledFailureStrategy unhandledFailureStrategy(); + } diff --git a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/WebSocketsServerRuntimeConfig.java b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/WebSocketsServerRuntimeConfig.java index 28e9d284c2fce..43beffda35600 100644 --- a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/WebSocketsServerRuntimeConfig.java +++ b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/WebSocketsServerRuntimeConfig.java @@ -46,4 +46,12 @@ public interface WebSocketsServerRuntimeConfig { */ Optional autoPingInterval(); + /** + * The strategy used when an error occurs but no error handler can handle the failure. + *

+ * By default, the connection is closed when an unhandled failure occurs. + */ + @WithDefault("close") + UnhandledFailureStrategy unhandledFailureStrategy(); + } diff --git a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/Endpoints.java b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/Endpoints.java index e8ed61d23620c..ce4d2c096628d 100644 --- a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/Endpoints.java +++ b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/Endpoints.java @@ -13,6 +13,8 @@ import io.quarkus.security.AuthenticationFailedException; import io.quarkus.security.ForbiddenException; import io.quarkus.security.UnauthorizedException; +import io.quarkus.websockets.next.CloseReason; +import io.quarkus.websockets.next.UnhandledFailureStrategy; import io.quarkus.websockets.next.WebSocketException; import io.quarkus.websockets.next.runtime.WebSocketSessionContext.SessionContextState; import io.smallrye.mutiny.Multi; @@ -29,7 +31,7 @@ class Endpoints { static void initialize(Vertx vertx, ArcContainer container, Codecs codecs, WebSocketConnectionBase connection, WebSocketBase ws, String generatedEndpointClass, Optional autoPingInterval, - SecuritySupport securitySupport, Runnable onClose) { + SecuritySupport securitySupport, UnhandledFailureStrategy unhandledFailureStrategy, Runnable onClose) { Context context = vertx.getOrCreateContext(); @@ -75,7 +77,7 @@ public void handle(Void event) { LOG.debugf("@OnTextMessage callback consuming Multi completed: %s", connection); } else { - logFailure(r.cause(), + handleFailure(unhandledFailureStrategy, r.cause(), "Unable to complete @OnTextMessage callback consuming Multi", connection); } @@ -93,7 +95,7 @@ public void handle(Void event) { LOG.debugf("@OnBinaryMessage callback consuming Multi completed: %s", connection); } else { - logFailure(r.cause(), + handleFailure(unhandledFailureStrategy, r.cause(), "Unable to complete @OnBinaryMessage callback consuming Multi", connection); } @@ -102,7 +104,7 @@ public void handle(Void event) { }); } } else { - logFailure(r.cause(), "Unable to complete @OnOpen callback", connection); + handleFailure(unhandledFailureStrategy, r.cause(), "Unable to complete @OnOpen callback", connection); } }); } @@ -115,7 +117,8 @@ public void handle(Void event) { if (r.succeeded()) { LOG.debugf("@OnTextMessage callback consumed text message: %s", connection); } else { - logFailure(r.cause(), "Unable to consume text message in @OnTextMessage callback", + handleFailure(unhandledFailureStrategy, r.cause(), + "Unable to consume text message in @OnTextMessage callback", connection); } }); @@ -130,7 +133,8 @@ public void handle(Void event) { } catch (Throwable throwable) { endpoint.doOnError(throwable).subscribe().with( v -> LOG.debugf("Text message >> Multi: %s", connection), - t -> LOG.errorf(t, "Unable to send text message to Multi: %s", connection)); + t -> handleFailure(unhandledFailureStrategy, t, "Unable to send text message to Multi", + connection)); } finally { contextSupport.end(false); } @@ -144,7 +148,8 @@ public void handle(Void event) { if (r.succeeded()) { LOG.debugf("@OnBinaryMessage callback consumed binary message: %s", connection); } else { - logFailure(r.cause(), "Unable to consume binary message in @OnBinaryMessage callback", + handleFailure(unhandledFailureStrategy, r.cause(), + "Unable to consume binary message in @OnBinaryMessage callback", connection); } }); @@ -159,7 +164,8 @@ public void handle(Void event) { } catch (Throwable throwable) { endpoint.doOnError(throwable).subscribe().with( v -> LOG.debugf("Binary message >> Multi: %s", connection), - t -> LOG.errorf(t, "Unable to send binary message to Multi: %s", connection)); + t -> handleFailure(unhandledFailureStrategy, t, "Unable to send binary message to Multi", + connection)); } finally { contextSupport.end(false); } @@ -171,7 +177,8 @@ public void handle(Void event) { if (r.succeeded()) { LOG.debugf("@OnPongMessage callback consumed text message: %s", connection); } else { - logFailure(r.cause(), "Unable to consume text message in @OnPongMessage callback", connection); + handleFailure(unhandledFailureStrategy, r.cause(), + "Unable to consume text message in @OnPongMessage callback", connection); } }); }); @@ -198,7 +205,8 @@ public void handle(Void event) { if (r.succeeded()) { LOG.debugf("@OnClose callback completed: %s", connection); } else { - logFailure(r.cause(), "Unable to complete @OnClose callback", connection); + handleFailure(unhandledFailureStrategy, r.cause(), "Unable to complete @OnClose callback", + connection); } onClose.run(); if (timerId != null) { @@ -218,14 +226,30 @@ public void handle(Throwable t) { public void handle(Void event) { endpoint.doOnError(t).subscribe().with( v -> LOG.debugf("Error [%s] processed: %s", t.getClass(), connection), - t -> LOG.errorf(t, "Unhandled error occurred: %s", t.toString(), - connection)); + t -> handleFailure(unhandledFailureStrategy, t, "Unhandled error occurred", connection)); } }); } }); } + private static void handleFailure(UnhandledFailureStrategy strategy, Throwable cause, String message, + WebSocketConnectionBase connection) { + switch (strategy) { + case CLOSE -> closeConnection(cause, connection); + case LOG -> logFailure(cause, message, connection); + case NOOP -> LOG.tracef("Unhandled failure ignored: %s", connection); + default -> throw new IllegalArgumentException("Unexpected strategy: " + strategy); + } + } + + private static void closeConnection(Throwable cause, WebSocketConnectionBase connection) { + connection.close(CloseReason.INTERNAL_SERVER_ERROR).subscribe().with( + v -> LOG.debugf("Connection closed due to unhandled failure %s: %s", cause, connection), + t -> LOG.errorf("Unable to close connection [%s] due to unhandled failure [%s]: %s", connection.id(), cause, + t)); + } + private static void logFailure(Throwable throwable, String message, WebSocketConnectionBase connection) { if (isWebSocketIsClosedFailure(throwable, connection)) { LOG.debugf(throwable, diff --git a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectorImpl.java b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectorImpl.java index d6281e5da71f4..8b8781ccac2ed 100644 --- a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectorImpl.java +++ b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectorImpl.java @@ -116,6 +116,7 @@ public Uni connect() { Endpoints.initialize(vertx, Arc.container(), codecs, connection, ws, clientEndpoint.generatedEndpointClass, config.autoPingInterval(), SecuritySupport.NOOP, + config.unhandledFailureStrategy(), () -> { connectionManager.remove(clientEndpoint.generatedEndpointClass, connection); client.close(); diff --git a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketServerRecorder.java b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketServerRecorder.java index 9384f8d60fc47..35bdae2ca2206 100644 --- a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketServerRecorder.java +++ b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketServerRecorder.java @@ -102,7 +102,7 @@ public void handle(RoutingContext ctx) { LOG.debugf("Connection created: %s", connection); Endpoints.initialize(vertx, container, codecs, connection, ws, generatedEndpointClass, - config.autoPingInterval(), securitySupport, + config.autoPingInterval(), securitySupport, config.unhandledFailureStrategy(), () -> connectionManager.remove(generatedEndpointClass, connection)); }); }