Skip to content

Commit

Permalink
AWS SQS source: reduce consumption on empty receives (#1743)
Browse files Browse the repository at this point in the history
  • Loading branch information
ennru authored Jun 25, 2019
2 parents be3b89a + e466618 commit 9c17a7c
Show file tree
Hide file tree
Showing 3 changed files with 285 additions and 3 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
/*
* Copyright (C) 2016-2019 Lightbend Inc. <http://www.lightbend.com>
*/

package akka.stream.alpakka.sqs.impl

import akka.annotation.InternalApi
import akka.stream.ActorAttributes.SupervisionStrategy
import akka.stream.Attributes.name
import akka.stream._
import akka.stream.impl.fusing.MapAsync
import akka.stream.impl.{Buffer => BufferImpl}
import akka.stream.stage.{GraphStage, GraphStageLogic, InHandler, OutHandler}

import scala.annotation.tailrec
import scala.concurrent.Future
import scala.util.control.NonFatal
import scala.util.{Failure, Success}

@InternalApi private[akka] final case class BalancingMapAsync[In, Out](
maxParallelism: Int,
f: In => Future[Out],
balancingF: (Out, Int) Int
) extends GraphStage[FlowShape[In, Out]] {

import MapAsync._

private val in = Inlet[In]("BalancingMapAsync.in")
private val out = Outlet[Out]("BalancingMapAsync.out")

override def initialAttributes = name("BalancingMapAsync")

override val shape = FlowShape(in, out)

override def createLogic(inheritedAttributes: Attributes): GraphStageLogic =
new GraphStageLogic(shape) with InHandler with OutHandler {

lazy val decider = inheritedAttributes.mandatoryAttribute[SupervisionStrategy].decider
var buffer: BufferImpl[Holder[Out]] = _
var parallelism = maxParallelism

private val futureCB = getAsyncCallback[Holder[Out]](
holder =>
holder.elem match {
case Success(value) =>
parallelism = balancingF(value, parallelism)
pushNextIfPossible()
case Failure(ex) =>
holder.supervisionDirectiveFor(decider, ex) match {
// fail fast as if supervision says so
case Supervision.Stop => failStage(ex)
case _ => pushNextIfPossible()
}
}
)

override def preStart(): Unit = buffer = BufferImpl(parallelism, materializer)

override def onPull(): Unit = pushNextIfPossible()

override def onPush(): Unit = {
try {
val future = f(grab(in))
val holder = new Holder[Out](NotYetThere, futureCB)
buffer.enqueue(holder)

future.value match {
case None => future.onComplete(holder)(akka.dispatch.ExecutionContexts.sameThreadExecutionContext)
case Some(v) =>
// #20217 the future is already here, optimization: avoid scheduling it on the dispatcher and
// run the logic directly on this thread
holder.setElem(v)
v match {
// this optimization also requires us to stop the stage to fail fast if the decider says so:
case Failure(ex) if holder.supervisionDirectiveFor(decider, ex) == Supervision.Stop => failStage(ex)
case _ => pushNextIfPossible()
}
}

} catch {
// this logic must only be executed if f throws, not if the future is failed
case NonFatal(ex) => if (decider(ex) == Supervision.Stop) failStage(ex)
}

pullIfNeeded()
}

override def onUpstreamFinish(): Unit = if (buffer.isEmpty) completeStage()

@tailrec
private def pushNextIfPossible(): Unit =
if (buffer.isEmpty) {
if (isClosed(in)) completeStage()
else pullIfNeeded()
} else if (buffer.peek().elem eq NotYetThere) pullIfNeeded() // ahead of line blocking to keep order
else if (isAvailable(out)) {
val holder = buffer.dequeue()
holder.elem match {
case Success(elem) =>
push(out, elem)
pullIfNeeded()

case Failure(NonFatal(ex)) =>
holder.supervisionDirectiveFor(decider, ex) match {
// this could happen if we are looping in pushNextIfPossible and end up on a failed future before the
// onComplete callback has run
case Supervision.Stop => failStage(ex)
case _ =>
// try next element
pushNextIfPossible()
}
}
}

private def pullIfNeeded(): Unit =
if (buffer.used < parallelism && !hasBeenPulled(in)) tryPull(in)

setHandlers(in, out, this)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@ package akka.stream.alpakka.sqs.scaladsl
import akka._
import akka.stream._
import akka.stream.alpakka.sqs.SqsSourceSettings
import akka.stream.scaladsl.Source
import akka.stream.alpakka.sqs.impl.BalancingMapAsync
import akka.stream.scaladsl.{Flow, Source}
import software.amazon.awssdk.services.sqs.SqsAsyncClient
import software.amazon.awssdk.services.sqs.model.{Message, QueueAttributeName, ReceiveMessageRequest}
import software.amazon.awssdk.services.sqs.model._

import scala.collection.JavaConverters._
import scala.compat.java8.FutureConverters._
Expand Down Expand Up @@ -42,9 +43,20 @@ object SqsSource {
case Some(t) => requestBuilder.visibilityTimeout(t.toSeconds.toInt).build()
}
}
.mapAsync(settings.parallelRequests)(sqsClient.receiveMessage(_).toScala)
.via(resolveHandler(settings.parallelRequests))
.map(_.messages().asScala.toList)
.takeWhile(messages => !settings.closeOnEmptyReceive || messages.nonEmpty)
.mapConcat(identity)
.buffer(settings.maxBufferSize, OverflowStrategy.backpressure)

private def resolveHandler(parallelism: Int)(implicit sqsClient: SqsAsyncClient) =
if (parallelism == 1) {
Flow[ReceiveMessageRequest].mapAsyncUnordered(parallelism)(sqsClient.receiveMessage(_).toScala)
} else {
BalancingMapAsync[ReceiveMessageRequest, ReceiveMessageResponse](
parallelism,
sqsClient.receiveMessage(_).toScala,
(response, _) => if (response.messages().isEmpty) 1 else parallelism
)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
/*
* Copyright (C) 2016-2019 Lightbend Inc. <http://www.lightbend.com>
*/

package akka.stream.alpakka.sqs.scaladsl

import java.util.concurrent.CompletableFuture

import akka.stream.alpakka.sqs.SqsSourceSettings
import akka.stream.testkit.scaladsl.TestSink
import org.mockito.ArgumentMatchers._
import org.mockito.Mockito.{atMost => atMostTimes, _}
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer
import org.scalatest.{FlatSpec, Matchers}
import org.scalatestplus.mockito.MockitoSugar.mock
import software.amazon.awssdk.services.sqs.SqsAsyncClient
import software.amazon.awssdk.services.sqs.model.{Message, ReceiveMessageRequest, ReceiveMessageResponse}

import scala.compat.java8.FutureConverters._
import scala.concurrent.Future
import scala.concurrent.duration._

class SqsSourceSpec extends FlatSpec with Matchers with DefaultTestContext {
val defaultMessages = (1 to 10).map { i =>
Message.builder().body(s"message $i").build()
}

"SqsSource" should "send a request and unwrap the response" in {
implicit val sqsClient: SqsAsyncClient = mock[SqsAsyncClient]
when(sqsClient.receiveMessage(any[ReceiveMessageRequest]))
.thenReturn(
CompletableFuture.completedFuture(
ReceiveMessageResponse
.builder()
.messages(defaultMessages: _*)
.build()
)
)

val probe = SqsSource(
"url",
SqsSourceSettings.Defaults.withMaxBufferSize(10)
).runWith(TestSink.probe[Message])

defaultMessages.foreach(probe.requestNext)

/**
* Invocations:
* 1 - to initially fill the buffer
* 2 - asynchronous call after the buffer is filled
* 3 - buffer proxies pull when it's not full -> async stage provides the data and executes the next call
*/
verify(sqsClient, times(3)).receiveMessage(any[ReceiveMessageRequest])
}

it should "buffer messages and acquire them fast with slow sqs" in {
implicit val sqsClient: SqsAsyncClient = mock[SqsAsyncClient]
val timeout = 1.second
val bufferToBatchRatio = 5

when(sqsClient.receiveMessage(any[ReceiveMessageRequest]))
.thenAnswer(new Answer[CompletableFuture[ReceiveMessageResponse]] {
def answer(invocation: InvocationOnMock) =
akka.pattern
.after(timeout, system.scheduler) {
Future.successful(
ReceiveMessageResponse
.builder()
.messages(defaultMessages: _*)
.build()
)
}(system.dispatcher)
.toJava
.toCompletableFuture
})

val probe = SqsSource(
"url",
SqsSourceSettings.Defaults.withMaxBufferSize(SqsSourceSettings.Defaults.maxBatchSize * bufferToBatchRatio)
).runWith(TestSink.probe[Message])

Thread.sleep(timeout.toMillis * (bufferToBatchRatio + 1))

for {
i <- 1 to bufferToBatchRatio
message <- defaultMessages
} {
probe.requestNext(10.milliseconds) shouldEqual message
}
}

it should "enable throttling on emptyReceives and disable throttling when a new message arrives" in {
implicit val sqsClient: SqsAsyncClient = mock[SqsAsyncClient]
val firstWithDataCount = 30
val thenEmptyCount = 15
val parallelism = 10
val timeout = 1.second
var requestsCounter = 0
when(sqsClient.receiveMessage(any[ReceiveMessageRequest]))
.thenAnswer(new Answer[CompletableFuture[ReceiveMessageResponse]] {
def answer(invocation: InvocationOnMock) = {
requestsCounter += 1

if (requestsCounter > firstWithDataCount && requestsCounter <= firstWithDataCount + thenEmptyCount) {
akka.pattern
.after(timeout, system.scheduler) {
Future.successful(
ReceiveMessageResponse
.builder()
.messages(List.empty[Message]: _*)
.build()
)
}(system.dispatcher)
.toJava
.toCompletableFuture
} else {
CompletableFuture.completedFuture(
ReceiveMessageResponse
.builder()
.messages(defaultMessages: _*)
.build()
)
}
}
})

val probe = SqsSource(
"url",
SqsSourceSettings.Defaults
.withMaxBufferSize(10)
.withParallelRequests(10)
.withWaitTime(timeout)
).runWith(TestSink.probe[Message])

(1 to firstWithDataCount * 10).foreach(_ => probe.requestNext(10.milliseconds))

verify(sqsClient, atMostTimes(firstWithDataCount + parallelism)).receiveMessage(any[ReceiveMessageRequest])

// now the throttling kicks in
probe.request(1)
probe.expectNoMessage((thenEmptyCount - parallelism + 1) * timeout)
verify(sqsClient, atMostTimes(firstWithDataCount + thenEmptyCount)).receiveMessage(any[ReceiveMessageRequest])

// now the throttling is off
probe.expectNext()

(1 to 1000).foreach(_ => probe.requestNext(10.milliseconds))
}
}

0 comments on commit 9c17a7c

Please sign in to comment.