diff --git a/mqtt-streaming/src/main/scala/akka/stream/alpakka/mqtt/streaming/impl/ClientState.scala b/mqtt-streaming/src/main/scala/akka/stream/alpakka/mqtt/streaming/impl/ClientState.scala index 3710301150..ce6591e0a7 100644 --- a/mqtt-streaming/src/main/scala/akka/stream/alpakka/mqtt/streaming/impl/ClientState.scala +++ b/mqtt-streaming/src/main/scala/akka/stream/alpakka/mqtt/streaming/impl/ClientState.scala @@ -444,8 +444,7 @@ import scala.util.{Either, Failure, Success} val consumer = context.spawn(Consumer(publish, None, packetId, local, data.consumerPacketRouter, data.settings), consumerName) - context.watchWith(consumer, ConsumerFree(publish.topicName)) - + context.watch(consumer) serverConnected(data.copy(activeConsumers = data.activeConsumers + (publish.topicName -> consumer)), resetPingReqTimer = false) @@ -474,10 +473,7 @@ import scala.util.{Either, Failure, Success} data.settings), consumerName ) - context.watchWith( - consumer, - ConsumerFree(topicName) - ) + context.watch(consumer) serverConnected( data.copy( activeConsumers = data.activeConsumers + (topicName -> consumer), @@ -514,10 +510,7 @@ import scala.util.{Either, Failure, Success} val producer = context.spawn(Producer(publish, publishData, reply, data.producerPacketRouter, data.settings), producerName) - context.watchWith( - producer, - ProducerFree(publish.topicName) - ) + context.watch(producer) serverConnected(data.copy(activeProducers = data.activeProducers + (publish.topicName -> producer))) } else { serverConnected( @@ -540,10 +533,7 @@ import scala.util.{Either, Failure, Success} Producer(prl.publish, prl.publishData, reply, data.producerPacketRouter, data.settings), producerName ) - context.watchWith( - producer, - ProducerFree(topicName) - ) + context.watch(producer) serverConnected( data.copy( activeProducers = data.activeProducers + (topicName -> producer), @@ -601,13 +591,23 @@ import scala.util.{Either, Failure, Success} serverConnected(data.copy(pendingPingResp = false)) } .receiveSignal { - case (context, ChildFailed(_, failure)) if failure == Subscriber.SubscribeFailed => - data.remote.fail(Subscriber.SubscribeFailed) + case (context, ChildFailed(_, failure)) + if failure == Subscriber.SubscribeFailed || + failure == Unsubscriber.UnsubscribeFailed || + failure.isInstanceOf[Consumer.ConsumeFailed] => + data.remote.fail(failure) disconnect(context, data.remote, data) - case (context, ChildFailed(_, failure)) if failure == Unsubscriber.UnsubscribeFailed => - data.remote.fail(Unsubscriber.UnsubscribeFailed) - disconnect(context, data.remote, data) - case (_, _: Terminated) => + case (context, t: Terminated) => + data.activeConsumers.find(_._2 == t.ref) match { + case Some((topic, _)) => + context.self ! ConsumerFree(topic) + case None => + data.activeProducers.find(_._2 == t.ref) match { + case Some((topic, _)) => + context.self ! ProducerFree(topic) + case None => + } + } serverConnected(data) case (_, PostStop) => data.remote.complete() diff --git a/mqtt-streaming/src/main/scala/akka/stream/alpakka/mqtt/streaming/impl/ServerState.scala b/mqtt-streaming/src/main/scala/akka/stream/alpakka/mqtt/streaming/impl/ServerState.scala index 67a9de6ce5..a0dc65e29f 100644 --- a/mqtt-streaming/src/main/scala/akka/stream/alpakka/mqtt/streaming/impl/ServerState.scala +++ b/mqtt-streaming/src/main/scala/akka/stream/alpakka/mqtt/streaming/impl/ServerState.scala @@ -541,7 +541,7 @@ import scala.util.{Failure, Success} data.consumerPacketRouter, data.settings), consumerName) - context.watchWith(consumer, ConsumerFree(publish.topicName)) + context.watch(consumer) clientConnected(data.copy(activeConsumers = data.activeConsumers + (publish.topicName -> consumer))) case Some(consumer) if publish.flags.contains(ControlPacketFlags.DUP) => consumer ! Consumer.DupPublishReceivedFromRemote(local) @@ -565,10 +565,7 @@ import scala.util.{Failure, Success} data.settings), consumerName ) - context.watchWith( - consumer, - ConsumerFree(topicName) - ) + context.watch(consumer) clientConnected( data.copy( activeConsumers = data.activeConsumers + (topicName -> consumer), @@ -605,10 +602,7 @@ import scala.util.{Failure, Success} val producer = context.spawn(Producer(publish, publishData, reply, data.producerPacketRouter, data.settings), producerName) - context.watchWith( - producer, - ProducerFree(publish.topicName) - ) + context.watch(producer) clientConnected(data.copy(activeProducers = data.activeProducers + (publish.topicName -> producer))) } else { clientConnected( @@ -629,10 +623,7 @@ import scala.util.{Failure, Success} Producer(prl.publish, prl.publishData, reply, data.producerPacketRouter, data.settings), producerName ) - context.watchWith( - producer, - ProducerFree(topicName) - ) + context.watch(producer) clientConnected( data.copy( activeProducers = data.activeProducers + (topicName -> producer), @@ -731,9 +722,22 @@ import scala.util.{Failure, Success} } .receiveSignal { case (context, ChildFailed(_, failure)) - if failure == Publisher.SubscribeFailed || failure == Unpublisher.UnsubscribeFailed => + if failure == Subscriber.SubscribeFailed || + failure == Unsubscriber.UnsubscribeFailed || + failure.isInstanceOf[Consumer.ConsumeFailed] => + data.remote.fail(failure) disconnect(context, data.remote, data) - case (_, _: Terminated) => + case (context, t: Terminated) => + data.activeConsumers.find(_._2 == t.ref) match { + case Some((topic, _)) => + context.self ! ConsumerFree(topic) + case None => + data.activeProducers.find(_._2 == t.ref) match { + case Some((topic, _)) => + context.self ! ProducerFree(topic) + case None => + } + } Behaviors.same case (_, PostStop) => data.remote.complete()