Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MQTT streaming: SourceQueue backpressure #1577

Merged
merged 6 commits into from
Jul 1, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions mqtt-streaming/src/main/mima-filters/1.0.2.backwards.excludes
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# PR #1577
# https://github.com/akka/alpakka/pull/1577
ProblemFilters.exclude[DirectMissingMethodProblem]("akka.stream.alpakka.mqtt.streaming.MqttSessionSettings.this")
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ object MqttSessionSettings {
* Configuration settings for client and server usage.
*/
final class MqttSessionSettings private (val maxPacketSize: Int = 4096,
val clientSendBufferSize: Int = 10,
val clientTerminationWatcherBufferSize: Int = 100,
val commandParallelism: Int = 50,
val eventParallelism: Int = 10,
Expand All @@ -52,6 +53,13 @@ final class MqttSessionSettings private (val maxPacketSize: Int = 4096,

import akka.util.JavaDurationConverters._

/**
* Just for clients - the number of commands that can be buffered while connected to a server. Defaults
* to 10. Any commands received beyond this will apply backpressure.
*/
def withClientSendBufferSize(clientSendBufferSize: Int): MqttSessionSettings =
copy(clientSendBufferSize = clientSendBufferSize)

/**
* The maximum size of a packet that is allowed to be decoded. Defaults to 4k.
*/
Expand Down Expand Up @@ -224,12 +232,13 @@ final class MqttSessionSettings private (val maxPacketSize: Int = 4096,

/**
* Just for servers - the number of commands that can be buffered while connected to a client. Defaults
* to 100. Any commands received beyond this will be dropped.
* to 100. Any commands received beyond this will apply backpressure.
*/
def withServerSendBufferSize(serverSendBufferSize: Int): MqttSessionSettings =
copy(serverSendBufferSize = serverSendBufferSize)

private def copy(maxPacketSize: Int = maxPacketSize,
clientSendBufferSize: Int = clientSendBufferSize,
clientTerminationWatcherBufferSize: Int = clientTerminationWatcherBufferSize,
commandParallelism: Int = commandParallelism,
eventParallelism: Int = eventParallelism,
Expand All @@ -245,6 +254,7 @@ final class MqttSessionSettings private (val maxPacketSize: Int = 4096,
serverSendBufferSize: Int = serverSendBufferSize) =
new MqttSessionSettings(
maxPacketSize,
clientSendBufferSize,
clientTerminationWatcherBufferSize,
commandParallelism,
eventParallelism,
Expand All @@ -261,5 +271,21 @@ final class MqttSessionSettings private (val maxPacketSize: Int = 4096,
)

override def toString: String =
s"MqttSessionSettings(maxPacketSize=$maxPacketSize,clientTerminationWatcherBufferSize=$clientTerminationWatcherBufferSize,commandParallelism=$commandParallelism,eventParallelism=$eventParallelism,receiveConnectTimeout=$receiveConnectTimeout,receiveConnAckTimeout=$receiveConnAckTimeout,receivePubAckRecTimeout=$producerPubAckRecTimeout,receivePubCompTimeout=$producerPubCompTimeout,receivePubAckRecTimeout=$consumerPubAckRecTimeout,receivePubCompTimeout=$consumerPubCompTimeout,receivePubRelTimeout=$consumerPubRelTimeout,receiveSubAckTimeout=$receiveSubAckTimeout,receiveUnsubAckTimeout=$receiveUnsubAckTimeout,serverSendBufferSize=$serverSendBufferSize)"
"MqttSessionSettings(" +
s"maxPacketSize=$maxPacketSize," +
s"clientSendBufferSize=$clientSendBufferSize," +
s"clientTerminationWatcherBufferSize=$clientTerminationWatcherBufferSize," +
s"commandParallelism=$commandParallelism," +
s"eventParallelism=$eventParallelism," +
s"receiveConnectTimeout=${receiveConnectTimeout.toCoarsest}," +
s"receiveConnAckTimeout=${receiveConnAckTimeout.toCoarsest}," +
s"receivePubAckRecTimeout=${producerPubAckRecTimeout.toCoarsest}," +
s"receivePubCompTimeout=${producerPubCompTimeout.toCoarsest}," +
s"receivePubAckRecTimeout=${consumerPubAckRecTimeout.toCoarsest}," +
s"receivePubCompTimeout=${consumerPubCompTimeout.toCoarsest}," +
s"receivePubRelTimeout=${consumerPubRelTimeout.toCoarsest}," +
s"receiveSubAckTimeout=${receiveSubAckTimeout.toCoarsest}," +
s"receiveUnsubAckTimeout=${receiveUnsubAckTimeout.toCoarsest}," +
s"serverSendBufferSize=$serverSendBufferSize" +
")"
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,17 @@ package akka.stream.alpakka.mqtt.streaming
package impl

import akka.NotUsed
import akka.actor.typed.{ActorRef, Behavior, ChildFailed, PostStop, Terminated}
import akka.actor.typed._
import akka.actor.typed.scaladsl.{ActorContext, Behaviors}
import akka.annotation.InternalApi
import akka.stream.{Materializer, OverflowStrategy}
import akka.stream.{Materializer, OverflowStrategy, QueueOfferResult}
import akka.stream.scaladsl.{BroadcastHub, Keep, Source, SourceQueueWithComplete}
import akka.util.ByteString

import scala.concurrent.Promise
import scala.concurrent.duration.FiniteDuration
import scala.util.control.NoStackTrace
import scala.util.{Failure, Success}
import scala.util.{Either, Failure, Success}

/*
* A client connector is a Finite State Machine that manages MQTT client
Expand Down Expand Up @@ -154,6 +154,8 @@ import scala.util.{Failure, Success}
settings
)

final case class WaitingForQueueOfferResult(nextBehavior: Behavior[Event], stash: Seq[Event])

sealed abstract class Event(val connectionId: ByteString)

final case class ConnectReceivedLocally(override val connectionId: ByteString,
Expand Down Expand Up @@ -207,6 +209,11 @@ import scala.util.{Failure, Success}
remote: Promise[Unsubscriber.ForwardUnsubscribe])
extends Event(connectionId)

final case class QueueOfferCompleted(override val connectionId: ByteString,
result: Either[Throwable, QueueOfferResult])
extends Event(connectionId)
with QueueOfferState.QueueOfferCompleted

sealed abstract class Command
sealed abstract class ForwardConnectCommand
case object ForwardConnect extends ForwardConnectCommand
Expand All @@ -226,18 +233,22 @@ import scala.util.{Failure, Success}
Behaviors
.receivePartial[Event] {
case (context, ConnectReceivedLocally(connectionId, connect, connectData, remote)) =>
import context.executionContext

val (queue, source) = Source
.queue[ForwardConnectCommand](1, OverflowStrategy.dropHead)
.queue[ForwardConnectCommand](data.settings.clientSendBufferSize, OverflowStrategy.backpressure)
.toMat(BroadcastHub.sink)(Keep.both)
.run()

remote.success(source)

queue.offer(ForwardConnect)
data.stash.foreach(context.self.tell)
queue
.offer(ForwardConnect)
.onComplete(result => context.self.tell(QueueOfferCompleted(connectionId, result.toEither)))

if (connect.connectFlags.contains(ConnectFlags.CleanSession)) {
val nextState = if (connect.connectFlags.contains(ConnectFlags.CleanSession)) {
context.children.foreach(context.stop)

serverConnect(
ConnectReceived(
connectionId,
Expand Down Expand Up @@ -281,6 +292,9 @@ import scala.util.{Failure, Success}
)

}

QueueOfferState.waitForQueueOfferCompleted(nextState, stash = data.stash)

case (_, ConnectionLost(_)) =>
Behavior.same
case (_, e) =>
Expand Down Expand Up @@ -391,30 +405,37 @@ import scala.util.{Failure, Success}
if connectionId != data.connectionId =>
context.self ! connect
disconnect(context, data.remote, data)

case (_, event) if event.connectionId.nonEmpty && event.connectionId != data.connectionId =>
Behaviors.same

case (context, ConnectionLost(_)) =>
timer.cancel(SendPingreq)
disconnect(context, data.remote, data)

case (context, DisconnectReceivedLocally(_, remote)) =>
remote.success(ForwardDisconnect)
timer.cancel(SendPingreq)
disconnect(context, data.remote, data)

case (context, SubscribeReceivedLocally(_, _, subscribeData, remote)) =>
context.watch(
context.spawnAnonymous(Subscriber(subscribeData, remote, data.subscriberPacketRouter, data.settings))
)
serverConnected(data)

case (context, UnsubscribeReceivedLocally(_, _, unsubscribeData, remote)) =>
context.watch(
context
.spawnAnonymous(Unsubscriber(unsubscribeData, remote, data.unsubscriberPacketRouter, data.settings))
)
serverConnected(data)

case (_, PublishReceivedFromRemote(_, publish, local))
if (publish.flags & ControlPacketFlags.QoSReserved).underlying == 0 =>
local.success(Consumer.ForwardPublish)
serverConnected(data, resetPingReqTimer = false)

case (context,
prfr @ PublishReceivedFromRemote(_, publish @ Publish(_, topicName, Some(packetId), _), local)) =>
data.activeConsumers.get(topicName) match {
Expand All @@ -424,17 +445,21 @@ import scala.util.{Failure, Success}
context.spawn(Consumer(publish, None, packetId, local, data.consumerPacketRouter, data.settings),
consumerName)
context.watchWith(consumer, ConsumerFree(publish.topicName))

serverConnected(data.copy(activeConsumers = data.activeConsumers + (publish.topicName -> consumer)),
resetPingReqTimer = false)

case Some(consumer) if publish.flags.contains(ControlPacketFlags.DUP) =>
consumer ! Consumer.DupPublishReceivedFromRemote(local)
serverConnected(data, resetPingReqTimer = false)

case Some(_) =>
serverConnected(
data.copy(pendingRemotePublications = data.pendingRemotePublications :+ (publish.topicName -> prfr)),
resetPingReqTimer = false
)
}

case (context, ConsumerFree(topicName)) =>
val i = data.pendingRemotePublications.indexWhere(_._1 == topicName)
if (i >= 0) {
Expand Down Expand Up @@ -463,10 +488,20 @@ import scala.util.{Failure, Success}
} else {
serverConnected(data.copy(activeConsumers = data.activeConsumers - topicName))
}
case (_, PublishReceivedLocally(publish, _))

case (context, PublishReceivedLocally(publish, _))
if (publish.flags & ControlPacketFlags.QoSReserved).underlying == 0 =>
data.remote.offer(ForwardPublish(publish, None))
serverConnected(data)
import context.executionContext

data.remote
.offer(ForwardPublish(publish, None))
.onComplete(result => context.self.tell(QueueOfferCompleted(ByteString.empty, result.toEither)))

QueueOfferState.waitForQueueOfferCompleted(
serverConnected(data),
stash = Seq.empty
)

case (context, prl @ PublishReceivedLocally(publish, publishData)) =>
val producerName = ActorName.mkName(ProducerNamePrefix + publish.topicName + "-" + context.children.size)
if (!data.activeProducers.contains(publish.topicName)) {
Expand All @@ -489,6 +524,7 @@ import scala.util.{Failure, Success}
data.copy(pendingLocalPublications = data.pendingLocalPublications :+ (publish.topicName -> prl))
)
}

case (context, ProducerFree(topicName)) =>
val i = data.pendingLocalPublications.indexWhere(_._1 == topicName)
if (i >= 0) {
Expand Down Expand Up @@ -518,19 +554,48 @@ import scala.util.{Failure, Success}
} else {
serverConnected(data.copy(activeProducers = data.activeProducers - topicName))
}
case (_, ReceivedProducerPublishingCommand(Producer.ForwardPublish(publish, packetId))) =>
data.remote.offer(ForwardPublish(publish, packetId))
Behaviors.same
case (_, ReceivedProducerPublishingCommand(Producer.ForwardPubRel(_, packetId))) =>
data.remote.offer(ForwardPubRel(packetId))
Behaviors.same

case (context, ReceivedProducerPublishingCommand(Producer.ForwardPublish(publish, packetId))) =>
import context.executionContext

data.remote
.offer(ForwardPublish(publish, packetId))
.onComplete(result => context.self.tell(QueueOfferCompleted(ByteString.empty, result.toEither)))

QueueOfferState.waitForQueueOfferCompleted(
serverConnected(data, resetPingReqTimer = false),
stash = Seq.empty
)

case (context, ReceivedProducerPublishingCommand(Producer.ForwardPubRel(_, packetId))) =>
import context.executionContext

data.remote
.offer(ForwardPubRel(packetId))
.onComplete(result => context.self.tell(QueueOfferCompleted(ByteString.empty, result.toEither)))

QueueOfferState.waitForQueueOfferCompleted(
serverConnected(data, resetPingReqTimer = false),
stash = Seq.empty
)

case (context, SendPingReqTimeout(_)) if data.pendingPingResp =>
data.remote.fail(PingFailed)
timer.cancel(SendPingreq)
disconnect(context, data.remote, data)
case (_, SendPingReqTimeout(_)) =>
data.remote.offer(ForwardPingReq)
serverConnected(data.copy(pendingPingResp = true))

case (context, SendPingReqTimeout(_)) =>
import context.executionContext

data.remote
.offer(ForwardPingReq)
.onComplete(result => context.self.tell(QueueOfferCompleted(ByteString.empty, result.toEither)))

QueueOfferState.waitForQueueOfferCompleted(
serverConnected(data.copy(pendingPingResp = true)),
stash = Seq.empty
)

case (_, PingRespReceivedFromRemote(_, local)) =>
local.success(ForwardPingResp)
serverConnected(data.copy(pendingPingResp = false))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Copyright (C) 2016-2019 Lightbend Inc. <http://www.lightbend.com>
*/

package akka.stream.alpakka.mqtt.streaming.impl

import akka.actor.typed.Behavior
import akka.actor.typed.scaladsl.Behaviors
import akka.stream.QueueOfferResult

private[mqtt] object QueueOfferState {

/**
* A marker trait that holds a result for SourceQueue#offer
*/
trait QueueOfferCompleted {
def result: Either[Throwable, QueueOfferResult]
}

/**
* A behavior that stashes messages until a response to the SourceQueue#offer
* method is received.
*
* This is to be used only with SourceQueues that use backpressure.
*/
def waitForQueueOfferCompleted[T](behavior: Behavior[T], stash: Seq[T]): Behavior[T] =
Behaviors
.receive[T] {
case (context, completed: QueueOfferCompleted) =>
completed.result match {
case Right(QueueOfferResult.Enqueued) =>
stash.foreach(context.self.tell)

behavior

case Right(other) =>
throw new IllegalStateException(s"Failed to offer to queue: $other")

case Left(failure) =>
throw failure
}

case (_, other) =>
waitForQueueOfferCompleted(behavior, stash = stash :+ other)
}
.orElse(behavior) // handle signals immediately
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing I'm not clear on are the semantics here. Note that on L43, we catch all events that flow. Thus, the intention is that all signals go to behavior but nothing else (until we hit L34 in the future)

If behavior's signal handler (i.e. the behavior being run via orElse) returns Behavior.same, I'd like to clarify that it doesn't change the behavior of the actor to itself, i.e. waitForQueueOfferCompleted (L26-L46) remains the behavior. Hopefully my question makes sense..

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the behaviour would indeed remain the same i.e. stay as per the behaviour of waitForQueueOfferCompleted. But I've not yet proved this.

However, given L46 with the orElse, I'm not seeing the signal handler of behavior being applied. When I run the tests I see a DeathPactException which means that there is no signal handler for Terminated, but there is such a signal handler for behavior...

}
Loading