Skip to content

Commit

Permalink
Merge pull request #222 from matsluni/more_options_on_configuration_o…
Browse files Browse the repository at this point in the history
…f_connectionFactory

AMQP: add more options to configuration of the ConnectionFactory, #191
  • Loading branch information
2m authored May 4, 2017
2 parents 399b968 + ba303e7 commit a0ad914
Show file tree
Hide file tree
Showing 7 changed files with 134 additions and 24 deletions.
37 changes: 33 additions & 4 deletions amqp/src/main/scala/akka/stream/alpakka/amqp/AmqpConnector.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,47 @@ private[amqp] trait AmqpConnector {
val factory = new ConnectionFactory
settings match {
case AmqpConnectionUri(uri) => factory.setUri(uri)
case AmqpConnectionDetails(host, port, maybeCredentials, maybeVirtualHost, sslProtocol) =>
factory.setHost(host)
factory.setPort(port)
case AmqpConnectionDetails(_,
maybeCredentials,
maybeVirtualHost,
sslProtocol,
requestedHeartbeat,
connectionTimeout,
handshakeTimeout,
shutdownTimeout,
networkRecoveryInterval,
automaticRecoveryEnabled,
topologyRecoveryEnabled,
exceptionHandler) =>
maybeCredentials.foreach { credentials =>
factory.setUsername(credentials.username)
factory.setPassword(credentials.password)
}
maybeVirtualHost.foreach(factory.setVirtualHost)
sslProtocol.foreach(factory.useSslProtocol)
requestedHeartbeat.foreach(factory.setRequestedHeartbeat)
connectionTimeout.foreach(factory.setConnectionTimeout)
handshakeTimeout.foreach(factory.setHandshakeTimeout)
shutdownTimeout.foreach(factory.setShutdownTimeout)
networkRecoveryInterval.foreach(factory.setNetworkRecoveryInterval)
automaticRecoveryEnabled.foreach(factory.setAutomaticRecoveryEnabled)
topologyRecoveryEnabled.foreach(factory.setTopologyRecoveryEnabled)
exceptionHandler.foreach(factory.setExceptionHandler)
case DefaultAmqpConnection => // leave it be as is
}
factory
}

def newConnection(factory: ConnectionFactory, settings: AmqpConnectionSettings): Connection = settings match {
case a: AmqpConnectionDetails => {
import scala.collection.JavaConverters._
if (a.hostAndPortList.nonEmpty)
factory.newConnection(a.hostAndPortList.map(hp => new Address(hp._1, hp._2)).asJava)
else
throw new IllegalArgumentException("You need to supply at least one host/port pair.")
}
case _ => factory.newConnection()
}
}

