-
Notifications
You must be signed in to change notification settings - Fork 645
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
AWS SQS source: reduce consumption on empty receives (#1743)
- Loading branch information
Showing
3 changed files
with
285 additions
and
3 deletions.
There are no files selected for viewing
120 changes: 120 additions & 0 deletions
120
sqs/src/main/scala/akka/stream/alpakka/sqs/impl/BalancingMapAsync.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
/* | ||
* Copyright (C) 2016-2019 Lightbend Inc. <http://www.lightbend.com> | ||
*/ | ||
|
||
package akka.stream.alpakka.sqs.impl | ||
|
||
import akka.annotation.InternalApi | ||
import akka.stream.ActorAttributes.SupervisionStrategy | ||
import akka.stream.Attributes.name | ||
import akka.stream._ | ||
import akka.stream.impl.fusing.MapAsync | ||
import akka.stream.impl.{Buffer => BufferImpl} | ||
import akka.stream.stage.{GraphStage, GraphStageLogic, InHandler, OutHandler} | ||
|
||
import scala.annotation.tailrec | ||
import scala.concurrent.Future | ||
import scala.util.control.NonFatal | ||
import scala.util.{Failure, Success} | ||
|
||
@InternalApi private[akka] final case class BalancingMapAsync[In, Out]( | ||
maxParallelism: Int, | ||
f: In => Future[Out], | ||
balancingF: (Out, Int) ⇒ Int | ||
) extends GraphStage[FlowShape[In, Out]] { | ||
|
||
import MapAsync._ | ||
|
||
private val in = Inlet[In]("BalancingMapAsync.in") | ||
private val out = Outlet[Out]("BalancingMapAsync.out") | ||
|
||
override def initialAttributes = name("BalancingMapAsync") | ||
|
||
override val shape = FlowShape(in, out) | ||
|
||
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = | ||
new GraphStageLogic(shape) with InHandler with OutHandler { | ||
|
||
lazy val decider = inheritedAttributes.mandatoryAttribute[SupervisionStrategy].decider | ||
var buffer: BufferImpl[Holder[Out]] = _ | ||
var parallelism = maxParallelism | ||
|
||
private val futureCB = getAsyncCallback[Holder[Out]]( | ||
holder => | ||
holder.elem match { | ||
case Success(value) => | ||
parallelism = balancingF(value, parallelism) | ||
pushNextIfPossible() | ||
case Failure(ex) => | ||
holder.supervisionDirectiveFor(decider, ex) match { | ||
// fail fast as if supervision says so | ||
case Supervision.Stop => failStage(ex) | ||
case _ => pushNextIfPossible() | ||
} | ||
} | ||
) | ||
|
||
override def preStart(): Unit = buffer = BufferImpl(parallelism, materializer) | ||
|
||
override def onPull(): Unit = pushNextIfPossible() | ||
|
||
override def onPush(): Unit = { | ||
try { | ||
val future = f(grab(in)) | ||
val holder = new Holder[Out](NotYetThere, futureCB) | ||
buffer.enqueue(holder) | ||
|
||
future.value match { | ||
case None => future.onComplete(holder)(akka.dispatch.ExecutionContexts.sameThreadExecutionContext) | ||
case Some(v) => | ||
// #20217 the future is already here, optimization: avoid scheduling it on the dispatcher and | ||
// run the logic directly on this thread | ||
holder.setElem(v) | ||
v match { | ||
// this optimization also requires us to stop the stage to fail fast if the decider says so: | ||
case Failure(ex) if holder.supervisionDirectiveFor(decider, ex) == Supervision.Stop => failStage(ex) | ||
case _ => pushNextIfPossible() | ||
} | ||
} | ||
|
||
} catch { | ||
// this logic must only be executed if f throws, not if the future is failed | ||
case NonFatal(ex) => if (decider(ex) == Supervision.Stop) failStage(ex) | ||
} | ||
|
||
pullIfNeeded() | ||
} | ||
|
||
override def onUpstreamFinish(): Unit = if (buffer.isEmpty) completeStage() | ||
|
||
@tailrec | ||
private def pushNextIfPossible(): Unit = | ||
if (buffer.isEmpty) { | ||
if (isClosed(in)) completeStage() | ||
else pullIfNeeded() | ||
} else if (buffer.peek().elem eq NotYetThere) pullIfNeeded() // ahead of line blocking to keep order | ||
else if (isAvailable(out)) { | ||
val holder = buffer.dequeue() | ||
holder.elem match { | ||
case Success(elem) => | ||
push(out, elem) | ||
pullIfNeeded() | ||
|
||
case Failure(NonFatal(ex)) => | ||
holder.supervisionDirectiveFor(decider, ex) match { | ||
// this could happen if we are looping in pushNextIfPossible and end up on a failed future before the | ||
// onComplete callback has run | ||
case Supervision.Stop => failStage(ex) | ||
case _ => | ||
// try next element | ||
pushNextIfPossible() | ||
} | ||
} | ||
} | ||
|
||
private def pullIfNeeded(): Unit = | ||
if (buffer.used < parallelism && !hasBeenPulled(in)) tryPull(in) | ||
|
||
setHandlers(in, out, this) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
150 changes: 150 additions & 0 deletions
150
sqs/src/test/scala/akka/stream/alpakka/sqs/scaladsl/SqsSourceSpec.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
/* | ||
* Copyright (C) 2016-2019 Lightbend Inc. <http://www.lightbend.com> | ||
*/ | ||
|
||
package akka.stream.alpakka.sqs.scaladsl | ||
|
||
import java.util.concurrent.CompletableFuture | ||
|
||
import akka.stream.alpakka.sqs.SqsSourceSettings | ||
import akka.stream.testkit.scaladsl.TestSink | ||
import org.mockito.ArgumentMatchers._ | ||
import org.mockito.Mockito.{atMost => atMostTimes, _} | ||
import org.mockito.invocation.InvocationOnMock | ||
import org.mockito.stubbing.Answer | ||
import org.scalatest.{FlatSpec, Matchers} | ||
import org.scalatestplus.mockito.MockitoSugar.mock | ||
import software.amazon.awssdk.services.sqs.SqsAsyncClient | ||
import software.amazon.awssdk.services.sqs.model.{Message, ReceiveMessageRequest, ReceiveMessageResponse} | ||
|
||
import scala.compat.java8.FutureConverters._ | ||
import scala.concurrent.Future | ||
import scala.concurrent.duration._ | ||
|
||
class SqsSourceSpec extends FlatSpec with Matchers with DefaultTestContext { | ||
val defaultMessages = (1 to 10).map { i => | ||
Message.builder().body(s"message $i").build() | ||
} | ||
|
||
"SqsSource" should "send a request and unwrap the response" in { | ||
implicit val sqsClient: SqsAsyncClient = mock[SqsAsyncClient] | ||
when(sqsClient.receiveMessage(any[ReceiveMessageRequest])) | ||
.thenReturn( | ||
CompletableFuture.completedFuture( | ||
ReceiveMessageResponse | ||
.builder() | ||
.messages(defaultMessages: _*) | ||
.build() | ||
) | ||
) | ||
|
||
val probe = SqsSource( | ||
"url", | ||
SqsSourceSettings.Defaults.withMaxBufferSize(10) | ||
).runWith(TestSink.probe[Message]) | ||
|
||
defaultMessages.foreach(probe.requestNext) | ||
|
||
/** | ||
* Invocations: | ||
* 1 - to initially fill the buffer | ||
* 2 - asynchronous call after the buffer is filled | ||
* 3 - buffer proxies pull when it's not full -> async stage provides the data and executes the next call | ||
*/ | ||
verify(sqsClient, times(3)).receiveMessage(any[ReceiveMessageRequest]) | ||
} | ||
|
||
it should "buffer messages and acquire them fast with slow sqs" in { | ||
implicit val sqsClient: SqsAsyncClient = mock[SqsAsyncClient] | ||
val timeout = 1.second | ||
val bufferToBatchRatio = 5 | ||
|
||
when(sqsClient.receiveMessage(any[ReceiveMessageRequest])) | ||
.thenAnswer(new Answer[CompletableFuture[ReceiveMessageResponse]] { | ||
def answer(invocation: InvocationOnMock) = | ||
akka.pattern | ||
.after(timeout, system.scheduler) { | ||
Future.successful( | ||
ReceiveMessageResponse | ||
.builder() | ||
.messages(defaultMessages: _*) | ||
.build() | ||
) | ||
}(system.dispatcher) | ||
.toJava | ||
.toCompletableFuture | ||
}) | ||
|
||
val probe = SqsSource( | ||
"url", | ||
SqsSourceSettings.Defaults.withMaxBufferSize(SqsSourceSettings.Defaults.maxBatchSize * bufferToBatchRatio) | ||
).runWith(TestSink.probe[Message]) | ||
|
||
Thread.sleep(timeout.toMillis * (bufferToBatchRatio + 1)) | ||
|
||
for { | ||
i <- 1 to bufferToBatchRatio | ||
message <- defaultMessages | ||
} { | ||
probe.requestNext(10.milliseconds) shouldEqual message | ||
} | ||
} | ||
|
||
it should "enable throttling on emptyReceives and disable throttling when a new message arrives" in { | ||
implicit val sqsClient: SqsAsyncClient = mock[SqsAsyncClient] | ||
val firstWithDataCount = 30 | ||
val thenEmptyCount = 15 | ||
val parallelism = 10 | ||
val timeout = 1.second | ||
var requestsCounter = 0 | ||
when(sqsClient.receiveMessage(any[ReceiveMessageRequest])) | ||
.thenAnswer(new Answer[CompletableFuture[ReceiveMessageResponse]] { | ||
def answer(invocation: InvocationOnMock) = { | ||
requestsCounter += 1 | ||
|
||
if (requestsCounter > firstWithDataCount && requestsCounter <= firstWithDataCount + thenEmptyCount) { | ||
akka.pattern | ||
.after(timeout, system.scheduler) { | ||
Future.successful( | ||
ReceiveMessageResponse | ||
.builder() | ||
.messages(List.empty[Message]: _*) | ||
.build() | ||
) | ||
}(system.dispatcher) | ||
.toJava | ||
.toCompletableFuture | ||
} else { | ||
CompletableFuture.completedFuture( | ||
ReceiveMessageResponse | ||
.builder() | ||
.messages(defaultMessages: _*) | ||
.build() | ||
) | ||
} | ||
} | ||
}) | ||
|
||
val probe = SqsSource( | ||
"url", | ||
SqsSourceSettings.Defaults | ||
.withMaxBufferSize(10) | ||
.withParallelRequests(10) | ||
.withWaitTime(timeout) | ||
).runWith(TestSink.probe[Message]) | ||
|
||
(1 to firstWithDataCount * 10).foreach(_ => probe.requestNext(10.milliseconds)) | ||
|
||
verify(sqsClient, atMostTimes(firstWithDataCount + parallelism)).receiveMessage(any[ReceiveMessageRequest]) | ||
|
||
// now the throttling kicks in | ||
probe.request(1) | ||
probe.expectNoMessage((thenEmptyCount - parallelism + 1) * timeout) | ||
verify(sqsClient, atMostTimes(firstWithDataCount + thenEmptyCount)).receiveMessage(any[ReceiveMessageRequest]) | ||
|
||
// now the throttling is off | ||
probe.expectNext() | ||
|
||
(1 to 1000).foreach(_ => probe.requestNext(10.milliseconds)) | ||
} | ||
} |