diff --git a/java/org/apache/tomcat/websocket/AsyncChannelWrapperSecure.java b/java/org/apache/tomcat/websocket/AsyncChannelWrapperSecure.java index 7ad581ddd477..eb744ea7700c 100644 --- a/java/org/apache/tomcat/websocket/AsyncChannelWrapperSecure.java +++ b/java/org/apache/tomcat/websocket/AsyncChannelWrapperSecure.java @@ -42,6 +42,7 @@ import org.apache.juli.logging.Log; import org.apache.juli.logging.LogFactory; import org.apache.tomcat.util.res.StringManager; +import org.apache.tomcat.util.threads.VirtualThreadExecutor; /** * Wraps the {@link AsynchronousSocketChannel} with SSL/TLS. This needs a lot more testing before it can be considered @@ -57,14 +58,23 @@ public class AsyncChannelWrapperSecure implements AsyncChannelWrapper { private final SSLEngine sslEngine; private final ByteBuffer socketReadBuffer; private final ByteBuffer socketWriteBuffer; - // One thread for read, one for write - private final ExecutorService executor = Executors.newFixedThreadPool(2, new SecureIOThreadFactory()); + private final ExecutorService executor; private final AtomicBoolean writing = new AtomicBoolean(false); private final AtomicBoolean reading = new AtomicBoolean(false); public AsyncChannelWrapperSecure(AsynchronousSocketChannel socketChannel, SSLEngine sslEngine) { + // One thread for read, one for write + this(socketChannel, sslEngine, Executors.newFixedThreadPool(2, new SecureIOThreadFactory())); + } + + public AsyncChannelWrapperSecure(AsynchronousSocketChannel socketChannel, SSLEngine sslEngine, VirtualThreadExecutor executor) { + this(socketChannel, sslEngine, (ExecutorService) executor); + } + + private AsyncChannelWrapperSecure(AsynchronousSocketChannel socketChannel, SSLEngine sslEngine, ExecutorService executor) { this.socketChannel = socketChannel; this.sslEngine = sslEngine; + this.executor = executor; int socketBufferSize = sslEngine.getSession().getPacketBufferSize(); socketReadBuffer = ByteBuffer.allocateDirect(socketBufferSize); @@ -142,7 +152,10 @@ public void close() { log.info(sm.getString("asyncChannelWrapperSecure.closeFail")); } } - executor.shutdownNow(); + + if (!(executor instanceof VirtualThreadExecutor)) { + executor.shutdownNow(); + } } @Override diff --git a/java/org/apache/tomcat/websocket/WsWebSocketContainer.java b/java/org/apache/tomcat/websocket/WsWebSocketContainer.java index 3827be38cc9d..ea938c3abcdf 100644 --- a/java/org/apache/tomcat/websocket/WsWebSocketContainer.java +++ b/java/org/apache/tomcat/websocket/WsWebSocketContainer.java @@ -70,6 +70,7 @@ import org.apache.tomcat.util.buf.StringUtils; import org.apache.tomcat.util.collections.CaseInsensitiveKeyMap; import org.apache.tomcat.util.res.StringManager; +import org.apache.tomcat.util.threads.VirtualThreadExecutor; public class WsWebSocketContainer implements WebSocketContainer, BackgroundProcess { @@ -102,6 +103,8 @@ public class WsWebSocketContainer implements WebSocketContainer, BackgroundProce private InstanceManager instanceManager; + private VirtualThreadExecutor virtualThreadExecutor; + protected InstanceManager getInstanceManager(ClassLoader classLoader) { if (instanceManager != null) { return instanceManager; @@ -302,7 +305,11 @@ private Session connectToServerRecursive(ClientEndpointHolder clientEndpointHold // proxy CONNECT, need to use TLS from this point on so wrap the // original AsynchronousSocketChannel SSLEngine sslEngine = createSSLEngine(clientEndpointConfiguration, host, port); - channel = new AsyncChannelWrapperSecure(socketChannel, sslEngine); + if (useVirtualThreads()) { + channel = new AsyncChannelWrapperSecure(socketChannel, sslEngine, virtualThreadExecutor); + } else { + channel = new AsyncChannelWrapperSecure(socketChannel, sslEngine); + } } else if (channel == null) { // Only need to wrap as this point if it wasn't wrapped to process a // proxy CONNECT @@ -1010,6 +1017,10 @@ public void destroy() { } } } + + if (useVirtualThreads()) { + virtualThreadExecutor.close(); + } } @@ -1028,6 +1039,18 @@ private AsynchronousChannelGroup getAsynchronousChannelGroup() { return result; } + public void setUseVirtualThreads(boolean useVirtualThreads) { + if (useVirtualThreads) { + virtualThreadExecutor = new VirtualThreadExecutor("WebSocketClient-IO-"); + } else { + virtualThreadExecutor = null; + } + } + + public boolean useVirtualThreads() { + return virtualThreadExecutor != null; + } + // ----------------------------------------------- BackgroundProcess methods