From 7f7d142fef2da9e0b4bcdbacd06a51ead29c4f5e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABl=20Ferrachat?= Date: Tue, 23 May 2023 15:02:34 +0200 Subject: [PATCH 1/2] Add getObjectByRanges to S3 API --- .../alpakka/s3/impl/MergeOrderedN.scala | 128 ++++++++++++++++++ .../stream/alpakka/s3/impl/S3Stream.scala | 72 ++++++++++ .../akka/stream/alpakka/s3/scaladsl/S3.scala | 49 +++++++ .../alpakka/s3/scaladsl/S3WireMockBase.scala | 14 ++ .../scala/docs/scaladsl/S3SourceSpec.scala | 38 ++++++ 5 files changed, 301 insertions(+) create mode 100644 s3/src/main/scala/akka/stream/alpakka/s3/impl/MergeOrderedN.scala diff --git a/s3/src/main/scala/akka/stream/alpakka/s3/impl/MergeOrderedN.scala b/s3/src/main/scala/akka/stream/alpakka/s3/impl/MergeOrderedN.scala new file mode 100644 index 0000000000..cab6c09665 --- /dev/null +++ b/s3/src/main/scala/akka/stream/alpakka/s3/impl/MergeOrderedN.scala @@ -0,0 +1,128 @@ +package akka.stream.alpakka.s3.impl + +import akka.annotation.InternalApi +import akka.stream.stage.{GraphStage, GraphStageLogic, InHandler, OutHandler} +import akka.stream.{Attributes, Inlet, Outlet, UniformFanInShape} + +import scala.collection.{immutable, mutable} + +@InternalApi private[impl] object MergeOrderedN { + /** @see [[MergeOrderedN]] */ + def apply[T](inputPorts: Int, breadth: Int) = + new MergeOrderedN[T](inputPorts, breadth) +} + +/** + * Takes multiple streams (in ascending order of input ports) whose elements will be pushed only if all elements from the + * previous stream(s) are already pushed downstream. + * + * The `breadth` controls how many upstream are pulled in parallel. + * That means elements might be received in any order, but will be buffered (if necessary) until their time comes. + * + * '''Emits when''' the next element from upstream (in ascending order of input ports) is available + * + * '''Backpressures when''' downstream backpressures + * + * '''Completes when''' all upstreams complete and there are no more buffered elements + * + * '''Cancels when''' downstream cancels + */ +@InternalApi private[impl] final class MergeOrderedN[T](val inputPorts: Int, val breadth: Int) extends GraphStage[UniformFanInShape[T, T]] { + require(inputPorts > 1, "input ports must be > 1") + require(breadth > 0, "breadth must be > 0") + + val in: immutable.IndexedSeq[Inlet[T]] = Vector.tabulate(inputPorts)(i => Inlet[T]("MergeOrderedN.in" + i)) + val out: Outlet[T] = Outlet[T]("MergeOrderedN.out") + override val shape: UniformFanInShape[T, T] = UniformFanInShape(out, in: _*) + + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) with OutHandler { + private val bufferByInPort = mutable.Map.empty[Int, mutable.Queue[T]] // Queue must not be empty, if so entry should be removed + private var currentHeadInPortIdx = 0 + private var currentLastInPortIdx = 0 + private val overallLastInPortIdx = inputPorts - 1 + + setHandler(out, this) + + in.zipWithIndex.foreach { case (inPort, idx) => + setHandler(inPort, new InHandler { + override def onPush(): Unit = { + val elem = grab(inPort) + if (currentHeadInPortIdx != idx || !isAvailable(out)) { + bufferByInPort.updateWith(idx) { + case Some(inPortBuffer) => + Some(inPortBuffer.enqueue(elem)) + case None => + val inPortBuffer = mutable.Queue.empty[T] + inPortBuffer.enqueue(elem) + Some(inPortBuffer) + } + } else { + pushUsingQueue(Some(elem)) + } + tryPull(inPort) + } + + override def onUpstreamFinish(): Unit = { + if (canCompleteStage) + completeStage() + else if (canSlideFrame) + slideFrame() + } + }) + } + + override def onPull(): Unit = pushUsingQueue() + + private def pushUsingQueue(next: Option[T] = None): Unit = { + val maybeBuffer = bufferByInPort.get(currentHeadInPortIdx) + if (maybeBuffer.forall(_.isEmpty) && next.nonEmpty) { + push(out, next.get) + } else if (maybeBuffer.exists(_.nonEmpty) && next.nonEmpty) { + maybeBuffer.get.enqueue(next.get) + push(out, maybeBuffer.get.dequeue()) + } else if (maybeBuffer.exists(_.nonEmpty) && next.isEmpty) { + push(out, maybeBuffer.get.dequeue()) + } else { + // Both empty + } + + if (maybeBuffer.exists(_.isEmpty)) + bufferByInPort.remove(currentHeadInPortIdx) + + if (canCompleteStage) + completeStage() + else if (canSlideFrame) + slideFrame() + } + + override def preStart(): Unit = { + if (breadth >= inputPorts) { + in.foreach(pull) + currentLastInPortIdx = overallLastInPortIdx + } else { + in.slice(0, breadth).foreach(pull) + currentLastInPortIdx = breadth - 1 + } + } + + private def canSlideFrame: Boolean = + (!bufferByInPort.contains(currentHeadInPortIdx) || bufferByInPort(currentHeadInPortIdx).isEmpty) && + isClosed(in(currentHeadInPortIdx)) + + private def canCompleteStage: Boolean = + canSlideFrame && currentHeadInPortIdx == overallLastInPortIdx + + private def slideFrame(): Unit = { + currentHeadInPortIdx += 1 + + if (isAvailable(out)) + pushUsingQueue() + + if (currentLastInPortIdx != overallLastInPortIdx) + currentLastInPortIdx += 1 + + if (!hasBeenPulled(in(currentLastInPortIdx))) + tryPull(in(currentLastInPortIdx)) + } + } +} diff --git a/s3/src/main/scala/akka/stream/alpakka/s3/impl/S3Stream.scala b/s3/src/main/scala/akka/stream/alpakka/s3/impl/S3Stream.scala index e4b2dedf2a..9b5c2f98bc 100644 --- a/s3/src/main/scala/akka/stream/alpakka/s3/impl/S3Stream.scala +++ b/s3/src/main/scala/akka/stream/alpakka/s3/impl/S3Stream.scala @@ -12,6 +12,7 @@ import akka.annotation.InternalApi import akka.dispatch.ExecutionContexts import akka.http.scaladsl.Http.OutgoingConnection import akka.http.scaladsl.model.StatusCodes.{NoContent, NotFound, OK} +import akka.http.scaladsl.model.headers.ByteRange.FromOffset import akka.http.scaladsl.model.headers._ import akka.http.scaladsl.model.{headers => http, _} import akka.http.scaladsl.settings.{ClientConnectionSettings, ConnectionPoolSettings} @@ -20,6 +21,7 @@ import akka.http.scaladsl.{ClientTransport, Http} import akka.stream.alpakka.s3.BucketAccess.{AccessDenied, AccessGranted, NotExists} import akka.stream.alpakka.s3._ import akka.stream.alpakka.s3.impl.auth.{CredentialScope, Signer, SigningKey} +import akka.stream.alpakka.s3.scaladsl.S3 import akka.stream.scaladsl.{Flow, Keep, RetryFlow, RunnableGraph, Sink, Source, Tcp} import akka.stream.{Attributes, Materializer} import akka.util.ByteString @@ -27,6 +29,7 @@ import akka.{Done, NotUsed} import software.amazon.awssdk.regions.Region import scala.collection.immutable +import scala.collection.mutable.ListBuffer import scala.concurrent.{Future, Promise} import scala.util.{Failure, Success, Try} @@ -37,6 +40,9 @@ import scala.util.{Failure, Success, Try} BucketAndKey.validateObjectKey(key, conf) this } + + def mkString: String = + s"s3://$bucket/$key" } /** Internal Api */ @@ -165,6 +171,7 @@ import scala.util.{Failure, Success, Try} import Marshalling._ val MinChunkSize: Int = 5 * 1024 * 1024 //in bytes + val DefaultByteRangeSize: Long = 8 * 1024 * 1024 val atLeastOneByteString: Flow[ByteString, ByteString, NotUsed] = Flow[ByteString].orElse(Source.single(ByteString.empty)) @@ -232,6 +239,71 @@ import scala.util.{Failure, Success, Try} .mapMaterializedValue(_.flatMap(identity)(ExecutionContexts.parasitic)) } + def getObjectByRanges( + s3Location: S3Location, + versionId: Option[String], + s3Headers: S3Headers, + rangeSize: Long = DefaultByteRangeSize, + parallelism: Int = 4 + ): Source[ByteString, Future[ObjectMetadata]] = { + Source.fromMaterializer { (_, _) => + val objectMetadataMat = Promise[ObjectMetadata]() + getObjectMetadata(s3Location.bucket, s3Location.key, versionId, s3Headers) + .flatMapConcat { + case Some(s3Meta) if s3Meta.contentLength == 0 => + objectMetadataMat.success(s3Meta) + Source.empty[ByteString] + case Some(s3Meta) => + objectMetadataMat.success(s3Meta) + val byteRanges = computeByteRanges(s3Meta.contentLength, rangeSize) + if (byteRanges.size <= 1) { + getObject(s3Location, None, versionId, s3Headers) + } else { + val rangeSources = prepareRangeSources(s3Location, versionId, s3Headers, byteRanges) + Source.combine[ByteString, ByteString]( + rangeSources.head, + rangeSources(1), + rangeSources.drop(2): _* + )(p => MergeOrderedN(p, parallelism)) + } + case None => + Source.failed(throw new NoSuchElementException(s"Object does not exist at location [${s3Location.mkString}]")) + } + .mapError { + case e: Throwable => + objectMetadataMat.tryFailure(e) + e + } + .mapMaterializedValue(_ => objectMetadataMat.future) + } + .mapMaterializedValue(_.flatMap(identity)(ExecutionContexts.parasitic)) + } + + private def computeByteRanges(contentLength: Long, rangeSize: Long): Seq[ByteRange] = { + require(contentLength >= 0, s"contentLength ($contentLength) must be >= 0") + require(rangeSize > 0, s"rangeSize ($rangeSize) must be > 0") + if (contentLength <= rangeSize) + Nil + else { + val ranges = ListBuffer[ByteRange]() + for (i <- 0L until contentLength by rangeSize) { + if ((i + rangeSize) >= contentLength) + ranges += FromOffset(i) + else + ranges += ByteRange(i, i + rangeSize - 1) + } + ranges.result() + } + } + + private def prepareRangeSources( + s3Location: S3Location, + versionId: Option[String], + s3Headers: S3Headers, + byteRanges: Seq[ByteRange] + ): Seq[Source[ByteString, Future[ObjectMetadata]]] = + byteRanges.map(br => getObject(s3Location, Some(br), versionId, s3Headers)) + /** * An ADT that represents the current state of pagination */ diff --git a/s3/src/main/scala/akka/stream/alpakka/s3/scaladsl/S3.scala b/s3/src/main/scala/akka/stream/alpakka/s3/scaladsl/S3.scala index cacb6cbe6b..4b88b3b094 100644 --- a/s3/src/main/scala/akka/stream/alpakka/s3/scaladsl/S3.scala +++ b/s3/src/main/scala/akka/stream/alpakka/s3/scaladsl/S3.scala @@ -263,6 +263,55 @@ object S3 { ): Source[ByteString, Future[ObjectMetadata]] = S3Stream.getObject(S3Location(bucket, key), range, versionId, s3Headers) + /** + * Gets a S3 Object using `Byte-Range Fetches` + * + * @param bucket the s3 bucket name + * @param key the s3 object key + * @param sse [optional] the server side encryption used on upload + * @param rangeSize size of each range to request + * @param parallelism number of range to request in parallel + * + * @return A [[akka.stream.scaladsl.Source]] containing the objects data as a [[akka.util.ByteString]] along with a materialized value containing the + * [[akka.stream.alpakka.s3.ObjectMetadata]] + */ + def getObjectByRanges( + bucket: String, + key: String, + versionId: Option[String] = None, + sse: Option[ServerSideEncryption] = None, + rangeSize: Long = MinChunkSize, + parallelism: Int = 4 + ): Source[ByteString, Future[ObjectMetadata]] = + S3Stream.getObjectByRanges(S3Location(bucket, key), + versionId, + S3Headers.empty.withOptionalServerSideEncryption(sse), + rangeSize, + parallelism + ) + + /** + * Gets a S3 Object using `Byte-Range Fetches` + * + * @param bucket the s3 bucket name + * @param key the s3 object key + * @param s3Headers any headers you want to add + * @param rangeSize size of each range to request + * @param parallelism number of range to request in parallel + * + * @return A [[akka.stream.scaladsl.Source]] containing the objects data as a [[akka.util.ByteString]] along with a materialized value containing the + * [[akka.stream.alpakka.s3.ObjectMetadata]] + */ + def getObjectByRanges( + bucket: String, + key: String, + versionId: Option[String], + s3Headers: S3Headers, + rangeSize: Long, + parallelism: Int + ): Source[ByteString, Future[ObjectMetadata]] = + S3Stream.getObjectByRanges(S3Location(bucket, key), versionId, s3Headers, rangeSize, parallelism) + /** * Will return a list containing all of the buckets for the current AWS account * diff --git a/s3/src/test/scala/akka/stream/alpakka/s3/scaladsl/S3WireMockBase.scala b/s3/src/test/scala/akka/stream/alpakka/s3/scaladsl/S3WireMockBase.scala index 67f58f9fa4..242611251a 100644 --- a/s3/src/test/scala/akka/stream/alpakka/s3/scaladsl/S3WireMockBase.scala +++ b/s3/src/test/scala/akka/stream/alpakka/s3/scaladsl/S3WireMockBase.scala @@ -5,6 +5,7 @@ package akka.stream.alpakka.s3.scaladsl import akka.actor.ActorSystem +import akka.http.scaladsl.model.headers.ByteRange import akka.stream.alpakka.s3.S3Settings import akka.stream.alpakka.s3.headers.ServerSideEncryption import akka.stream.alpakka.s3.impl.S3Stream @@ -177,6 +178,19 @@ abstract class S3WireMockBase(_system: ActorSystem, val _wireMockServer: WireMoc ) ) + def mockRangedDownload(byteRange: ByteRange, range: String): Unit = + mock + .register( + get(urlEqualTo(s"/$bucketKey")) + .withHeader("Range", new EqualToPattern(s"bytes=$byteRange")) + .willReturn( + aResponse() + .withStatus(200) + .withHeader("ETag", """"fba9dede5f27731c9771645a39863328"""") + .withBody(range) + ) + ) + def mockRangedDownload(): Unit = mock .register( diff --git a/s3/src/test/scala/docs/scaladsl/S3SourceSpec.scala b/s3/src/test/scala/docs/scaladsl/S3SourceSpec.scala index 1782746099..b72b60adf1 100644 --- a/s3/src/test/scala/docs/scaladsl/S3SourceSpec.scala +++ b/s3/src/test/scala/docs/scaladsl/S3SourceSpec.scala @@ -5,6 +5,7 @@ package docs.scaladsl import akka.http.scaladsl.model.headers.ByteRange +import akka.http.scaladsl.model.headers.ByteRange.FromOffset import akka.http.scaladsl.model.{ContentType, ContentTypes, HttpEntity, HttpResponse, IllegalUriException} import akka.stream.Attributes import akka.stream.alpakka.s3.BucketAccess.{AccessDenied, AccessGranted, NotExists} @@ -26,6 +27,43 @@ class S3SourceSpec extends S3WireMockBase with S3ClientIntegrationSpec { override protected def afterEach(): Unit = mock.removeMappings() + "S3Source" should "download a stream of bytes by ranges from S3" in { + + val bodyBytes = ByteString(body) + val bodyRanges = bodyBytes.grouped(10).toList + val rangeHeaders = bodyRanges.zipWithIndex.map { + case (_, idx) if idx != bodyRanges.size - 1 => + ByteRange(idx * 10, (idx * 10) + 10 - 1) + case (_, idx) => + FromOffset(idx * 10) + } + val rangesWithHeaders = bodyRanges.zip(rangeHeaders) + + mockHead(bodyBytes.size) + rangesWithHeaders.foreach { case (bs, br) => mockRangedDownload(br, bs.utf8String)} + + val s3Source: Source[ByteString, Future[ObjectMetadata]] = + S3.getObjectByRanges(bucket, bucketKey, rangeSize = 10L) + + val (metadataFuture, dataFuture) = + s3Source.toMat(Sink.reduce[ByteString](_ ++ _))(Keep.both).run() + + val data = dataFuture.futureValue + val metadata = metadataFuture.futureValue + + data.utf8String shouldBe body + + HttpResponse( + entity = HttpEntity( + metadata.contentType + .flatMap(ContentType.parse(_).toOption) + .getOrElse(ContentTypes.`application/octet-stream`), + metadata.contentLength, + s3Source + ) + ) + } + "S3Source" should "download a stream of bytes from S3" in { mockDownload() From f8bb980bfa259fc67ee93c88d3dea5f3e0f7d565 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABl=20Ferrachat?= Date: Mon, 29 May 2023 09:55:43 +0200 Subject: [PATCH 2/2] Change fetching ranges strategy --- .../alpakka/s3/impl/MergeOrderedN.scala | 128 ------------------ .../stream/alpakka/s3/impl/S3Stream.scala | 106 +++++++++++---- 2 files changed, 82 insertions(+), 152 deletions(-) delete mode 100644 s3/src/main/scala/akka/stream/alpakka/s3/impl/MergeOrderedN.scala diff --git a/s3/src/main/scala/akka/stream/alpakka/s3/impl/MergeOrderedN.scala b/s3/src/main/scala/akka/stream/alpakka/s3/impl/MergeOrderedN.scala deleted file mode 100644 index cab6c09665..0000000000 --- a/s3/src/main/scala/akka/stream/alpakka/s3/impl/MergeOrderedN.scala +++ /dev/null @@ -1,128 +0,0 @@ -package akka.stream.alpakka.s3.impl - -import akka.annotation.InternalApi -import akka.stream.stage.{GraphStage, GraphStageLogic, InHandler, OutHandler} -import akka.stream.{Attributes, Inlet, Outlet, UniformFanInShape} - -import scala.collection.{immutable, mutable} - -@InternalApi private[impl] object MergeOrderedN { - /** @see [[MergeOrderedN]] */ - def apply[T](inputPorts: Int, breadth: Int) = - new MergeOrderedN[T](inputPorts, breadth) -} - -/** - * Takes multiple streams (in ascending order of input ports) whose elements will be pushed only if all elements from the - * previous stream(s) are already pushed downstream. - * - * The `breadth` controls how many upstream are pulled in parallel. - * That means elements might be received in any order, but will be buffered (if necessary) until their time comes. - * - * '''Emits when''' the next element from upstream (in ascending order of input ports) is available - * - * '''Backpressures when''' downstream backpressures - * - * '''Completes when''' all upstreams complete and there are no more buffered elements - * - * '''Cancels when''' downstream cancels - */ -@InternalApi private[impl] final class MergeOrderedN[T](val inputPorts: Int, val breadth: Int) extends GraphStage[UniformFanInShape[T, T]] { - require(inputPorts > 1, "input ports must be > 1") - require(breadth > 0, "breadth must be > 0") - - val in: immutable.IndexedSeq[Inlet[T]] = Vector.tabulate(inputPorts)(i => Inlet[T]("MergeOrderedN.in" + i)) - val out: Outlet[T] = Outlet[T]("MergeOrderedN.out") - override val shape: UniformFanInShape[T, T] = UniformFanInShape(out, in: _*) - - override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) with OutHandler { - private val bufferByInPort = mutable.Map.empty[Int, mutable.Queue[T]] // Queue must not be empty, if so entry should be removed - private var currentHeadInPortIdx = 0 - private var currentLastInPortIdx = 0 - private val overallLastInPortIdx = inputPorts - 1 - - setHandler(out, this) - - in.zipWithIndex.foreach { case (inPort, idx) => - setHandler(inPort, new InHandler { - override def onPush(): Unit = { - val elem = grab(inPort) - if (currentHeadInPortIdx != idx || !isAvailable(out)) { - bufferByInPort.updateWith(idx) { - case Some(inPortBuffer) => - Some(inPortBuffer.enqueue(elem)) - case None => - val inPortBuffer = mutable.Queue.empty[T] - inPortBuffer.enqueue(elem) - Some(inPortBuffer) - } - } else { - pushUsingQueue(Some(elem)) - } - tryPull(inPort) - } - - override def onUpstreamFinish(): Unit = { - if (canCompleteStage) - completeStage() - else if (canSlideFrame) - slideFrame() - } - }) - } - - override def onPull(): Unit = pushUsingQueue() - - private def pushUsingQueue(next: Option[T] = None): Unit = { - val maybeBuffer = bufferByInPort.get(currentHeadInPortIdx) - if (maybeBuffer.forall(_.isEmpty) && next.nonEmpty) { - push(out, next.get) - } else if (maybeBuffer.exists(_.nonEmpty) && next.nonEmpty) { - maybeBuffer.get.enqueue(next.get) - push(out, maybeBuffer.get.dequeue()) - } else if (maybeBuffer.exists(_.nonEmpty) && next.isEmpty) { - push(out, maybeBuffer.get.dequeue()) - } else { - // Both empty - } - - if (maybeBuffer.exists(_.isEmpty)) - bufferByInPort.remove(currentHeadInPortIdx) - - if (canCompleteStage) - completeStage() - else if (canSlideFrame) - slideFrame() - } - - override def preStart(): Unit = { - if (breadth >= inputPorts) { - in.foreach(pull) - currentLastInPortIdx = overallLastInPortIdx - } else { - in.slice(0, breadth).foreach(pull) - currentLastInPortIdx = breadth - 1 - } - } - - private def canSlideFrame: Boolean = - (!bufferByInPort.contains(currentHeadInPortIdx) || bufferByInPort(currentHeadInPortIdx).isEmpty) && - isClosed(in(currentHeadInPortIdx)) - - private def canCompleteStage: Boolean = - canSlideFrame && currentHeadInPortIdx == overallLastInPortIdx - - private def slideFrame(): Unit = { - currentHeadInPortIdx += 1 - - if (isAvailable(out)) - pushUsingQueue() - - if (currentLastInPortIdx != overallLastInPortIdx) - currentLastInPortIdx += 1 - - if (!hasBeenPulled(in(currentLastInPortIdx))) - tryPull(in(currentLastInPortIdx)) - } - } -} diff --git a/s3/src/main/scala/akka/stream/alpakka/s3/impl/S3Stream.scala b/s3/src/main/scala/akka/stream/alpakka/s3/impl/S3Stream.scala index 9b5c2f98bc..8bfe68479c 100644 --- a/s3/src/main/scala/akka/stream/alpakka/s3/impl/S3Stream.scala +++ b/s3/src/main/scala/akka/stream/alpakka/s3/impl/S3Stream.scala @@ -4,9 +4,6 @@ package akka.stream.alpakka.s3.impl -import java.net.InetSocketAddress -import java.time.{Instant, ZoneOffset, ZonedDateTime} -import scala.annotation.nowarn import akka.actor.ActorSystem import akka.annotation.InternalApi import akka.dispatch.ExecutionContexts @@ -21,15 +18,17 @@ import akka.http.scaladsl.{ClientTransport, Http} import akka.stream.alpakka.s3.BucketAccess.{AccessDenied, AccessGranted, NotExists} import akka.stream.alpakka.s3._ import akka.stream.alpakka.s3.impl.auth.{CredentialScope, Signer, SigningKey} -import akka.stream.alpakka.s3.scaladsl.S3 import akka.stream.scaladsl.{Flow, Keep, RetryFlow, RunnableGraph, Sink, Source, Tcp} import akka.stream.{Attributes, Materializer} import akka.util.ByteString import akka.{Done, NotUsed} import software.amazon.awssdk.regions.Region -import scala.collection.immutable +import java.net.InetSocketAddress +import java.time.{Instant, ZoneOffset, ZonedDateTime} +import scala.annotation.{nowarn, tailrec} import scala.collection.mutable.ListBuffer +import scala.collection.{immutable, mutable} import scala.concurrent.{Future, Promise} import scala.util.{Failure, Success, Try} @@ -175,6 +174,8 @@ import scala.util.{Failure, Success, Try} val atLeastOneByteString: Flow[ByteString, ByteString, NotUsed] = Flow[ByteString].orElse(Source.single(ByteString.empty)) + private val RangeEndMarker = "$END$" + // def because tokens can expire private def signingKey(implicit settings: S3Settings) = { val requestDate = ZonedDateTime.now(ZoneOffset.UTC) @@ -255,19 +256,11 @@ import scala.util.{Failure, Success, Try} Source.empty[ByteString] case Some(s3Meta) => objectMetadataMat.success(s3Meta) - val byteRanges = computeByteRanges(s3Meta.contentLength, rangeSize) - if (byteRanges.size <= 1) { - getObject(s3Location, None, versionId, s3Headers) - } else { - val rangeSources = prepareRangeSources(s3Location, versionId, s3Headers, byteRanges) - Source.combine[ByteString, ByteString]( - rangeSources.head, - rangeSources(1), - rangeSources.drop(2): _* - )(p => MergeOrderedN(p, parallelism)) - } + doGetByRanges(s3Location, versionId, s3Headers, s3Meta.contentLength, rangeSize, parallelism) case None => - Source.failed(throw new NoSuchElementException(s"Object does not exist at location [${s3Location.mkString}]")) + val exc = new NoSuchElementException(s"Object does not exist at location [${s3Location.mkString}]") + objectMetadataMat.failure(exc) + Source.failed(exc) } .mapError { case e: Throwable => @@ -279,6 +272,29 @@ import scala.util.{Failure, Success, Try} .mapMaterializedValue(_.flatMap(identity)(ExecutionContexts.parasitic)) } + private def doGetByRanges( + s3Location: S3Location, + versionId: Option[String], + s3Headers: S3Headers, + contentLength: Long, + rangeSize: Long, + parallelism: Int + ): Source[ByteString, Any] = { + val byteRanges = computeByteRanges(contentLength, rangeSize) + if (byteRanges.size <= 1) { + getObject(s3Location, None, versionId, s3Headers) + } else { + Source(byteRanges) + .zipWithIndex + .flatMapMerge(parallelism, brToIdx => { + val (br, idx) = brToIdx + val endMarker = Source.single(ByteString("$END$")) + getObject(s3Location, Some(br), versionId, s3Headers).concat(endMarker).map(_ -> idx) + }) + .statefulMapConcat(RangeMapConcat) + } + } + private def computeByteRanges(contentLength: Long, rangeSize: Long): Seq[ByteRange] = { require(contentLength >= 0, s"contentLength ($contentLength) must be >= 0") require(rangeSize > 0, s"rangeSize ($rangeSize) must be > 0") @@ -296,13 +312,55 @@ import scala.util.{Failure, Success, Try} } } - private def prepareRangeSources( - s3Location: S3Location, - versionId: Option[String], - s3Headers: S3Headers, - byteRanges: Seq[ByteRange] - ): Seq[Source[ByteString, Future[ObjectMetadata]]] = - byteRanges.map(br => getObject(s3Location, Some(br), versionId, s3Headers)) + private val RangeMapConcat: () => ((ByteString, Long)) => IterableOnce[ByteString] = () => { + var currentRangeIdx = 0L + var completedRanges = Set.empty[Long] + var bufferByRangeIdx = Map.empty[Long, mutable.Queue[ByteString]] + + val isEndMarker: ByteString => Boolean = bs => bs.size == RangeEndMarker.length && bs.utf8String == RangeEndMarker + + def foldRangeBuffers(): Option[ByteString] = { + @tailrec + def innerFoldRangeBuffers(acc: Option[ByteString]): Option[ByteString] = { + bufferByRangeIdx.get(currentRangeIdx) match { + case None => + if (completedRanges.contains(currentRangeIdx)) + currentRangeIdx += 1 + if (bufferByRangeIdx.contains(currentRangeIdx)) + innerFoldRangeBuffers(acc) + else + acc + case Some(queue) => + val next = queue.dequeueAll(_ => true).foldLeft(acc.getOrElse(ByteString.empty))(_ ++ _) + bufferByRangeIdx -= currentRangeIdx + if (completedRanges.contains(currentRangeIdx)) + currentRangeIdx += 1 + if (bufferByRangeIdx.contains(currentRangeIdx)) + innerFoldRangeBuffers(Some(next)) + else + Some(next) + } + } + + innerFoldRangeBuffers(None) + } + + bsToIdx => { + val (bs, idx) = bsToIdx + if (isEndMarker(bs)) { + completedRanges = completedRanges + idx + foldRangeBuffers().toList + } else if (idx == currentRangeIdx) { + bs :: Nil + } else { + bufferByRangeIdx = bufferByRangeIdx.updatedWith(idx.toInt) { + case Some(queue) => Some(queue.enqueue(bs)) + case None => Some(mutable.Queue(bs)) + } + Nil + } + } + } /** * An ADT that represents the current state of pagination