/**
Expand All @@ -40,12 +68,13 @@ private[amqp] trait AmqpConnectorLogic { this: GraphStageLogic =>

def settings: AmqpConnectorSettings
def connectionFactoryFrom(settings: AmqpConnectionSettings): ConnectionFactory
def newConnection(factory: ConnectionFactory, settings: AmqpConnectionSettings): Connection
def whenConnected(): Unit

final override def preStart(): Unit = {
val factory = connectionFactoryFrom(settings.connectionSettings)

connection = factory.newConnection()
connection = newConnection(factory, settings.connectionSettings)
channel = connection.createChannel()

val connShutdownCallback = getAsyncCallback[ShutdownSignalException] { ex =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import akka.stream._
import akka.stream.stage._
import akka.util.ByteString
import com.rabbitmq.client.AMQP.BasicProperties
import com.rabbitmq.client.{DefaultConsumer, Envelope, ShutdownSignalException}
import com.rabbitmq.client._

import scala.collection.mutable
import scala.concurrent.{Future, Promise}
Expand Down Expand Up @@ -49,6 +49,8 @@ final class AmqpRpcFlowStage(settings: AmqpSinkSettings, bufferSize: Int, respon
private var outstandingMessages = 0

override def connectionFactoryFrom(settings: AmqpConnectionSettings) = stage.connectionFactoryFrom(settings)
override def newConnection(factory: ConnectionFactory, settings: AmqpConnectionSettings): Connection =
stage.newConnection(factory, settings)

override def whenConnected(): Unit = {
import scala.collection.JavaConverters._
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ final class AmqpSinkStage(settings: AmqpSinkSettings)
private val routingKey = settings.routingKey.getOrElse("")

override def connectionFactoryFrom(settings: AmqpConnectionSettings) = stage.connectionFactoryFrom(settings)
override def newConnection(factory: ConnectionFactory, settings: AmqpConnectionSettings) =
stage.newConnection(factory, settings)

override def whenConnected(): Unit = {
val shutdownCallback = getAsyncCallback[ShutdownSignalException] { ex =>
Expand Down Expand Up @@ -124,6 +126,8 @@ final class AmqpReplyToSinkStage(settings: AmqpReplyToSinkSettings)
override val settings = stage.settings

override def connectionFactoryFrom(settings: AmqpConnectionSettings) = stage.connectionFactoryFrom(settings)
override def newConnection(factory: ConnectionFactory, settings: AmqpConnectionSettings): Connection =
stage.newConnection(factory, settings)

override def whenConnected(): Unit = {
val shutdownCallback = getAsyncCallback[ShutdownSignalException] { ex =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ final class AmqpSourceStage(settings: AmqpSourceSettings, bufferSize: Int)

override val settings = stage.settings
override def connectionFactoryFrom(settings: AmqpConnectionSettings) = stage.connectionFactoryFrom(settings)
override def newConnection(factory: ConnectionFactory, settings: AmqpConnectionSettings) =
stage.newConnection(factory, settings)

private val queue = mutable.Queue[IncomingMessage]()

Expand Down
66 changes: 62 additions & 4 deletions amqp/src/main/scala/akka/stream/alpakka/amqp/model.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
*/
package akka.stream.alpakka.amqp

import com.rabbitmq.client.ExceptionHandler
import scala.collection.JavaConverters._

/**
* Internal API
*/
Expand Down Expand Up @@ -147,15 +150,70 @@ object AmqpConnectionUri {
}

final case class AmqpConnectionDetails(
host: String,
port: Int,
hostAndPortList: Seq[(String, Int)],
credentials: Option[AmqpCredentials] = None,
virtualHost: Option[String] = None,
sslProtocol: Option[String] = None
) extends AmqpConnectionSettings {}
sslProtocol: Option[String] = None,
requestedHeartbeat: Option[Int] = None,
connectionTimeout: Option[Int] = None,
handshakeTimeout: Option[Int] = None,
shutdownTimeout: Option[Int] = None,
networkRecoveryInterval: Option[Int] = None,
automaticRecoveryEnabled: Option[Boolean] = None,
topologyRecoveryEnabled: Option[Boolean] = None,
exceptionHandler: Option[ExceptionHandler] = None
) extends AmqpConnectionSettings {

def withHostsAndPorts(hostAndPort: (String, Int), hostAndPorts: (String, Int)*): AmqpConnectionDetails =
copy(hostAndPortList = (hostAndPort +: hostAndPorts).toList)

def withCredentials(amqpCredentials: AmqpCredentials): AmqpConnectionDetails =
copy(credentials = Option(amqpCredentials))

def withVirtualHost(virtualHost: String): AmqpConnectionDetails =
copy(virtualHost = Option(virtualHost))

def withSslProtocol(sslProtocol: String): AmqpConnectionDetails =
copy(sslProtocol = Option(sslProtocol))

def withRequestedHeartbeat(requestedHeartbeat: Int): AmqpConnectionDetails =
copy(requestedHeartbeat = Option(requestedHeartbeat))

def withConnectionTimeout(connectionTimeout: Int): AmqpConnectionDetails =
copy(connectionTimeout = Option(connectionTimeout))

def withHandshakeTimeout(handshakeTimeout: Int): AmqpConnectionDetails =
copy(handshakeTimeout = Option(handshakeTimeout))

def withShutdownTimeout(shutdownTimeout: Int): AmqpConnectionDetails =
copy(shutdownTimeout = Option(shutdownTimeout))

def withNetworkRecoveryInterval(networkRecoveryInterval: Int): AmqpConnectionDetails =
copy(networkRecoveryInterval = Option(networkRecoveryInterval))

def withAutomaticRecoveryEnabled(automaticRecoveryEnabled: Boolean): AmqpConnectionDetails =
copy(automaticRecoveryEnabled = Option(automaticRecoveryEnabled))

def withTopologyRecoveryEnabled(topologyRecoveryEnabled: Boolean): AmqpConnectionDetails =
copy(topologyRecoveryEnabled = Option(topologyRecoveryEnabled))

def withExceptionHandler(exceptionHandler: ExceptionHandler): AmqpConnectionDetails =
copy(exceptionHandler = Option(exceptionHandler))

/**
* Java API:
*/
@annotation.varargs
def withHostsAndPorts(hostAndPort: akka.japi.Pair[String, Int],
hostAndPorts: akka.japi.Pair[String, Int]*): AmqpConnectionDetails =
copy(hostAndPortList = (hostAndPort +: hostAndPorts).map(_.toScala).toList)
}

