diff --git a/activemq-amqp/src/test/java/org/apache/activemq/transport/amqp/client/AmqpClientTestSupport.java b/activemq-amqp/src/test/java/org/apache/activemq/transport/amqp/client/AmqpClientTestSupport.java index d429b27c537..997bb8ded84 100644 --- a/activemq-amqp/src/test/java/org/apache/activemq/transport/amqp/client/AmqpClientTestSupport.java +++ b/activemq-amqp/src/test/java/org/apache/activemq/transport/amqp/client/AmqpClientTestSupport.java @@ -55,7 +55,7 @@ public void tearDown() throws Exception { } public String getConnectorScheme() { - return connectorScheme; + return connectorScheme.contains("ws") ? connectorScheme.replace("amqp+", "") : connectorScheme; } public boolean isUseSSL() { diff --git a/activemq-amqp/src/test/java/org/apache/activemq/transport/amqp/interop/AmqpConfiguredMaxConnectionsTest.java b/activemq-amqp/src/test/java/org/apache/activemq/transport/amqp/interop/AmqpConfiguredMaxConnectionsTest.java index 6b97698af6a..b77e3046996 100644 --- a/activemq-amqp/src/test/java/org/apache/activemq/transport/amqp/interop/AmqpConfiguredMaxConnectionsTest.java +++ b/activemq-amqp/src/test/java/org/apache/activemq/transport/amqp/interop/AmqpConfiguredMaxConnectionsTest.java @@ -45,7 +45,9 @@ public class AmqpConfiguredMaxConnectionsTest extends AmqpClientTestSupport { public static Collection data() { return Arrays.asList(new Object[][] { {"amqp", false}, + {"amqp+ws", false}, {"amqp+nio", false}, + {"amqp+wss", true} }); } diff --git a/activemq-http/src/main/java/org/apache/activemq/transport/ws/AbstractMQTTSocket.java b/activemq-http/src/main/java/org/apache/activemq/transport/ws/AbstractMQTTSocket.java index a3e5bf1a3a1..c992f2f78c8 100644 --- a/activemq-http/src/main/java/org/apache/activemq/transport/ws/AbstractMQTTSocket.java +++ b/activemq-http/src/main/java/org/apache/activemq/transport/ws/AbstractMQTTSocket.java @@ -35,7 +35,7 @@ import org.apache.activemq.util.ServiceStopper; import org.fusesource.mqtt.codec.MQTTFrame; -public abstract class AbstractMQTTSocket extends TransportSupport implements MQTTTransport, BrokerServiceAware { +public abstract class AbstractMQTTSocket extends AbstractWsSocket implements MQTTTransport, BrokerServiceAware { protected ReentrantLock protocolLock = new ReentrantLock(); protected volatile MQTTProtocolConverter protocolConverter = null; diff --git a/activemq-http/src/main/java/org/apache/activemq/transport/ws/AbstractStompSocket.java b/activemq-http/src/main/java/org/apache/activemq/transport/ws/AbstractStompSocket.java index dd25a1deeeb..30da1c5fd68 100644 --- a/activemq-http/src/main/java/org/apache/activemq/transport/ws/AbstractStompSocket.java +++ b/activemq-http/src/main/java/org/apache/activemq/transport/ws/AbstractStompSocket.java @@ -38,7 +38,7 @@ /** * Base implementation of a STOMP based WebSocket handler. */ -public abstract class AbstractStompSocket extends TransportSupport implements StompTransport { +public abstract class AbstractStompSocket extends AbstractWsSocket implements StompTransport { private static final Logger LOG = LoggerFactory.getLogger(AbstractStompSocket.class); diff --git a/activemq-http/src/main/java/org/apache/activemq/transport/ws/AbstractWsSocket.java b/activemq-http/src/main/java/org/apache/activemq/transport/ws/AbstractWsSocket.java new file mode 100644 index 00000000000..edd2fa6d468 --- /dev/null +++ b/activemq-http/src/main/java/org/apache/activemq/transport/ws/AbstractWsSocket.java @@ -0,0 +1,43 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.activemq.transport.ws; + +import org.apache.activemq.transport.TransportSupport; +import org.eclipse.jetty.websocket.api.WebSocketListener; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public abstract class AbstractWsSocket extends TransportSupport implements WebSocketListener { + + private static final Logger LOG = LoggerFactory.getLogger(AbstractWsSocket.class); + + @Override + public void onWebSocketClose(int statusCode, String reason) { + WebSocketListener.super.onWebSocketClose(statusCode, reason); + + try { + stop(); + LOG.debug("Stopped socket: {}", getRemoteAddress()); + } catch (Exception e) { + LOG.debug("Could not stop socket to {}. This exception is ignored.", getRemoteAddress(), e); + } + + doWebSocketClose(statusCode, reason); + } + + public abstract void doWebSocketClose(int statusCode, String reason); +} diff --git a/activemq-http/src/main/java/org/apache/activemq/transport/ws/StompWSConnection.java b/activemq-http/src/main/java/org/apache/activemq/transport/ws/StompWSConnection.java index 8c0fccb1cc8..f3a21a67af0 100644 --- a/activemq-http/src/main/java/org/apache/activemq/transport/ws/StompWSConnection.java +++ b/activemq-http/src/main/java/org/apache/activemq/transport/ws/StompWSConnection.java @@ -39,7 +39,9 @@ public class StompWSConnection extends WebSocketAdapter implements WebSocketList private static final Logger LOG = LoggerFactory.getLogger(StompWSConnection.class); private Session connection; + private Throwable connectionError; private final CountDownLatch connectLatch = new CountDownLatch(1); + private final CountDownLatch errorLatch = new CountDownLatch(1); private final BlockingQueue prefetch = new LinkedBlockingDeque(); private final StompWireFormat wireFormat = new StompWireFormat(); @@ -49,7 +51,7 @@ public class StompWSConnection extends WebSocketAdapter implements WebSocketList @Override public boolean isConnected() { - return connection != null ? connection.isOpen() : false; + return connection != null && connection.isOpen(); } public void close() { @@ -62,6 +64,9 @@ protected Session getConnection() { return connection; } + public Throwable getConnectionError() { + return connectionError; + } //---- Send methods ------------------------------------------------------// public synchronized void sendRawFrame(String rawFrame) throws Exception { @@ -106,6 +111,10 @@ public boolean awaitConnection(long time, TimeUnit unit) throws InterruptedExcep return connectLatch.await(time, unit); } + public boolean awaitError(long time, TimeUnit unit) throws InterruptedException { + return errorLatch.await(time, unit); + } + //----- Property Accessors -----------------------------------------------// public int getCloseCode() { @@ -148,6 +157,13 @@ public void onWebSocketConnect(org.eclipse.jetty.websocket.api.Session session) this.connectLatch.countDown(); } + @Override + public void onWebSocketError(Throwable cause) { + this.connection = null; + this.connectionError = cause; + this.errorLatch.countDown(); + } + //----- Internal implementation ------------------------------------------// private void checkConnected() throws IOException { diff --git a/activemq-http/src/main/java/org/apache/activemq/transport/ws/WSTransportProxy.java b/activemq-http/src/main/java/org/apache/activemq/transport/ws/WSTransportProxy.java index a8242d4eddd..d8e4016f0d3 100644 --- a/activemq-http/src/main/java/org/apache/activemq/transport/ws/WSTransportProxy.java +++ b/activemq-http/src/main/java/org/apache/activemq/transport/ws/WSTransportProxy.java @@ -43,7 +43,7 @@ * A proxy class that manages sending WebSocket events to the wrapped protocol level * WebSocket Transport. */ -public final class WSTransportProxy extends TransportSupport implements Transport, WebSocketListener, BrokerServiceAware, WSTransportSink { +public final class WSTransportProxy extends AbstractWsSocket implements Transport, WebSocketListener, BrokerServiceAware, WSTransportSink { private static final Logger LOG = LoggerFactory.getLogger(WSTransportProxy.class); @@ -201,7 +201,7 @@ public void onWebSocketText(String data) { } @Override - public void onWebSocketClose(int statusCode, String reason) { + public void doWebSocketClose(int statusCode, String reason) { try { if (protocolLock.tryLock() || protocolLock.tryLock(ORDERLY_CLOSE_TIMEOUT, TimeUnit.SECONDS)) { LOG.debug("WebSocket closed: code[{}] message[{}]", statusCode, reason); diff --git a/activemq-http/src/main/java/org/apache/activemq/transport/ws/WSTransportServer.java b/activemq-http/src/main/java/org/apache/activemq/transport/ws/WSTransportServer.java index 75cf1fd56cc..430a597f5b0 100644 --- a/activemq-http/src/main/java/org/apache/activemq/transport/ws/WSTransportServer.java +++ b/activemq-http/src/main/java/org/apache/activemq/transport/ws/WSTransportServer.java @@ -20,6 +20,7 @@ import java.net.InetSocketAddress; import java.net.URI; import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; import jakarta.servlet.Servlet; @@ -53,6 +54,11 @@ public class WSTransportServer extends WebTransportServerSupport implements Brok private BrokerService brokerService; private WSServlet servlet; + /** + * The maximum number of sockets allowed for this server + */ + protected int maximumConnections = Integer.MAX_VALUE; + public WSTransportServer(URI location) { super(location); this.bindAddress = location; @@ -122,6 +128,7 @@ private Servlet createWSServlet() throws Exception { servlet = new WSServlet(); servlet.setTransportOptions(transportOptions); servlet.setBrokerService(brokerService); + servlet.setMaximumConnections(maximumConnections); return servlet; } @@ -176,12 +183,35 @@ public void setBrokerService(BrokerService brokerService) { @Override public long getMaxConnectionExceededCount() { - // Max Connection Count not supported for ws - return -1l; + if (servlet != null) { + return servlet.getMaxConnectionExceededCount(); + } + return 0; } @Override public void resetStatistics() { - // Statistics not implemented for ws + if(servlet != null) { + servlet.resetStatistics(); + } } + + public int getMaximumConnections() { + return maximumConnections; + } + + public void setMaximumConnections(int maximumConnections) { + this.maximumConnections = maximumConnections; + if (servlet != null) { + servlet.setMaximumConnections(maximumConnections); + } + } + + public AtomicInteger getCurrentTransportCount() { + if (servlet != null) { + return servlet.getCurrentTransportCount(); + } + return new AtomicInteger(0); + } + } diff --git a/activemq-http/src/main/java/org/apache/activemq/transport/ws/jetty11/MQTTSocket.java b/activemq-http/src/main/java/org/apache/activemq/transport/ws/jetty11/MQTTSocket.java index 853c201c346..ae4f5c3ba5a 100644 --- a/activemq-http/src/main/java/org/apache/activemq/transport/ws/jetty11/MQTTSocket.java +++ b/activemq-http/src/main/java/org/apache/activemq/transport/ws/jetty11/MQTTSocket.java @@ -95,7 +95,7 @@ public void onWebSocketBinary(byte[] bytes, int offset, int length) { } @Override - public void onWebSocketClose(int arg0, String arg1) { + public void doWebSocketClose(int arg0, String arg1) { try { if (protocolLock.tryLock() || protocolLock.tryLock(ORDERLY_CLOSE_TIMEOUT, TimeUnit.SECONDS)) { LOG.debug("MQTT WebSocket closed: code[{}] message[{}]", arg0, arg1); @@ -120,14 +120,6 @@ public void onWebSocketConnect(Session session) { this.session.setIdleTimeout(Duration.ZERO); } - @Override - public void onWebSocketError(Throwable arg0) { - - } - - @Override - public void onWebSocketText(String arg0) { - } private static int getDefaultSendTimeOut() { return Integer.getInteger("org.apache.activemq.transport.ws.MQTTSocket.sendTimeout", 30); diff --git a/activemq-http/src/main/java/org/apache/activemq/transport/ws/jetty11/StompSocket.java b/activemq-http/src/main/java/org/apache/activemq/transport/ws/jetty11/StompSocket.java index 5c718c8ab06..28e099d3461 100644 --- a/activemq-http/src/main/java/org/apache/activemq/transport/ws/jetty11/StompSocket.java +++ b/activemq-http/src/main/java/org/apache/activemq/transport/ws/jetty11/StompSocket.java @@ -69,7 +69,7 @@ public void onWebSocketBinary(byte[] arg0, int arg1, int arg2) { } @Override - public void onWebSocketClose(int arg0, String arg1) { + public void doWebSocketClose(int arg0, String arg1) { try { if (protocolLock.tryLock() || protocolLock.tryLock(ORDERLY_CLOSE_TIMEOUT, TimeUnit.SECONDS)) { LOG.debug("Stomp WebSocket closed: code[{}] message[{}]", arg0, arg1); diff --git a/activemq-http/src/main/java/org/apache/activemq/transport/ws/jetty11/WSServlet.java b/activemq-http/src/main/java/org/apache/activemq/transport/ws/jetty11/WSServlet.java index 607248e4dee..9fbcdd9c4d4 100644 --- a/activemq-http/src/main/java/org/apache/activemq/transport/ws/jetty11/WSServlet.java +++ b/activemq-http/src/main/java/org/apache/activemq/transport/ws/jetty11/WSServlet.java @@ -26,18 +26,23 @@ import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; import jakarta.servlet.ServletException; import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; +import org.apache.activemq.Service; import org.apache.activemq.broker.BrokerService; import org.apache.activemq.broker.BrokerServiceAware; import org.apache.activemq.transport.Transport; import org.apache.activemq.transport.TransportAcceptListener; import org.apache.activemq.transport.TransportFactory; +import org.apache.activemq.transport.tcp.ExceededMaximumConnectionsException; import org.apache.activemq.transport.util.HttpTransportUtils; import org.apache.activemq.transport.ws.WSTransportProxy; +import org.apache.activemq.util.ServiceListener; import org.eclipse.jetty.websocket.api.WebSocketListener; import org.eclipse.jetty.websocket.server.JettyServerUpgradeRequest; import org.eclipse.jetty.websocket.server.JettyServerUpgradeResponse; @@ -48,7 +53,7 @@ /** * Handle connection upgrade requests and creates web sockets */ -public class WSServlet extends JettyWebSocketServlet implements BrokerServiceAware { +public class WSServlet extends JettyWebSocketServlet implements BrokerServiceAware, ServiceListener { private static final long serialVersionUID = -4716657876092884139L; @@ -60,6 +65,10 @@ public class WSServlet extends JettyWebSocketServlet implements BrokerServiceAwa private Map transportOptions; private BrokerService brokerService; + private int maximumConnections = Integer.MAX_VALUE; + protected final AtomicLong maximumConnectionsExceededCount = new AtomicLong(); + private final AtomicInteger currentTransportCount = new AtomicInteger(); + private enum Protocol { MQTT, STOMP, UNKNOWN } @@ -93,6 +102,23 @@ public void configure(JettyWebSocketServletFactory factory) { factory.setCreator(new JettyWebSocketCreator() { @Override public Object createWebSocket(JettyServerUpgradeRequest req, JettyServerUpgradeResponse resp) { + int currentCount; + do { + currentCount = currentTransportCount.get(); + if (currentCount >= maximumConnections) { + maximumConnectionsExceededCount.incrementAndGet(); + listener.onAcceptError(new ExceededMaximumConnectionsException( + "Exceeded the maximum number of allowed client connections. See the '" + + "maximumConnections' property on the WS transport configuration URI " + + "in the ActiveMQ configuration file (e.g., activemq.xml)")); + return null; + } + + //Increment this value before configuring the transport + //This is necessary because some of the transport servers must read from the + //socket during configureTransport() so we want to make sure this value is + //accurate as the transport server could pause here waiting for data to be sent from a client + } while(!currentTransportCount.compareAndSet(currentCount, currentCount + 1)); WebSocketListener socket; Protocol requestedProtocol = Protocol.UNKNOWN; @@ -114,6 +140,7 @@ public Object createWebSocket(JettyServerUpgradeRequest req, JettyServerUpgradeR socket = new MQTTSocket(HttpTransportUtils.generateWsRemoteAddress(req.getHttpServletRequest())); ((MQTTSocket) socket).setTransportOptions(new HashMap<>(transportOptions)); ((MQTTSocket) socket).setPeerCertificates(req.getCertificates()); + ((MQTTSocket) socket).addServiceListener(WSServlet.this); resp.setAcceptedSubProtocol(getAcceptedSubProtocol(mqttProtocols, req.getSubProtocols(), "mqtt")); break; case UNKNOWN: @@ -124,11 +151,13 @@ public Object createWebSocket(JettyServerUpgradeRequest req, JettyServerUpgradeR case STOMP: socket = new StompSocket(HttpTransportUtils.generateWsRemoteAddress(req.getHttpServletRequest())); ((StompSocket) socket).setPeerCertificates(req.getCertificates()); + ((StompSocket) socket).addServiceListener(WSServlet.this); resp.setAcceptedSubProtocol(getAcceptedSubProtocol(stompProtocols, req.getSubProtocols(), "stomp")); break; default: socket = null; listener.onAcceptError(new IOException("Unknown protocol requested")); + currentTransportCount.decrementAndGet(); break; } @@ -160,6 +189,7 @@ private WebSocketListener findWSTransport(JettyServerUpgradeRequest request, Jet proxy = new WSTransportProxy(remoteAddress, transport); proxy.setPeerCertificates(request.getCertificates()); proxy.setTransportOptions(new HashMap<>(transportOptions)); + proxy.addServiceListener(this); response.setAcceptedSubProtocol(proxy.getSubProtocol()); } catch (Exception e) { @@ -217,4 +247,42 @@ public void setTransportOptions(Map transportOptions) { public void setBrokerService(BrokerService brokerService) { this.brokerService = brokerService; } + + + @Override + public void started(Service service) { + } + + @Override + public void stopped(Service service) { + this.currentTransportCount.decrementAndGet(); + } + + /** + * @return the maximumConnections + */ + public int getMaximumConnections() { + return maximumConnections; + } + + + public long getMaxConnectionExceededCount() { + return this.maximumConnectionsExceededCount.get(); + } + + public void resetStatistics() { + this.maximumConnectionsExceededCount.set(0L); + } + + /** + * @param maximumConnections + * the maximumConnections to set + */ + public void setMaximumConnections(int maximumConnections) { + this.maximumConnections = maximumConnections; + } + + public AtomicInteger getCurrentTransportCount() { + return currentTransportCount; + } } diff --git a/activemq-http/src/test/java/org/apache/activemq/transport/ws/MQTTWSConfiguredMaxConnectionsTest.java b/activemq-http/src/test/java/org/apache/activemq/transport/ws/MQTTWSConfiguredMaxConnectionsTest.java new file mode 100644 index 00000000000..e8fb98162f2 --- /dev/null +++ b/activemq-http/src/test/java/org/apache/activemq/transport/ws/MQTTWSConfiguredMaxConnectionsTest.java @@ -0,0 +1,123 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.activemq.transport.ws; + +import org.apache.activemq.util.Wait; +import org.eclipse.jetty.client.HttpClient; +import org.eclipse.jetty.client.dynamic.HttpClientTransportDynamic; +import org.eclipse.jetty.io.ClientConnector; +import org.eclipse.jetty.util.ssl.SslContextFactory; +import org.eclipse.jetty.websocket.client.ClientUpgradeRequest; +import org.eclipse.jetty.websocket.client.WebSocketClient; +import org.fusesource.hawtbuf.UTF8Buffer; +import org.fusesource.mqtt.codec.CONNACK; +import org.fusesource.mqtt.codec.CONNECT; +import org.fusesource.mqtt.codec.MQTTFrame; +import org.junit.Test; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.UUID; +import java.util.concurrent.TimeUnit; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +public class MQTTWSConfiguredMaxConnectionsTest extends WSTransportTestSupport { + + private static final String connectorScheme = "ws"; + + private static final int MAX_CONNECTIONS = 10; + + @Test(timeout = 60000) + public void testMaxConnectionsSettingIsHonored() throws Exception { + List connections = new ArrayList<>(); + + for (int i = 0; i < MAX_CONNECTIONS; i++) { + MQTTWSConnection connection = createConnection(); + if (!connection.awaitConnection(30, TimeUnit.SECONDS)) { + throw new IOException("Could not connect to STOMP WS endpoint"); + } + + connections.add(connection); + + CONNECT command = new CONNECT(); + + command.clientId(new UTF8Buffer(UUID.randomUUID().toString())); + command.cleanSession(false); + command.version(3); + command.keepAlive((short) 60); + + connection.sendFrame(command.encode()); + + MQTTFrame received = connection.receive(15, TimeUnit.SECONDS); + if (received == null || received.messageType() != CONNACK.TYPE) { + fail("Client did not get expected CONNACK"); + } + } + + assertEquals(MAX_CONNECTIONS, getProxyToBroker().getCurrentConnectionsCount()); + assertEquals(Long.valueOf(0L), Long.valueOf(getProxyToConnectionView(connectorScheme).getMaxConnectionExceededCount())); + + { + MQTTWSConnection connection = createConnection(); + if (!connection.awaitError(30, TimeUnit.SECONDS)) { + throw new IOException("WS endpoint has maximumConnections=" + MAX_CONNECTIONS); + } + + connections.add(connection); + } + + assertEquals(MAX_CONNECTIONS, getProxyToBroker().getCurrentConnectionsCount()); + assertEquals(1, getProxyToConnectionView(connectorScheme).getMaxConnectionExceededCount()); + + for (MQTTWSConnection connection : connections) { + connection.close(); + } + assertTrue("Connection should close", Wait.waitFor(() -> getProxyToBroker().getCurrentConnectionsCount() == 0)); + + // Confirm reset statistics + getProxyToConnectionView(connectorScheme).resetStatistics(); + assertEquals(0, getProxyToConnectionView(connectorScheme).getMaxConnectionExceededCount()); + } + + @Override + protected String getWSConnectorURI() { + return super.getWSConnectorURI() + "&maximumConnections=" + MAX_CONNECTIONS; + } + + private MQTTWSConnection createConnection() throws Exception { + MQTTWSConnection connection = new MQTTWSConnection(); + + ClientUpgradeRequest request = new ClientUpgradeRequest(); + request.setSubProtocols("mqttv3.1"); + + SslContextFactory.Client sslContextFactory = new SslContextFactory.Client(); + sslContextFactory.setTrustAll(true); + ClientConnector clientConnector = new ClientConnector(); + clientConnector.setSslContextFactory(sslContextFactory); + + WebSocketClient wsClient = new WebSocketClient(new HttpClient(new HttpClientTransportDynamic(clientConnector))); + wsClient.start(); + + wsClient.connect(connection, wsConnectUri, request); + + return connection; + } +} diff --git a/activemq-http/src/test/java/org/apache/activemq/transport/ws/MQTTWSConnection.java b/activemq-http/src/test/java/org/apache/activemq/transport/ws/MQTTWSConnection.java index e127d077d0c..6641f7adb05 100644 --- a/activemq-http/src/test/java/org/apache/activemq/transport/ws/MQTTWSConnection.java +++ b/activemq-http/src/test/java/org/apache/activemq/transport/ws/MQTTWSConnection.java @@ -56,7 +56,9 @@ public class MQTTWSConnection extends WebSocketAdapter implements WebSocketListe private static final MQTTFrame PING_RESP_FRAME = new PINGRESP().encode(); private Session connection; + private Throwable connectionError; private final CountDownLatch connectLatch = new CountDownLatch(1); + private final CountDownLatch errorLatch = new CountDownLatch(1); private final MQTTWireFormat wireFormat = new MQTTWireFormat(); private final BlockingQueue prefetch = new LinkedBlockingDeque<>(); @@ -80,6 +82,10 @@ protected Session getConnection() { return connection; } + public Throwable getConnectionError() { + return connectionError; + } + //----- Connection and Disconnection methods -----------------------------// public void connect() throws Exception { @@ -160,6 +166,10 @@ public boolean awaitConnection(long time, TimeUnit unit) throws InterruptedExcep return connectLatch.await(time, unit); } + public boolean awaitError(long time, TimeUnit unit) throws InterruptedException { + return errorLatch.await(time, unit); + } + //----- Property Accessors -----------------------------------------------// public int getCloseCode() { @@ -286,4 +296,11 @@ public void onWebSocketConnect(org.eclipse.jetty.websocket.api.Session session) this.connection.setIdleTimeout(Duration.ZERO); this.connectLatch.countDown(); } + + @Override + public void onWebSocketError(Throwable cause) { + this.connection = null; + this.connectionError = cause; + this.errorLatch.countDown(); + } } diff --git a/activemq-http/src/test/java/org/apache/activemq/transport/ws/StompWSConfiguredMaxConnectionsTest.java b/activemq-http/src/test/java/org/apache/activemq/transport/ws/StompWSConfiguredMaxConnectionsTest.java new file mode 100644 index 00000000000..49e5e168864 --- /dev/null +++ b/activemq-http/src/test/java/org/apache/activemq/transport/ws/StompWSConfiguredMaxConnectionsTest.java @@ -0,0 +1,117 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.activemq.transport.ws; + +import org.apache.activemq.transport.stomp.Stomp; +import org.apache.activemq.util.Wait; +import org.eclipse.jetty.client.HttpClient; +import org.eclipse.jetty.client.dynamic.HttpClientTransportDynamic; +import org.eclipse.jetty.io.ClientConnector; +import org.eclipse.jetty.util.ssl.SslContextFactory; +import org.eclipse.jetty.websocket.client.ClientUpgradeRequest; +import org.eclipse.jetty.websocket.client.WebSocketClient; +import org.junit.Test; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.TimeUnit; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +public class StompWSConfiguredMaxConnectionsTest extends WSTransportTestSupport { + + private static final String connectorScheme = "ws"; + + private static final int MAX_CONNECTIONS = 10; + + @Test(timeout = 60000) + public void testMaxConnectionsSettingIsHonored() throws Exception { + List connections = new ArrayList<>(); + + for (int i = 0; i < MAX_CONNECTIONS; i++) { + StompWSConnection connection = createConnection(); + if (!connection.awaitConnection(30, TimeUnit.SECONDS)) { + throw new IOException("Could not connect to STOMP WS endpoint"); + } + + connections.add(connection); + + String connectFrame = "STOMP\n" + + "login:system\n" + + "passcode:manager\n" + + "accept-version:1.0,1.1\n" + + "host:localhost\n" + + "\n" + Stomp.NULL; + connection.sendRawFrame(connectFrame); + + String incoming = connection.receive(30, TimeUnit.SECONDS); + assertNotNull(incoming); + assertTrue(incoming.startsWith("CONNECTED")); + } + + assertEquals(MAX_CONNECTIONS, getProxyToBroker().getCurrentConnectionsCount()); + assertEquals(Long.valueOf(0l), Long.valueOf(getProxyToConnectionView(connectorScheme).getMaxConnectionExceededCount())); + + { + StompWSConnection connection = createConnection(); + if (!connection.awaitError(30, TimeUnit.SECONDS)) { + throw new IOException("WS endpoint has maximumConnections=" + MAX_CONNECTIONS); + } + + connections.add(connection); + } + + assertEquals(MAX_CONNECTIONS, getProxyToBroker().getCurrentConnectionsCount()); + assertEquals(1, getProxyToConnectionView(connectorScheme).getMaxConnectionExceededCount()); + + for (StompWSConnection connection : connections) { + connection.close(); + } + assertTrue("Connection should close", Wait.waitFor(() -> getProxyToBroker().getCurrentConnectionsCount() == 0)); + + // Confirm reset statistics + getProxyToConnectionView(connectorScheme).resetStatistics(); + assertEquals(0, getProxyToConnectionView(connectorScheme).getMaxConnectionExceededCount()); + } + + @Override + protected String getWSConnectorURI() { + return super.getWSConnectorURI() + "&maximumConnections=" + MAX_CONNECTIONS; + } + + private StompWSConnection createConnection() throws Exception { + StompWSConnection connection = new StompWSConnection(); + + ClientUpgradeRequest request = new ClientUpgradeRequest(); + request.setSubProtocols("v11.stomp"); + + SslContextFactory.Client sslContextFactory = new SslContextFactory.Client(); + sslContextFactory.setTrustAll(true); + ClientConnector clientConnector = new ClientConnector(); + clientConnector.setSslContextFactory(sslContextFactory); + + WebSocketClient wsClient = new WebSocketClient(new HttpClient(new HttpClientTransportDynamic(clientConnector))); + wsClient.start(); + + wsClient.connect(connection, wsConnectUri, request); + + return connection; + } +} diff --git a/activemq-http/src/test/java/org/apache/activemq/transport/ws/WSTransportTestSupport.java b/activemq-http/src/test/java/org/apache/activemq/transport/ws/WSTransportTestSupport.java index 8c4a2c76844..d3a292e1065 100644 --- a/activemq-http/src/test/java/org/apache/activemq/transport/ws/WSTransportTestSupport.java +++ b/activemq-http/src/test/java/org/apache/activemq/transport/ws/WSTransportTestSupport.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.net.ServerSocket; import java.net.URI; +import java.util.Set; import jakarta.jms.JMSException; import javax.management.MalformedObjectNameException; @@ -27,6 +28,7 @@ import org.apache.activemq.broker.BrokerService; import org.apache.activemq.broker.jmx.BrokerViewMBean; +import org.apache.activemq.broker.jmx.ConnectorViewMBean; import org.apache.activemq.broker.jmx.QueueViewMBean; import org.apache.activemq.broker.jmx.TopicViewMBean; import org.apache.activemq.spring.SpringSslContext; @@ -173,4 +175,20 @@ protected TopicViewMBean getProxyToTopic(String name) throws MalformedObjectName .newProxyInstance(topicViewMBeanName, TopicViewMBean.class, true); return proxy; } + + + protected ConnectorViewMBean getProxyToConnectionView(String connectionType) throws Exception { + ObjectName connectorQuery = new ObjectName( + "org.apache.activemq:type=Broker,brokerName=localhost,connector=clientConnectors,connectorName="+connectionType+"_//*"); + + Set results = broker.getManagementContext().queryNames(connectorQuery, null); + + if (results == null || results.size() != 1) { + throw new Exception("Unable to find the exact Connector instance."); + } + + ConnectorViewMBean proxy = (ConnectorViewMBean) broker.getManagementContext() + .newProxyInstance(results.iterator().next(), ConnectorViewMBean.class, true); + return proxy; + } }