Skip to content

Commit

Permalink
Add Flow to support RabbitMQ RPC workflow akka#160
Browse files Browse the repository at this point in the history
  • Loading branch information
Falmarri committed Jan 23, 2017
1 parent 2f03442 commit 9024e8c
Show file tree
Hide file tree
Showing 10 changed files with 455 additions and 5 deletions.
179 changes: 179 additions & 0 deletions amqp/src/main/scala/akka/stream/alpakka/amqp/AmqpRpcFlowStage.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
package akka.stream.alpakka.amqp

import akka.stream._
import akka.stream.stage._
import akka.util.ByteString
import com.rabbitmq.client.AMQP.BasicProperties
import com.rabbitmq.client.{DefaultConsumer, Envelope, ShutdownSignalException}

import scala.collection.mutable
import scala.concurrent.{Future, Promise}

object AmqpRpcFlowStage {

/**
* Internal API
*/
private val defaultAttributes =
Attributes.name("AmqpRpcFlow").and(ActorAttributes.dispatcher("akka.stream.default-blocking-io-dispatcher"))
}

/**
* This stage materializes to a Future[String], which is the name of the private exclusive queue used for RPC communication
*
* @param responsesPerMessage The number of responses that should be expected for each message placed on the queue. This
* can be overridden per message by including `expectedReplies` in the the header of the [[OutgoingMessage]]
*/
final class AmqpRpcFlowStage(settings: AmqpSinkSettings, bufferSize: Int, responsesPerMessage: Int = 1)
extends GraphStageWithMaterializedValue[FlowShape[OutgoingMessage, IncomingMessage], Future[String]]
with AmqpConnector {
stage =>

import AmqpRpcFlowStage._

val in = Inlet[OutgoingMessage]("AmqpRpcFlow.in")
val out = Outlet[IncomingMessage]("AmqpRpcFlow.out")

override def shape: FlowShape[OutgoingMessage, IncomingMessage] = FlowShape.of(in, out)

override protected def initialAttributes: Attributes = defaultAttributes

override def createLogicAndMaterializedValue(inheritedAttributes: Attributes): (GraphStageLogic, Future[String]) = {
val promise = Promise[String]()
(new GraphStageLogic(shape) with AmqpConnectorLogic {

override val settings = stage.settings
private val exchange = settings.exchange.getOrElse("")
private val routingKey = settings.routingKey.getOrElse("")
private val queue = mutable.Queue[IncomingMessage]()
private var queueName: String = _
private var outstandingMessages = 0

override def connectionFactoryFrom(settings: AmqpConnectionSettings) = stage.connectionFactoryFrom(settings)

override def whenConnected(): Unit = {
import scala.collection.JavaConverters._
val shutdownCallback = getAsyncCallback[Option[ShutdownSignalException]] {
case Some(ex) =>
promise.tryFailure(ex)
failStage(ex)
case None =>
promise.trySuccess("")
completeStage()
}

pull(in)

// we have only one consumer per connection so global is ok
channel.basicQos(bufferSize, true)
val consumerCallback = getAsyncCallback(handleDelivery)

val amqpSourceConsumer = new DefaultConsumer(channel) {
override def handleDelivery(
consumerTag: String,
envelope: Envelope,
properties: BasicProperties,
body: Array[Byte]
): Unit =
consumerCallback.invoke(IncomingMessage(ByteString(body), envelope, properties))

override def handleCancel(consumerTag: String): Unit =
// non consumer initiated cancel, for example happens when the queue has been deleted.
shutdownCallback.invoke(None)

override def handleShutdownSignal(consumerTag: String, sig: ShutdownSignalException): Unit =
// "Called when either the channel or the underlying connection has been shut down."
shutdownCallback.invoke(Option(sig))
}

// Create an exclusive queue with a randomly generated name for use as the replyTo portion of RPC
queueName = channel.queueDeclare(
"",
false,
true,
true,
Map.empty[String, AnyRef].asJava
).getQueue

channel.basicConsume(
queueName,
amqpSourceConsumer
)
promise.success(queueName)
}

def handleDelivery(message: IncomingMessage): Unit =
if (isAvailable(out)) {
pushAndAckMessage(message)
} else {
if (queue.size + 1 > bufferSize) {
failStage(new RuntimeException(s"Reached maximum buffer size $bufferSize"))
} else {
queue.enqueue(message)
}
}

def pushAndAckMessage(message: IncomingMessage): Unit = {
push(out, message)
// ack it as soon as we have passed it downstream
// TODO ack less often and do batch acks with multiple = true would probably be more performant
channel.basicAck(
message.envelope.getDeliveryTag,
false // just this single message
)
outstandingMessages -= 1

if (outstandingMessages == 0 && isClosed(in)) {
completeStage()
}
}

setHandler(
out,
new OutHandler {
override def onPull(): Unit =
if (queue.nonEmpty) {
pushAndAckMessage(queue.dequeue())
}
}
)

setHandler(
in,
new InHandler {
// We don't want to finish since we're still waiting
// on incoming messages from rabbit
override def onUpstreamFinish(): Unit = {}

override def onPush(): Unit = {
import scala.collection.JavaConverters._
val elem = grab(in)
val props = elem.props.getOrElse(new BasicProperties()).builder.replyTo(queueName).build()
channel.basicPublish(
exchange,
routingKey,
elem.mandatory,
elem.immediate,
props,
elem.bytes.toArray
)

// TODO: This is pretty expensive just to see if a message specifies how many responses it should be expecting
Option(props.getHeaders)
.map(_.asScala)
.getOrElse(Map.empty[String, AnyRef])
.get("expectedReplies")
.map(_.asInstanceOf[Int])
.getOrElse(responsesPerMessage)

outstandingMessages += responsesPerMessage
pull(in)
}
}
)
}, promise.future)
}

override def toString: String = "AmqpRpcFlow"

}
68 changes: 67 additions & 1 deletion amqp/src/main/scala/akka/stream/alpakka/amqp/AmqpSinkStage.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ object AmqpSinkStage {
* Internal API
*/
private val defaultAttributes =
Attributes.name("AmsqpSink").and(ActorAttributes.dispatcher("akka.stream.default-blocking-io-dispatcher"))
Attributes.name("AmqpSink").and(ActorAttributes.dispatcher("akka.stream.default-blocking-io-dispatcher"))
}

