diff --git a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/client/TlsClientEndpointTest.java b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/client/TlsClientEndpointTest.java index 8f5af7885b530..ea69019783846 100644 --- a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/client/TlsClientEndpointTest.java +++ b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/client/TlsClientEndpointTest.java @@ -5,6 +5,7 @@ import java.io.File; import java.net.URI; +import java.net.URISyntaxException; import java.util.List; import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.CountDownLatch; @@ -52,7 +53,14 @@ public class TlsClientEndpointTest { URI uri; @Test - void testClient() throws InterruptedException { + void testClient() throws InterruptedException, URISyntaxException { + assertClient(uri); + URI wssUri = new URI("wss", uri.getUserInfo(), uri.getHost(), uri.getPort(), uri.getPath(), uri.getQuery(), + uri.getFragment()); + assertClient(wssUri); + } + + void assertClient(URI uri) throws InterruptedException, URISyntaxException { WebSocketClientConnection connection = connector .baseUri(uri) // The value will be encoded automatically @@ -63,19 +71,22 @@ void testClient() throws InterruptedException { assertEquals("Lu=", connection.pathParam("name")); connection.sendTextAndAwait("Hi!"); - assertTrue(ClientEndpoint.MESSAGE_LATCH.await(5, TimeUnit.SECONDS)); + assertTrue(ClientEndpoint.messageLatch.await(5, TimeUnit.SECONDS)); assertEquals("Lu=:Hello Lu=!", ClientEndpoint.MESSAGES.get(0)); assertEquals("Lu=:Hi!", ClientEndpoint.MESSAGES.get(1)); connection.closeAndAwait(); - assertTrue(ClientEndpoint.CLOSED_LATCH.await(5, TimeUnit.SECONDS)); - assertTrue(ServerEndpoint.CLOSED_LATCH.await(5, TimeUnit.SECONDS)); + assertTrue(ClientEndpoint.closedLatch.await(5, TimeUnit.SECONDS)); + assertTrue(ServerEndpoint.closedLatch.await(5, TimeUnit.SECONDS)); + + ServerEndpoint.reset(); + ClientEndpoint.reset(); } @WebSocket(path = "/endpoint/{name}") public static class ServerEndpoint { - static final CountDownLatch CLOSED_LATCH = new CountDownLatch(1); + static volatile CountDownLatch closedLatch = new CountDownLatch(1); @OnOpen String open(@PathParam String name) { @@ -89,7 +100,11 @@ String echo(String message) { @OnClose void close() { - CLOSED_LATCH.countDown(); + closedLatch.countDown(); + } + + static void reset() { + closedLatch = new CountDownLatch(1); } } @@ -97,11 +112,11 @@ void close() { @WebSocketClient(path = "/endpoint/{name}") public static class ClientEndpoint { - static final CountDownLatch MESSAGE_LATCH = new CountDownLatch(2); + static volatile CountDownLatch messageLatch = new CountDownLatch(2); static final List MESSAGES = new CopyOnWriteArrayList<>(); - static final CountDownLatch CLOSED_LATCH = new CountDownLatch(1); + static volatile CountDownLatch closedLatch = new CountDownLatch(1); @OnTextMessage void onMessage(@PathParam String name, String message, WebSocketClientConnection connection) { @@ -109,12 +124,18 @@ void onMessage(@PathParam String name, String message, WebSocketClientConnection throw new IllegalArgumentException(); } MESSAGES.add(name + ":" + message); - MESSAGE_LATCH.countDown(); + messageLatch.countDown(); } @OnClose void close() { - CLOSED_LATCH.countDown(); + closedLatch.countDown(); + } + + static void reset() { + MESSAGES.clear(); + messageLatch = new CountDownLatch(2); + closedLatch = new CountDownLatch(1); } } diff --git a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectorBase.java b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectorBase.java index 1a878b6b6cb18..dcc142d80aea5 100644 --- a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectorBase.java +++ b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectorBase.java @@ -161,18 +161,18 @@ protected WebSocketClientOptions populateClientOptions() { protected WebSocketConnectOptions newConnectOptions(URI serverEndpointUri) { WebSocketConnectOptions connectOptions = new WebSocketConnectOptions() - .setSsl(isHttps(serverEndpointUri)) + .setSsl(isSecure(serverEndpointUri)) .setHost(serverEndpointUri.getHost()); if (serverEndpointUri.getPort() != -1) { connectOptions.setPort(serverEndpointUri.getPort()); - } else if (isHttps(serverEndpointUri)) { - // If port is undefined and https is used then use 443 by default + } else if (isSecure(serverEndpointUri)) { + // If port is undefined and https/wss is used then use 443 by default connectOptions.setPort(443); } return connectOptions; } - protected boolean isHttps(URI uri) { - return "https".equals(uri.getScheme()); + protected boolean isSecure(URI uri) { + return "https".equals(uri.getScheme()) || "wss".equals(uri.getScheme()); } }