object AmqpConnectionDetails {

def apply(host: String, port: Int): AmqpConnectionDetails =
AmqpConnectionDetails(List((host, port)))

/**
* Java API:
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,35 @@
*/
package akka.stream.alpakka.amqp.javadsl;

import akka.stream.alpakka.amqp.*;
import akka.stream.testkit.TestSubscriber;
import akka.stream.testkit.javadsl.TestSink;
import static org.junit.Assert.assertEquals;

import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.TimeUnit;

import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;

import static org.junit.Assert.*;

import akka.Done;
import akka.NotUsed;
import akka.actor.ActorSystem;
import akka.japi.Pair;
import akka.stream.*;
import akka.stream.javadsl.*;
import akka.stream.ActorMaterializer;
import akka.stream.Materializer;
import akka.stream.alpakka.amqp.*;
import akka.stream.javadsl.Flow;
import akka.stream.javadsl.Sink;
import akka.stream.javadsl.Source;
import akka.stream.testkit.TestSubscriber;
import akka.stream.testkit.javadsl.TestSink;
import akka.testkit.JavaTestKit;
import akka.util.ByteString;
import scala.Some;
import scala.concurrent.duration.Duration;

import java.util.*;
import java.util.concurrent.*;
import java.util.stream.*;

/**
* Needs a local running AMQP server on the default port with no password.
*/
Expand All @@ -53,9 +58,13 @@ public void publishAndConsume() throws Exception {
final QueueDeclaration queueDeclaration = QueueDeclaration.create(queueName);
//#queue-declaration

@SuppressWarnings("unchecked")
AmqpConnectionDetails amqpConnectionDetails = AmqpConnectionDetails.create("invalid", 5673)
.withHostsAndPorts(Pair.create("localhost", 5672), Pair.create("localhost", 5674));

//#create-sink
final Sink<ByteString, CompletionStage<Done>> amqpSink = AmqpSink.createSimple(
AmqpSinkSettings.create()
AmqpSinkSettings.create(amqpConnectionDetails)
.withRoutingKey(queueName)
.withDeclarations(queueDeclaration)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,26 @@ class AmqpConnectorsSpec extends AmqpSpec {
"The AMQP Connectors" should {

"publish and consume elements through a simple queue again in the same JVM" in {

// use a list of host/port pairs where one is normally invalid, but
// it should still work as expected,
val connectionSettings =
AmqpConnectionDetails(List(("invalid", 5673))).withHostsAndPorts(("localhost", 5672))

//#queue-declaration
val queueName = "amqp-conn-it-spec-simple-queue-" + System.currentTimeMillis()
val queueDeclaration = QueueDeclaration(queueName)
//#queue-declaration

//#create-sink
val amqpSink = AmqpSink.simple(
AmqpSinkSettings(DefaultAmqpConnection).withRoutingKey(queueName).withDeclarations(queueDeclaration)
AmqpSinkSettings(connectionSettings).withRoutingKey(queueName).withDeclarations(queueDeclaration)
)
//#create-sink

//#create-source
val amqpSource = AmqpSource(
NamedQueueSourceSettings(DefaultAmqpConnection, queueName).withDeclarations(queueDeclaration),
NamedQueueSourceSettings(connectionSettings, queueName).withDeclarations(queueDeclaration),
bufferSize = 10
)
//#create-source
Expand Down

0 comments on commit a0ad914

Please sign in to comment.