diff --git a/src/main/java/com/hivemq/bootstrap/ClientConnection.java b/src/main/java/com/hivemq/bootstrap/ClientConnection.java index c5a94cec4..db5b59a09 100644 --- a/src/main/java/com/hivemq/bootstrap/ClientConnection.java +++ b/src/main/java/com/hivemq/bootstrap/ClientConnection.java @@ -31,6 +31,7 @@ import com.hivemq.mqtt.handler.publish.PublishFlushHandler; import com.hivemq.mqtt.message.ProtocolVersion; import com.hivemq.mqtt.message.connect.CONNECT; +import com.hivemq.mqtt.message.connect.MqttWillPublish; import com.hivemq.mqtt.message.mqtt5.Mqtt5UserProperties; import com.hivemq.mqtt.message.pool.FreePacketIdRanges; import com.hivemq.security.auth.SslClientCertificate; @@ -60,7 +61,7 @@ public class ClientConnection implements ClientConnectionContext { private @NotNull String clientId; private boolean cleanStart; private @Nullable ModifiableDefaultPermissions authPermissions; - private @Nullable CONNECT connectMessage; + private @Nullable MqttWillPublish willPublish; private @Nullable AtomicInteger inFlightMessageCount; private @Nullable Integer clientReceiveMaximum; private @Nullable Integer connectKeepAlive; @@ -131,7 +132,7 @@ public class ClientConnection implements ClientConnectionContext { context.cleanStart, context.authPermissions, context.connectedListener, - context.connectMessage, + context.willPublish, context.clientReceiveMaximum, context.connectKeepAlive, context.queueSizeMaximum, @@ -178,7 +179,7 @@ public ClientConnection( final boolean cleanStart, final @Nullable ModifiableDefaultPermissions authPermissions, final @NotNull Listener connectedListener, - final @Nullable CONNECT connectMessage, + final @Nullable MqttWillPublish mqttWillPublish, final @Nullable Integer clientReceiveMaximum, final @Nullable Integer connectKeepAlive, final @Nullable Long queueSizeMaximum, @@ -219,7 +220,7 @@ public ClientConnection( this.cleanStart = cleanStart; this.authPermissions = authPermissions; this.connectedListener = connectedListener; - this.connectMessage = connectMessage; + this.willPublish = mqttWillPublish; this.clientReceiveMaximum = clientReceiveMaximum; this.connectKeepAlive = connectKeepAlive; this.queueSizeMaximum = queueSizeMaximum; @@ -327,13 +328,13 @@ public void setAuthPermissions(final @NotNull ModifiableDefaultPermissions authP return connectedListener; } - public @Nullable CONNECT getConnectMessage() { - return connectMessage; + public @Nullable MqttWillPublish getWillPublish() { + return willPublish; } @Override - public void setConnectMessage(final @Nullable CONNECT connectMessage) { - this.connectMessage = connectMessage; + public void setWillPublish(final @Nullable MqttWillPublish willPublish) { + this.willPublish = willPublish; } /** diff --git a/src/main/java/com/hivemq/bootstrap/ClientConnectionContext.java b/src/main/java/com/hivemq/bootstrap/ClientConnectionContext.java index 7ee017f90..a75f58e8e 100644 --- a/src/main/java/com/hivemq/bootstrap/ClientConnectionContext.java +++ b/src/main/java/com/hivemq/bootstrap/ClientConnectionContext.java @@ -29,6 +29,7 @@ import com.hivemq.extensions.events.client.parameters.ClientEventListeners; import com.hivemq.mqtt.message.ProtocolVersion; import com.hivemq.mqtt.message.connect.CONNECT; +import com.hivemq.mqtt.message.connect.MqttWillPublish; import com.hivemq.mqtt.message.mqtt5.Mqtt5UserProperties; import com.hivemq.security.auth.SslClientCertificate; import io.netty.channel.Channel; @@ -136,7 +137,7 @@ public interface ClientConnectionContext { void setRequestProblemInformation(boolean problemInformationRequested); - void setConnectMessage(@Nullable CONNECT msg); + void setWillPublish(@Nullable MqttWillPublish willPublish); @NotNull String @Nullable [] getTopicAliasMapping(); diff --git a/src/main/java/com/hivemq/bootstrap/UndefinedClientConnection.java b/src/main/java/com/hivemq/bootstrap/UndefinedClientConnection.java index f93e55da8..c5debf035 100644 --- a/src/main/java/com/hivemq/bootstrap/UndefinedClientConnection.java +++ b/src/main/java/com/hivemq/bootstrap/UndefinedClientConnection.java @@ -30,6 +30,7 @@ import com.hivemq.mqtt.handler.publish.PublishFlushHandler; import com.hivemq.mqtt.message.ProtocolVersion; import com.hivemq.mqtt.message.connect.CONNECT; +import com.hivemq.mqtt.message.connect.MqttWillPublish; import com.hivemq.mqtt.message.mqtt5.Mqtt5UserProperties; import com.hivemq.security.auth.SslClientCertificate; import io.netty.channel.Channel; @@ -53,7 +54,7 @@ public class UndefinedClientConnection implements ClientConnectionContext { @Nullable String clientId; boolean cleanStart; @Nullable ModifiableDefaultPermissions authPermissions; - @Nullable CONNECT connectMessage; + @Nullable MqttWillPublish willPublish; @Nullable Integer clientReceiveMaximum; @Nullable Integer connectKeepAlive; @Nullable Long queueSizeMaximum; @@ -165,8 +166,8 @@ public void setAuthPermissions(final @NotNull ModifiableDefaultPermissions authP } @Override - public void setConnectMessage(final @Nullable CONNECT connectMessage) { - this.connectMessage = connectMessage; + public void setWillPublish(final @Nullable MqttWillPublish willPublish) { + this.willPublish = willPublish; } @Override diff --git a/src/main/java/com/hivemq/extensions/handler/PluginInitializerHandler.java b/src/main/java/com/hivemq/extensions/handler/PluginInitializerHandler.java index 898a64195..a95884e1d 100644 --- a/src/main/java/com/hivemq/extensions/handler/PluginInitializerHandler.java +++ b/src/main/java/com/hivemq/extensions/handler/PluginInitializerHandler.java @@ -37,7 +37,6 @@ import com.hivemq.mqtt.handler.connack.MqttConnacker; import com.hivemq.mqtt.handler.publish.DefaultPermissionsEvaluator; import com.hivemq.mqtt.message.connack.CONNACK; -import com.hivemq.mqtt.message.connect.CONNECT; import com.hivemq.mqtt.message.connect.MqttWillPublish; import com.hivemq.mqtt.message.mqtt5.Mqtt5UserProperties; import com.hivemq.mqtt.message.reason.Mqtt5ConnAckReasonCode; @@ -124,9 +123,9 @@ private void fireInitialize( if (pluginInitializerMap.isEmpty() && msg != null) { clientConnection.setPreventLwt(false); ctx.writeAndFlush(msg, promise); - // Prevent leaking the retained CONNECT message for any existing ClientConnection. - // The CONNECT message would otherwise be owned by the plugin initialization below outside this scope. - clientConnection.setConnectMessage(null); + // Prevent leaking the retained WILL message for any existing ClientConnection. + // The WILL message would otherwise be owned by the plugin initialization below outside this scope. + clientConnection.setWillPublish(null); return; } @@ -175,14 +174,14 @@ private void fireInitialize( @Override public void onSuccess(@Nullable final Void result) { authenticateWill(ctx, msg, promise); - clientConnection.setConnectMessage(null); + clientConnection.setWillPublish(null); } @Override public void onFailure(final @NotNull Throwable t) { Exceptions.rethrowError(t); log.error("Calling initializer failed", t); - clientConnection.setConnectMessage(null); + clientConnection.setWillPublish(null); ctx.writeAndFlush(msg, promise); } }, ctx.executor()); @@ -195,13 +194,12 @@ private void authenticateWill( final ClientConnection clientConnection = ClientConnection.of(ctx.channel()); - final CONNECT connect = clientConnection.getConnectMessage(); - if (connect == null || connect.getWillPublish() == null) { + final MqttWillPublish willPublish = clientConnection.getWillPublish(); + if (willPublish == null) { ctx.writeAndFlush(msg, promise); return; } - final MqttWillPublish willPublish = connect.getWillPublish(); final ModifiableDefaultPermissions permissions = clientConnection.getAuthPermissions(); if (DefaultPermissionsEvaluator.checkWillPublish(permissions, willPublish)) { clientConnection.setPreventLwt(false); //clear prevent flag, Will is authorized @@ -213,7 +211,7 @@ private void authenticateWill( clientConnection.setPreventLwt(true); //We have already added the will to the session, so we need to remove it again final ListenableFuture removeWillFuture = - clientSessionPersistence.deleteWill(connect.getClientIdentifier()); + clientSessionPersistence.deleteWill(clientConnection.getClientId()); Futures.addCallback(removeWillFuture, new FutureCallback<>() { @Override public void onSuccess(@Nullable final Void result) { diff --git a/src/main/java/com/hivemq/mqtt/handler/connect/ConnectHandler.java b/src/main/java/com/hivemq/mqtt/handler/connect/ConnectHandler.java index 4f46cb468..0fb5ab37a 100644 --- a/src/main/java/com/hivemq/mqtt/handler/connect/ConnectHandler.java +++ b/src/main/java/com/hivemq/mqtt/handler/connect/ConnectHandler.java @@ -606,8 +606,8 @@ private void sendConnackSuccess( final ChannelFuture connackSent; - // We retain the CONNECT message in memory during the initialization progress, e.g. for plugin initialization. - clientConnection.setConnectMessage(msg); + // We retain the WILL message in memory during the initialization progress, e.g. for plugin initialization. + clientConnection.setWillPublish(msg.getWillPublish()); if (msg.getProtocolVersion() == ProtocolVersion.MQTTv5) { final CONNACK connack = buildMqtt5Connack(clientConnection, msg, sessionPresent); diff --git a/src/test/java/com/hivemq/extensions/handler/PluginInitializerHandlerTest.java b/src/test/java/com/hivemq/extensions/handler/PluginInitializerHandlerTest.java index a6b8ad696..34c68606c 100644 --- a/src/test/java/com/hivemq/extensions/handler/PluginInitializerHandlerTest.java +++ b/src/test/java/com/hivemq/extensions/handler/PluginInitializerHandlerTest.java @@ -41,7 +41,6 @@ import com.hivemq.mqtt.message.ProtocolVersion; import com.hivemq.mqtt.message.QoS; import com.hivemq.mqtt.message.connack.CONNACK; -import com.hivemq.mqtt.message.connect.CONNECT; import com.hivemq.mqtt.message.connect.MqttWillPublish; import com.hivemq.mqtt.message.mqtt5.Mqtt5UserProperties; import com.hivemq.mqtt.message.reason.Mqtt5ConnAckReasonCode; @@ -112,7 +111,7 @@ public void setUp() throws Exception { channel = new EmbeddedChannel(); clientConnection = new DummyClientConnection(channel, publishFlushHandler); - clientConnection.setConnectMessage(mock(CONNECT.class)); + clientConnection.setWillPublish(mock(MqttWillPublish.class)); clientConnection.setClientId("test_client"); clientConnection.setProtocolVersion(ProtocolVersion.MQTTv5); @@ -163,7 +162,7 @@ public void test_write_connack_no_initializer() throws Exception { verify(channelHandlerContext).writeAndFlush(any(Object.class), eq(channelPromise)); assertFalse(ClientConnection.of(channel).isPreventLwt()); - assertNull(clientConnection.getConnectMessage()); + assertNull(clientConnection.getWillPublish()); } @Test(timeout = 10000) @@ -193,7 +192,7 @@ public void test_write_connack_fire_initialize() throws Exception { verify(initializers, timeout(5000).times(1)).getClientInitializerMap(); verify(channelHandlerContext, timeout(5000)).writeAndFlush(any(Object.class), eq(channelPromise)); verify(channelPipeline).remove(any(ChannelHandler.class)); - assertNull(clientConnection.getConnectMessage()); + assertNull(clientConnection.getWillPublish()); } @Test(timeout = 10000) @@ -215,10 +214,7 @@ public void test_write_will_publish_not_authorized() throws Exception { .withPayload(new byte[]{1, 2, 3}) .build(); - final CONNECT connect = - new CONNECT.Mqtt5Builder().withClientIdentifier("test-client").withWillPublish(willPublish).build(); - - ClientConnection.of(channel).setConnectMessage(connect); + ClientConnection.of(channel).setWillPublish(willPublish); final ModifiableDefaultPermissionsImpl permissions = new ModifiableDefaultPermissionsImpl(); permissions.add(new TopicPermissionBuilderImpl(new TestConfigurationBootstrap().getFullConfigurationService()).topicFilter( @@ -238,7 +234,7 @@ public void test_write_will_publish_not_authorized() throws Exception { verify(channelPipeline).remove(any(ChannelHandler.class)); assertTrue(ClientConnection.of(channel).isPreventLwt()); - assertNull(clientConnection.getConnectMessage()); + assertNull(clientConnection.getWillPublish()); } @Test(timeout = 10000) @@ -251,10 +247,7 @@ public void test_write_will_publish_authorized() throws Exception { .withPayload(new byte[]{1, 2, 3}) .build(); - final CONNECT connect = - new CONNECT.Mqtt5Builder().withClientIdentifier("test-client").withWillPublish(willPublish).build(); - - ClientConnection.of(channel).setConnectMessage(connect); + ClientConnection.of(channel).setWillPublish(willPublish); final ModifiableDefaultPermissionsImpl permissions = new ModifiableDefaultPermissionsImpl(); permissions.add(new TopicPermissionBuilderImpl(new TestConfigurationBootstrap().getFullConfigurationService()).topicFilter( @@ -271,7 +264,7 @@ public void test_write_will_publish_authorized() throws Exception { verify(channelPipeline).remove(any(ChannelHandler.class)); assertFalse(ClientConnection.of(channel).isPreventLwt()); - assertNull(clientConnection.getConnectMessage()); + assertNull(clientConnection.getWillPublish()); } private Map createClientInitializerMap() throws Exception {