Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WebSockets Next - client: support the wss scheme correctly #42826

Merged
merged 1 commit into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import java.io.File;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
Expand Down Expand Up @@ -52,7 +53,14 @@ public class TlsClientEndpointTest {
URI uri;

@Test
void testClient() throws InterruptedException {
void testClient() throws InterruptedException, URISyntaxException {
assertClient(uri);
URI wssUri = new URI("wss", uri.getUserInfo(), uri.getHost(), uri.getPort(), uri.getPath(), uri.getQuery(),
uri.getFragment());
assertClient(wssUri);
}

void assertClient(URI uri) throws InterruptedException, URISyntaxException {
WebSocketClientConnection connection = connector
.baseUri(uri)
// The value will be encoded automatically
Expand All @@ -63,19 +71,22 @@ void testClient() throws InterruptedException {
assertEquals("Lu=", connection.pathParam("name"));
connection.sendTextAndAwait("Hi!");

assertTrue(ClientEndpoint.MESSAGE_LATCH.await(5, TimeUnit.SECONDS));
assertTrue(ClientEndpoint.messageLatch.await(5, TimeUnit.SECONDS));
assertEquals("Lu=:Hello Lu=!", ClientEndpoint.MESSAGES.get(0));
assertEquals("Lu=:Hi!", ClientEndpoint.MESSAGES.get(1));

connection.closeAndAwait();
assertTrue(ClientEndpoint.CLOSED_LATCH.await(5, TimeUnit.SECONDS));
assertTrue(ServerEndpoint.CLOSED_LATCH.await(5, TimeUnit.SECONDS));
assertTrue(ClientEndpoint.closedLatch.await(5, TimeUnit.SECONDS));
assertTrue(ServerEndpoint.closedLatch.await(5, TimeUnit.SECONDS));

ServerEndpoint.reset();
ClientEndpoint.reset();
}

@WebSocket(path = "/endpoint/{name}")
public static class ServerEndpoint {

static final CountDownLatch CLOSED_LATCH = new CountDownLatch(1);
static volatile CountDownLatch closedLatch = new CountDownLatch(1);

@OnOpen
String open(@PathParam String name) {
Expand All @@ -89,32 +100,42 @@ String echo(String message) {

@OnClose
void close() {
CLOSED_LATCH.countDown();
closedLatch.countDown();
}

static void reset() {
closedLatch = new CountDownLatch(1);
}

}

@WebSocketClient(path = "/endpoint/{name}")
public static class ClientEndpoint {

static final CountDownLatch MESSAGE_LATCH = new CountDownLatch(2);
static volatile CountDownLatch messageLatch = new CountDownLatch(2);

static final List<String> MESSAGES = new CopyOnWriteArrayList<>();

static final CountDownLatch CLOSED_LATCH = new CountDownLatch(1);
static volatile CountDownLatch closedLatch = new CountDownLatch(1);

@OnTextMessage
void onMessage(@PathParam String name, String message, WebSocketClientConnection connection) {
if (!name.equals(connection.pathParam("name"))) {
throw new IllegalArgumentException();
}
MESSAGES.add(name + ":" + message);
MESSAGE_LATCH.countDown();
messageLatch.countDown();
}

@OnClose
void close() {
CLOSED_LATCH.countDown();
closedLatch.countDown();
}

static void reset() {
MESSAGES.clear();
messageLatch = new CountDownLatch(2);
closedLatch = new CountDownLatch(1);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,18 +161,18 @@ protected WebSocketClientOptions populateClientOptions() {

protected WebSocketConnectOptions newConnectOptions(URI serverEndpointUri) {
WebSocketConnectOptions connectOptions = new WebSocketConnectOptions()
.setSsl(isHttps(serverEndpointUri))
.setSsl(isSecure(serverEndpointUri))
.setHost(serverEndpointUri.getHost());
if (serverEndpointUri.getPort() != -1) {
connectOptions.setPort(serverEndpointUri.getPort());
} else if (isHttps(serverEndpointUri)) {
// If port is undefined and https is used then use 443 by default
} else if (isSecure(serverEndpointUri)) {
// If port is undefined and https/wss is used then use 443 by default
connectOptions.setPort(443);
}
return connectOptions;
}

protected boolean isHttps(URI uri) {
return "https".equals(uri.getScheme());
protected boolean isSecure(URI uri) {
return "https".equals(uri.getScheme()) || "wss".equals(uri.getScheme());
}
}