/**
Expand Down Expand Up @@ -77,3 +77,69 @@ final class AmqpSinkStage(settings: AmqpSinkSettings)

override def toString: String = "AmqpSink"
}


object AmqpReplyToStage {

/**
* Internal API
*/
private val defaultAttributes =
Attributes.name("AmqpReplyToSink").and(ActorAttributes.dispatcher("akka.stream.default-blocking-io-dispatcher"))
}

/**
* Connects to an AMQP server upon materialization and sends incoming messages to the server.
* Each materialized sink will create one connection to the broker. This stage sends messages to
* the queue named in the replyTo options of the message instead of from settings declared at construction.
*/
final class AmqpReplyToStage(settings: AmqpSinkSettings)
extends GraphStage[SinkShape[OutgoingMessage]]
with AmqpConnector { stage =>
import AmqpReplyToStage._

val in = Inlet[OutgoingMessage]("AmqpReplyToSink.in")

override def shape: SinkShape[OutgoingMessage] = SinkShape.of(in)

override protected def initialAttributes: Attributes = defaultAttributes

override def createLogic(inheritedAttributes: Attributes): GraphStageLogic =
new GraphStageLogic(shape) with AmqpConnectorLogic {
override val settings = stage.settings

override def connectionFactoryFrom(settings: AmqpConnectionSettings) = stage.connectionFactoryFrom(settings)

override def whenConnected(): Unit = {
val shutdownCallback = getAsyncCallback[ShutdownSignalException] { ex =>
failStage(ex)
}
channel.addShutdownListener(new ShutdownListener {
override def shutdownCompleted(cause: ShutdownSignalException): Unit =
shutdownCallback.invoke(cause)
})
pull(in)
}

setHandler(in,
new InHandler {
override def onPush(): Unit = {
val elem = grab(in)
channel.basicPublish(
"",
elem.props.map(_.getReplyTo).get,
elem.mandatory,
elem.immediate,
elem.props.orNull,
elem.bytes.toArray
)
pull(in)
}
})

}

override def toString: String = "AmqpReplyToSink"
}


Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package akka.stream.alpakka.amqp.javadsl

import akka.NotUsed
import akka.stream.alpakka.amqp._
import akka.stream.javadsl.{Flow, Source}
import akka.util.ByteString

import scala.concurrent.Future

object AmqpRpcFlow {

/**
* Java API: Creates an [[AmqpRpcFlow]] with given settings and buffer size.
*/
def create(settings: AmqpSinkSettings, bufferSize: Int) =
Flow.fromGraph(new AmqpRpcFlowStage(settings, bufferSize))

/**
* Java API: Creates an [[AmqpRpcFlow]] with given settings and buffer size.
*/
def create(settings: AmqpSinkSettings, bufferSize: Int, repliesPerMessage: Int) =
Flow.fromGraph(new AmqpRpcFlowStage(settings, bufferSize, repliesPerMessage))

/**
* Java API: Creates an [[AmqpRpcFlow]] with given settings and buffer size.
*/
def createSimple(settings: AmqpSinkSettings, repliesPerMessage: Int): Flow[ByteString, ByteString, Future[String]] =
akka.stream.alpakka.amqp.scaladsl.AmqpRpcFlow.simple(settings, repliesPerMessage).asJava

}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
package akka.stream.alpakka.amqp.javadsl

import akka.NotUsed
import akka.stream.alpakka.amqp.{AmqpSinkSettings, AmqpSinkStage, OutgoingMessage}
import akka.stream.alpakka.amqp.{AmqpReplyToStage, AmqpSinkSettings, AmqpSinkStage, OutgoingMessage}
import akka.stream.javadsl.Sink
import akka.util.ByteString

Expand All @@ -16,6 +16,12 @@ object AmqpSink {
def create(settings: AmqpSinkSettings): akka.stream.javadsl.Sink[OutgoingMessage, NotUsed] =
Sink.fromGraph(new AmqpSinkStage(settings))

/**
* Java API: Creates an [[AmqpSink]] that accepts [[OutgoingMessage]] elements.
*/
def createReplyTo(settings: AmqpSinkSettings): akka.stream.javadsl.Sink[OutgoingMessage, NotUsed] =
Sink.fromGraph(new AmqpReplyToStage(settings))

/**
* Java API: Creates an [[AmqpSink]] that accepts ByteString elements.
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package akka.stream.alpakka.amqp.scaladsl

import akka.NotUsed
import akka.stream.alpakka.amqp.{AmqpRpcFlowStage, AmqpSinkSettings, IncomingMessage, OutgoingMessage}
import akka.stream.scaladsl.{Flow, Keep, Sink}
import akka.util.ByteString

import scala.concurrent.Future

object AmqpRpcFlow {
/**
* Scala API: Creates an [[AmqpRpcFlow]] that accepts ByteString elements and emits ByteString elements.
*/
def simple(settings: AmqpSinkSettings, repliesPerMessage: Int = 1) =
Flow[ByteString].map(bytes => OutgoingMessage(bytes, false, false, None)).viaMat(apply(settings, 1, repliesPerMessage))(Keep.right).map(_.bytes)

/**
* Scala API: Creates an [[AmqpRpcFlow]] that accepts [[OutgoingMessage]] elements and emits corresponding IncomingMessage elements.
*/
def apply(settings: AmqpSinkSettings, bufferSize: Int, repliesPerMessage: Int = 1): Flow[OutgoingMessage, IncomingMessage, Future[String]] =
Flow.fromGraph(new AmqpRpcFlowStage(settings, bufferSize, repliesPerMessage))

}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
package akka.stream.alpakka.amqp.scaladsl

import akka.NotUsed
import akka.stream.alpakka.amqp.{AmqpSinkSettings, AmqpSinkStage, OutgoingMessage}
import akka.stream.alpakka.amqp.{AmqpReplyToStage, AmqpSinkSettings, AmqpSinkStage, OutgoingMessage}
import akka.stream.scaladsl.Sink
import akka.util.ByteString

Expand All @@ -16,6 +16,9 @@ object AmqpSink {
def simple(settings: AmqpSinkSettings): Sink[ByteString, NotUsed] =
apply(settings).contramap[ByteString](bytes => OutgoingMessage(bytes, false, false, None))

def replyTo(settings: AmqpSinkSettings): Sink[OutgoingMessage, NotUsed] =
Sink.fromGraph(new AmqpReplyToStage(settings))

/**
* Scala API: Creates an [[AmqpSink]] that accepts [[OutgoingMessage]] elements.
*/
Expand Down
Loading

0 comments on commit 9024e8c

Please sign in to comment.